commit 6d11758: [Fix] Neural: Another bunch of fixes

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Jul 15 15:49:07 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-07-15 16:40:47 +0100
URL: https://github.com/rspamd/rspamd/commit/6d11758e98e2adac29897cb45b7a243625d9b761

[Fix] Neural: Another bunch of fixes

---
 src/plugins/lua/neural.lua | 45 +++++++++++++++++++++++++++++++++++++++------
 1 file changed, 39 insertions(+), 6 deletions(-)

diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 2e4c8e7cc..a68d6f83a 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -541,6 +541,20 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
   else
     local inputs, outputs = {}, {}
 
+    -- Used to show sparsed vectors in a convenient format (for debugging only)
+    --[[
+    local function debug_vec(t)
+      local ret = {}
+      for i,v in ipairs(t) do
+        if v ~= 0 then
+          ret[#ret + 1] = string.format('%d=%.2f', i, v)
+        end
+      end
+
+      return ret
+    end
+    ]]--
+
     -- Make training set by joining vectors
     -- KANN automatically shuffles those samples
     -- 1.0 is used for spam and -1.0 is used for ham
@@ -548,21 +562,26 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
     for _,e in ipairs(spam_vec) do
       inputs[#inputs + 1] = e
       outputs[#outputs + 1] = {1.0}
+      --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
     end
     for _,e in ipairs(ham_vec) do
       inputs[#inputs + 1] = e
       outputs[#outputs + 1] = {-1.0}
+      --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
     end
 
     -- Called in child process
     local function train()
+      local log_thresh = rule.train.max_iterations / 10
       train_ann:train1(inputs, outputs, {
         lr = rule.train.learning_rate,
         max_epoch = rule.train.max_iterations,
         cb = function(iter, train_cost, _)
-          if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then
-            rspamd_logger.infox(rspamd_config, "ANN %s:%s: learned %s iterations, error: %s",
+          if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
+            rspamd_logger.infox(rspamd_config,
+                "ANN %s:%s: learned from %s redis key in %s iterations, error: %s",
                 rule.prefix, set.name,
+                ann_key,
                 iter, train_cost)
           end
         end
@@ -589,7 +608,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
         )
       else
         rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
-            rule.prefix, set.name, ann_key)
+            rule.prefix, set.name, set.ann.redis_key)
       end
     end
 
@@ -608,8 +627,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
             {ann_key, 'lock'}
         )
       else
-        rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes; redis key: %s',
-            rule.prefix, set.name, #data, ann_key)
         local ann_data = rspamd_util.zstd_compress(data)
         if not set.ann then
           set.ann = {
@@ -637,6 +654,10 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
         local ucl = require "ucl"
         local profile_serialized = ucl.to_format(profile, 'json-compact', true)
 
+        rspamd_logger.infox(rspamd_config,
+            'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',
+            rule.prefix, set.name, #data, set.ann.redis_key, ann_key)
+
         lua_redis.exec_redis_script(redis_save_unlock_id,
             {ev_base = ev_base, is_write = true},
             redis_save_cb,
@@ -1131,8 +1152,20 @@ local function process_rules_settings()
           rule.prefix, selt.name)
     end
 
+    local function filter_symbols_predicate(sname)
+      local fl = rspamd_config:get_symbol_flags(sname)
+      if fl then
+        fl = lua_util.list_to_hash(fl)
+
+        return not (fl.nostat or fl.idempotent or fl.skip)
+      end
+
+      return false
+    end
+
     -- Generic stuff
-    table.sort(selt.symbols)
+    table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
+
     selt.digest = lua_util.table_digest(selt.symbols)
     selt.prefix = redis_ann_prefix(rule, selt.name)
 


More information about the Commits mailing list