commit b8a7db1: [Minor] Neural: Various fixes

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Jul 15 15:49:05 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-07-15 15:42:59 +0100
URL: https://github.com/rspamd/rspamd/commit/b8a7db17236cf6fc3757e593c5ba2b3429ed1dc6

[Minor] Neural: Various fixes

---
 src/plugins/lua/neural.lua | 57 +++++++++++++++++++++++++++-------------------
 1 file changed, 34 insertions(+), 23 deletions(-)

diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 51a33e6e1..2e4c8e7cc 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -273,7 +273,7 @@ local function new_ann_profile(task, rule, set, version)
   local profile = {
     symbols = set.symbols,
     redis_key = ann_key,
-    version = version or 0,
+    version = version,
     digest = set.digest,
     distance = 0 -- Since we are using our own profile
   }
@@ -334,8 +334,8 @@ local function ann_scores_filter(task)
       score = out[1]
 
       local symscore = string.format('%.3f', score)
-      lua_util.debugm(N, task, '%s:%s ann score: %s',
-          rule.prefix, set.name, symscore)
+      lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
+          rule.prefix, set.name, set.ann.version, symscore)
 
       if score > 0 then
         local result = score
@@ -425,6 +425,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         local vec = result_to_vector(task, set)
 
         local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+        local target_key = set.ann.redis_key .. '_' .. learn_type
 
         local function learn_vec_cb(_err)
           if _err then
@@ -432,8 +433,9 @@ local function ann_push_task_result(rule, task, verdict, score, set)
                 rule.prefix, set.name, _err)
           else
             lua_util.debugm(N, task,
-                "add train data for ANN rule %s:%s, save %s vector of %s elts; %s bytes compressed",
-                rule.prefix, set.name, learn_type, #vec, #str)
+                "add train data for ANN rule" ..
+                "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+                rule.prefix, set.name, learn_type, #vec, target_key, #str)
           end
         end
 
@@ -443,7 +445,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
             true, -- is write
             learn_vec_cb, --callback
             'LPUSH', -- command
-            { set.ann.redis_key .. '_' .. learn_type, str} -- arguments
+            { target_key, str } -- arguments
         )
       else
         if err then
@@ -458,7 +460,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
 
     if not set.ann then
       -- Need to create or load a profile corresponding to the current configuration
-      set.ann = new_ann_profile(task, rule, set)
+      set.ann = new_ann_profile(task, rule, set, 0)
     end
     -- Check if we can learn
     lua_redis.exec_redis_script(redis_can_store_train_vec_id,
@@ -606,8 +608,8 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
             {ann_key, 'lock'}
         )
       else
-        rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes',
-            rule.prefix, set.name, #data)
+        rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes; redis key: %s',
+            rule.prefix, set.name, #data, ann_key)
         local ann_data = rspamd_util.zstd_compress(data)
         if not set.ann then
           set.ann = {
@@ -800,6 +802,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
 
           if ann then
             set.ann = {
+              digest = profile.digest,
               version = profile.version,
               symbols = profile.symbols,
               distance = min_diff,
@@ -943,18 +946,21 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
         ann_key)
 
     -- Create continuation closure
-    local redis_len_cb_gen = function(cont_cb)
+    local redis_len_cb_gen = function(cont_cb, what)
       return function(err, data)
         if err then
           rspamd_logger.errx(rspamd_config,
-              'cannot get ANN trains %s from redis: %s', ann_key, err)
+              'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
         elseif data and type(data) == 'number' or type(data) == 'string' then
           if tonumber(data) and tonumber(data) >= rule.train.max_trains then
+            rspamd_logger.debugm(N, rspamd_config,
+                'ANN %s has %s %s learn vectors (%s required)',
+                ann_key, tonumber(data), what, rule.train.max_trains)
             cont_cb()
           else
             rspamd_logger.debugm(N, rspamd_config,
-                'no need to learn ANN %s %s learn vectors (%s required)',
-                ann_key, tonumber(data), rule.train.max_trains)
+                'no need to learn ANN %s %s %s learn vectors (%s required)',
+                ann_key, tonumber(data), what, rule.train.max_trains)
           end
         end
       end
@@ -975,7 +981,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
           rule.redis,
           nil,
           false, -- is write
-          redis_len_cb_gen(initiate_train), --callback
+          redis_len_cb_gen(initiate_train, 'ham'), --callback
           'LLEN', -- command
           {ann_key .. '_ham'}
       )
@@ -986,7 +992,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
         rule.redis,
         nil,
         false, -- is write
-        redis_len_cb_gen(check_ham_len), --callback
+        redis_len_cb_gen(check_ham_len, 'spam'), --callback
         'LLEN', -- command
         {ann_key .. '_spam'}
     )
@@ -1016,14 +1022,15 @@ local function load_ann_profile(element)
 end
 
 -- Function to check or load ANNs from Redis
-local function check_anns(worker, cfg, ev_base, rule, process_callback)
+local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
   for _,set in pairs(rule.settings) do
     local function members_cb(err, data)
       if err then
         rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
             err)
       elseif type(data) == 'table' then
-        lua_util.debugm(N, cfg, 'process element %s:%s', rule.prefix, set.name)
+        lua_util.debugm(N, cfg, '%s: process element %s:%s',
+            what, rule.prefix, set.name)
         process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
       end
     end
@@ -1272,7 +1279,7 @@ end
 
 rspamd_config:register_symbol({
   name = 'NEURAL_LEARN',
-  type = 'idempotent,nostat',
+  type = 'idempotent,nostat,explicit_disable',
   priority = 5,
   callback = ann_push_vector
 })
@@ -1284,10 +1291,13 @@ for _,rule in pairs(settings.rules) do
   rspamd_config:add_post_init(process_rules_settings)
   -- This function will check ANNs in Redis when a worker is loaded
   rspamd_config:add_on_load(function(cfg, ev_base, worker)
-    rspamd_config:add_periodic(ev_base, 0.0,
-        function(_, _)
-          return check_anns(worker, cfg, ev_base, rule, process_existing_ann)
-        end)
+    if worker:is_scanner() then
+      rspamd_config:add_periodic(ev_base, 0.0,
+          function(_, _)
+            return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
+                'try_load_ann')
+          end)
+    end
 
     if worker:is_primary_controller() then
       -- We also want to train neural nets when they have enough data
@@ -1295,7 +1305,8 @@ for _,rule in pairs(settings.rules) do
           function(_, _)
             -- Clean old ANNs
             cleanup_anns(rule, cfg, ev_base)
-            return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann)
+            return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
+                'try_train_ann')
           end)
     end
   end)


More information about the Commits mailing list