commit 4627ddd: [Rework] Eliminate torch from neural plugin

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Jul 1 16:28:08 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-07-01 14:53:50 +0100
URL: https://github.com/rspamd/rspamd/commit/4627dddfdbd3f7bf56d4d3e88374406b07f08b9b

[Rework] Eliminate torch from neural plugin

---
 src/plugins/lua/neural.lua | 363 +++++++++++++++------------------------------
 1 file changed, 116 insertions(+), 247 deletions(-)

diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 032859d18..193b07614 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -26,9 +26,6 @@ local lua_redis = require "lua_redis"
 local lua_util = require "lua_util"
 local fun = require "fun"
 local meta_functions = require "lua_meta"
-local use_torch = false
-local torch
-local nn
 local N = "neural"
 
 -- Module vars
@@ -216,10 +213,7 @@ local function gen_ann_prefix(rule, id)
   local cksum = rspamd_config:get_symbols_cksum():hex()
   -- We also need to count metatokens:
   local n = meta_functions.rspamd_count_metatokens()
-  local tprefix = ''
-  if use_torch then
-    tprefix = 't';
-  end
+  local tprefix = 'k'
   if id then
     return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id
   else
@@ -229,27 +223,7 @@ end
 
 local function is_ann_valid(rule, prefix, ann)
   if ann then
-    local n = rspamd_config:get_symbols_count() +
-        meta_functions.rspamd_count_metatokens()
-
-    if use_torch then
-      return true
-    else
-      if n ~= ann:get_inputs() then
-        rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
-            ' is found in the cache', prefix, ann:get_inputs(), n)
-        return false
-      end
-      local layers = ann:get_layers()
-
-      if not layers or #layers ~= rule.nlayers then
-        rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
-          prefix, #layers)
-        return false
-      end
-
-      return true
-    end
+    return true
   end
 end
 
@@ -275,28 +249,17 @@ local function ann_scores_filter(task)
       fun.each(function(e) table.insert(ann_data, e) end, mt)
 
       local score
-      if use_torch then
-        local out = rule.anns[id].ann:forward(torch.Tensor(ann_data))
-        score = out[1]
-      else
-        local out = rule.anns[id].ann:test(ann_data)
-        score = out[1]
-      end
+      local out = rule.anns[id].ann:apply1(ann_data)
+      score = out[1]
 
       local symscore = string.format('%.3f', score)
       rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore)
 
       if score > 0 then
         local result = score
-        if not use_torch then
-          result = rspamd_util.normalize_prob(score / 2.0, 0)
-        end
         task:insert_result(rule.symbol_spam, result, symscore, id)
       else
         local result = -(score)
-        if not use_torch then
-          result = rspamd_util.normalize_prob(-(score) / 2.0, 0)
-        end
         task:insert_result(rule.symbol_ham, result, symscore, id)
       end
     end
@@ -304,20 +267,13 @@ local function ann_scores_filter(task)
 end
 
 local function create_ann(n, nlayers)
-  if use_torch then
-    -- We ignore number of layers so far when using torch
-    local ann = nn.Sequential()
-    local nhidden = math.floor((n + 1) / 2)
-    ann:add(nn.NaN(nn.Identity()))
-    ann:add(nn.Linear(n, nhidden))
-    ann:add(nn.PReLU())
-    ann:add(nn.Linear(nhidden, 1))
-    ann:add(nn.Tanh())
-
-    return ann
-  else
-    assert(false)
-  end
+    -- We ignore number of layers so far when using kann
+  local nhidden = math.floor((n + 1) / 2)
+  local t = rspamd_kann.layer.input(n)
+  t = rspamd_kann.transform.relu(t)
+  t = rspamd_kann.transform.tanh(rspamd_kann.layer.dense(t, nhidden));
+  t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.mse)
+  return rspamd_kann.new.kann(t)
 end
 
 local function create_train_ann(rule, n, id)
@@ -364,11 +320,7 @@ local function load_or_invalidate_ann(rule, data, id, ev_base)
     rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
     return
   else
-    if use_torch then
-      ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject()
-    else
-      assert(false)
-    end
+    ann = rspamd_kann.load(ann_data)
   end
 
   if is_ann_valid(rule, prefix, ann) then
@@ -533,47 +485,9 @@ local function train_ann(rule, _, ev_base, elt, worker)
       )
     else
       rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
-        prefix, train_mse)
-      local ann_data
-      if use_torch then
-        local f = torch.MemoryFile()
-        f:writeObject(rule.anns[elt].ann_train)
-        ann_data = rspamd_util.zstd_compress(f:storage():string())
-      else
-        ann_data = rspamd_util.zstd_compress(rule.anns[elt].ann_train:data())
-      end
-
-      rule.anns[elt].version = rule.anns[elt].version + 1
-      rule.anns[elt].ann = rule.anns[elt].ann_train
-      rule.anns[elt].ann_train = nil
-      lua_redis.exec_redis_script(redis_save_unlock_id,
-        {ev_base = ev_base, is_write = true},
-        redis_save_cb,
-        {prefix, tostring(ann_data), tostring(rule.ann_expire)})
-    end
-  end
-
-  local function ann_trained_torch(err, data)
-    rule.learning_spawned = false
-    if err then
-      rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
-        prefix, err)
-      lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        true, -- is write
-        redis_unlock_cb, --callback
-        'DEL', -- command
-        {prefix .. '_locked'}
-      )
-    else
-      rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
-        prefix, #data)
-      local ann_data
-      local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
-      ann_data = rspamd_util.zstd_compress(f:storage():string())
-      rule.anns[elt].ann_train = f:readObject()
+          prefix, train_mse)
+      local f = rule.anns[elt].ann_train:save()
+      local ann_data = rspamd_util.zstd_compress(f)
 
       rule.anns[elt].version = rule.anns[elt].version + 1
       rule.anns[elt].ann = rule.anns[elt].ann_train
@@ -608,12 +522,6 @@ local function train_ann(rule, _, ev_base, elt, worker)
       -- Now we need to join inputs and create the appropriate test vectors
       local n = rspamd_config:get_symbols_count() +
           meta_functions.rspamd_count_metatokens()
-      local filt = function(elts)
-        -- Basic sanity checks: vector has good length + there are no
-        -- 'bad' values such as NaNs or infinities in its elements
-        return #elts == n and
-            not fun.any(function(e) return e ~= e or e == math.huge or e == -math.huge end, elts)
-      end
 
       -- Now we can train ann
       if not rule.anns[elt] or not rule.anns[elt].ann_train then
@@ -638,67 +546,44 @@ local function train_ann(rule, _, ev_base, elt, worker)
           redis_invalidate_cb,
           {prefix})
       else
-        if use_torch then
-          -- For torch we do not need to mix samples as they would be flushed
-          local dataset = {}
-          fun.each(function(s)
-            table.insert(dataset, {torch.Tensor(s), torch.Tensor({1.0})})
-          end, fun.filter(filt, spam_elts))
-          fun.each(function(s)
-            table.insert(dataset, {torch.Tensor(s), torch.Tensor({-1.0})})
-          end, fun.filter(filt, ham_elts))
-          -- Needed for torch
-          dataset.size = function() return #dataset end
-
-          local function train_torch()
-            if rule.train.learn_threads then
-              torch.setnumthreads(rule.train.learn_threads)
-            end
-            local criterion = nn.MSECriterion()
-            local trainer = nn.StochasticGradient(rule.anns[elt].ann_train,
-              criterion)
-            trainer.learning_rate = rule.train.learning_rate
-            trainer.verbose = false
-            trainer.maxIteration = rule.train.max_iterations
-            trainer.hookIteration = function(_, iteration, currentError)
-              rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
-                  iteration, currentError)
-            end
-            trainer.logger = function(s)
-              rspamd_logger.infox(rspamd_config, 'training: %s', s)
-            end
-            trainer:train(dataset)
-            local out = torch.MemoryFile()
-            out:writeObject(rule.anns[elt].ann_train)
-            local st = out:storage():string()
-            return st
+        local inputs, outputs = {}, {}
+
+        for _,e in ipairs(spam_elts) do
+          if e == e then
+            inputs[#inputs + 1] = e
+            outputs[#outputs + 1] = 1.0
+          end
+        end
+        for _,e in ipairs(ham_elts) do
+          if e == e then
+            inputs[#inputs + 1] = e
+            outputs[#outputs + 1] = 0.0
           end
+        end
 
-          rule.learning_spawned = true
 
-          worker:spawn_process{
-            func = train_torch,
-            on_complete = ann_trained_torch,
-          }
-        else
-          local inputs = {}
-          local outputs = {}
-
-          fun.each(function(spam_sample, ham_sample)
-            table.insert(inputs, spam_sample)
-            table.insert(outputs, {1.0})
-            table.insert(inputs, ham_sample)
-            table.insert(outputs, {-1.0})
-          end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
-          rule.learning_spawned = true
-          rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
-          rule.anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
-            ev_base, {
-              max_epochs = rule.train.max_epoch,
-              desired_mse = rule.train.mse
-            })
+        local function train()
+          rule.anns[elt].ann_train:train1(inputs, outputs, {
+            lr = rule.train.learning_rate,
+            max_epoch = rule.train.max_iterations,
+            cb = function(iter, train_cost, _)
+              if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then
+                rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
+                    iter, train_cost)
+              end
+            end
+          })
+
+          local out = rule.anns[elt].ann_train:save()
+          return tostring(out)
         end
 
+        rule.learning_spawned = true
+
+        worker:spawn_process{
+          func = train,
+          on_complete = ann_trained,
+        }
       end
     end
   end
@@ -929,99 +814,83 @@ if not (opts and type(opts) == 'table') or not redis_params then
   return
 end
 
-if not use_torch then
-  rspamd_logger.errx(rspamd_config, 'neural networks support is not compiled in rspamd, this ' ..
-    'module is eventually disabled')
-  lua_util.disable_module(N, "fail")
-  return
-else
-  local rules = opts['rules']
-
-  if not rules then
-    -- Use legacy configuration
-    rules = {}
-    rules['RFANN'] = opts
-  end
+local rules = opts['rules']
 
-  if opts.disable_torch then
-    use_torch = false
-  else
-    torch = require "torch"
-    nn = require "nn"
+if not rules then
+  -- Use legacy configuration
+  rules = {}
+  rules['RFANN'] = opts
+end
 
-    torch.setnumthreads(1)
+local id = rspamd_config:register_symbol({
+  name = 'NEURAL_CHECK',
+  type = 'postfilter,nostat',
+  priority = 6,
+  callback = ann_scores_filter
+})
+for k,r in pairs(rules) do
+  local def_rules = lua_util.override_defaults(default_options, r)
+  def_rules['redis'] = redis_params
+  def_rules['anns'] = {} -- Store ANNs here
+
+  if not def_rules.prefix then
+    def_rules.prefix = k
   end
-
-  local id = rspamd_config:register_symbol({
-    name = 'NEURAL_CHECK',
-    type = 'postfilter,nostat',
-    priority = 6,
-    callback = ann_scores_filter
-  })
-  for k,r in pairs(rules) do
-    local def_rules = lua_util.override_defaults(default_options, r)
-    def_rules['redis'] = redis_params
-    def_rules['anns'] = {} -- Store ANNs here
-
-    if not def_rules.prefix then
-      def_rules.prefix = k
-    end
-    if not def_rules.name then
-      def_rules.name = k
-    end
-    if def_rules.train.max_train then
-      def_rules.train.max_trains = def_rules.train.max_train
-    end
-    rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
-    settings.rules[k] = def_rules
-    rspamd_config:set_metric_symbol({
-      name = def_rules.symbol_spam,
-      score = 0.0,
-      description = 'Neural network SPAM',
-      group = 'neural'
-    })
-    rspamd_config:register_symbol({
-      name = def_rules.symbol_spam,
-      type = 'virtual,nostat',
-      parent = id
-    })
-
-    rspamd_config:set_metric_symbol({
-      name = def_rules.symbol_ham,
-      score = -0.0,
-      description = 'Neural network HAM',
-      group = 'neural'
-    })
-    rspamd_config:register_symbol({
-      name = def_rules.symbol_ham,
-      type = 'virtual,nostat',
-      parent = id
-    })
+  if not def_rules.name then
+    def_rules.name = k
+  end
+  if def_rules.train.max_train then
+    def_rules.train.max_trains = def_rules.train.max_train
   end
+  rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
+  settings.rules[k] = def_rules
+  rspamd_config:set_metric_symbol({
+    name = def_rules.symbol_spam,
+    score = 0.0,
+    description = 'Neural network SPAM',
+    group = 'neural'
+  })
+  rspamd_config:register_symbol({
+    name = def_rules.symbol_spam,
+    type = 'virtual,nostat',
+    parent = id
+  })
 
+  rspamd_config:set_metric_symbol({
+    name = def_rules.symbol_ham,
+    score = -0.0,
+    description = 'Neural network HAM',
+    group = 'neural'
+  })
   rspamd_config:register_symbol({
-    name = 'NEURAL_LEARN',
-    type = 'idempotent,nostat',
-    priority = 5,
-    callback = ann_push_vector
+    name = def_rules.symbol_ham,
+    type = 'virtual,nostat',
+    parent = id
   })
+end
 
-  -- Add training scripts
-  for _,rule in pairs(settings.rules) do
-    load_scripts(rule.redis)
-    rspamd_config:add_on_load(function(cfg, ev_base, worker)
+rspamd_config:register_symbol({
+  name = 'NEURAL_LEARN',
+  type = 'idempotent,nostat',
+  priority = 5,
+  callback = ann_push_vector
+})
+
+-- Add training scripts
+for _,rule in pairs(settings.rules) do
+  load_scripts(rule.redis)
+  rspamd_config:add_on_load(function(cfg, ev_base, worker)
+    rspamd_config:add_periodic(ev_base, 0.0,
+        function(_, _)
+          return check_anns(rule, cfg, ev_base)
+        end)
+
+    if worker:is_primary_controller() then
+      -- We also want to train neural nets when they have enough data
       rspamd_config:add_periodic(ev_base, 0.0,
           function(_, _)
-            return check_anns(rule, cfg, ev_base)
+            return maybe_train_anns(rule, cfg, ev_base, worker)
           end)
-
-      if worker:is_primary_controller() then
-        -- We also want to train neural nets when they have enough data
-        rspamd_config:add_periodic(ev_base, 0.0,
-            function(_, _)
-              return maybe_train_anns(rule, cfg, ev_base, worker)
-            end)
-      end
-    end)
-  end
+    end
+  end)
 end


More information about the Commits mailing list