commit ab77b33: [Project] Implement 'probabilistic' learn mode for ANN

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Aug 3 13:07:07 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-03 13:18:59 +0100
URL: https://github.com/rspamd/rspamd/commit/ab77b3398b3462776c4404291a9bdb4ee1b67365

[Project] Implement 'probabilistic' learn mode for ANN

---
 src/plugins/lua/neural.lua | 174 +++++++++++++++++++++++++++++++++------------
 1 file changed, 130 insertions(+), 44 deletions(-)

diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index e6c52912a..41a9b4f07 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -42,8 +42,11 @@ local default_options = {
     autotrain = true,
     train_prob = 1.0,
     learn_threads = 1,
+    learn_mode = 'balanced', -- Possible values: balanced, proportional
     learning_rate = 0.01,
-    classes_bias = 0.0, -- What difference is allowed between classes (1:1 proportion means 0 bias)
+    classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
+    spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
+    ham_skip_prob = 0.0, -- proportional mode: ham skip probability
   },
   watch_interval = 60.0,
   lock_expire = 600,
@@ -97,26 +100,21 @@ end
 -- Lua script that checks if we can store a new training vector
 -- Uses the following keys:
 -- key1 - ann key
--- key2 - spam or ham
--- key3 - maximum trains
--- key4 - sampling coin (as Redis scripts do not allow math.random calls)
--- key5 - classes bias
--- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
-local redis_lua_script_can_store_train_vec = [[
+-- returns nspam,nham (or nil if locked)
+local redis_lua_script_vectors_len = [[
   local prefix = KEYS[1]
   local locked = redis.call('HGET', prefix, 'lock')
-  if locked then return {tostring(-1),'locked by another process till: ' .. locked} end
+  if locked then return false end
   local nspam = 0
   local nham = 0
-  local lim = tonumber(KEYS[3])
-  local coin = tonumber(KEYS[4])
-  local classes_bias = tonumber(KEYS[5])
 
   local ret = redis.call('LLEN', prefix .. '_spam')
   if ret then nspam = tonumber(ret) end
   ret = redis.call('LLEN', prefix .. '_ham')
   if ret then nham = tonumber(ret) end
 
+  return {nspam,nham}
+
   if KEYS[2] == 'spam' then
     if nspam <= lim then
       if nspam > nham then
@@ -147,7 +145,7 @@ local redis_lua_script_can_store_train_vec = [[
 
   return {tostring(-1),'bad input'}
 ]]
-local redis_can_store_train_vec_id = nil
+local redis_lua_script_vectors_len_id = nil
 
 -- Lua script to invalidate ANNs by rank
 -- Uses the following keys
@@ -220,7 +218,7 @@ local redis_save_unlock_id = nil
 local redis_params
 
 local function load_scripts(params)
-  redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec,
+  redis_lua_script_vectors_len_id = lua_redis.add_redis_script(redis_lua_script_vectors_len,
     params)
   redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
     params)
@@ -379,6 +377,88 @@ local function create_ann(n, nlayers)
   return rspamd_kann.new.kann(t)
 end
 
+local function can_push_train_vector(rule, task, learn_type, nspam, nham)
+  local train_opts = rule.train
+  local coin = math.random()
+
+  if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
+    rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+    return false
+  end
+
+  if train_opts.learn_mode == 'balanced' then
+    -- Keep balanced training set based on number of spam and ham samples
+    if learn_type == 'spam' then
+      if nspam <= train_opts.max_trains then
+        if nspam > nham then
+          -- Apply sampling
+          local skip_rate = 1.0 - nham / (nspam + 1)
+          if coin < skip_rate - train_opts.classes_bias then
+            rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type,
+                skip_rate - train_opts.classes_bias)
+            return false
+          end
+        end
+        return true
+      else -- Enough learns
+        rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', learn_type,
+            nspam)
+      end
+    else
+      if nham <= train_opts.max_trains then
+        if nham > nspam then
+          -- Apply sampling
+          local skip_rate = 1.0 - nspam / (nham + 1)
+          if coin < skip_rate - train_opts.classes_bias then
+            rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type,
+                skip_rate - train_opts.classes_bias)
+            return false
+          end
+        end
+        return true
+      else
+        rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
+            nham)
+      end
+    end
+  else
+    -- Probabilistic learn mode, we just skip learn if we already have enough samples or
+    -- if our coin drop is less than desired probability
+    if learn_type == 'spam' then
+      if nspam <= train_opts.max_trains then
+        if train_opts.spam_skip_prob then
+          if coin <= train_opts.spam_skip_prob then
+            rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+                coin, train_opts.spam_skip_prob)
+            return false
+          end
+
+          return true
+        end
+      else
+        rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
+            nspam, train_opts.max_trains)
+      end
+    else
+      if nham <= train_opts.max_trains then
+        if train_opts.ham_skip_prob then
+          if coin <= train_opts.ham_skip_prob then
+            rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+                coin, train_opts.ham_skip_prob)
+            return false
+          end
+
+          return true
+        end
+      else
+        rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
+            nham, train_opts.max_trains)
+      end
+    end
+  end
+
+  return false
+end
 
 local function ann_push_task_result(rule, task, verdict, score, set)
   local train_opts = rule.train
@@ -436,17 +516,12 @@ local function ann_push_task_result(rule, task, verdict, score, set)
     local learn_type
     if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
 
-    local function can_train_cb(err, data)
+    local function vectors_len_cb(err, data)
       if not err and type(data) == 'table' then
-        local nsamples,reason = tonumber(data[1]),data[2]
+        local nspam,nham = data[1],data[2]
 
-        if nsamples >= 0 then
-          local coin = math.random()
-
-          if coin < 1.0 - train_opts.train_prob then
-            rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
-            return
-          end
+        if nspam > 0 and nham > 0 and
+            can_push_train_vector(rule, task, learn_type, nspam, nham) then
 
           local vec = result_to_vector(task, set)
 
@@ -473,15 +548,15 @@ local function ann_push_task_result(rule, task, verdict, score, set)
               'LPUSH', -- command
               { target_key, str } -- arguments
           )
-        else
-          -- Negative result returned
-          rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: %s (%s vectors stored)",
-              learn_type, rule.prefix, set.name, set.ann.redis_key, reason, -tonumber(nsamples))
         end
       else
         if err then
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
               rule.prefix, set.name, err)
+        elseif type(data) == 'userdata' then
+          -- nil return value
+          rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning",
+              learn_type, rule.prefix, set.name, set.ann.redis_key)
         else
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
               'please remove this key from Redis manually if you perform upgrade from the previous version',
@@ -500,15 +575,11 @@ local function ann_push_task_result(rule, task, verdict, score, set)
             set.name)
       end
 
-      lua_redis.exec_redis_script(redis_can_store_train_vec_id,
-          {task = task, is_write = true},
-          can_train_cb,
+      lua_redis.exec_redis_script(redis_lua_script_vectors_len_id,
+          {task = task, is_write = false},
+          vectors_len_cb,
           {
             set.ann.redis_key,
-            learn_type,
-            tostring(train_opts.max_trains),
-            tostring(math.random()),
-            tostring(train_opts.classes_bias)
           })
     else
       lua_util.debugm(N, task,
@@ -1059,18 +1130,33 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
             -- at least (10 * (1 - 0.25)) = 8 trains
 
             local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
-            local len_bias_check_pred = function(_, l)
-              return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
-            end
-            if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
-              rspamd_logger.debugm(N, rspamd_config,
-                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                  ann_key, lens, rule.train.max_trains, what)
-              cont_cb()
+
+            if rule.train.learn_type == 'balanced' then
+              local len_bias_check_pred = function(_, l)
+                return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
+              end
+              if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
+                rspamd_logger.debugm(N, rspamd_config,
+                    'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+                    ann_key, lens, rule.train.max_trains, what)
+                cont_cb()
+              else
+                rspamd_logger.debugm(N, rspamd_config,
+                    'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+                    ann_key, what, lens, rule.train.max_trains)
+              end
             else
-              rspamd_logger.debugm(N, rspamd_config,
-                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                  ann_key, what, lens, rule.train.max_trains)
+              -- Probabilistic mode, just ensure that at least one vector is okay
+              if max_len >= rule.train.max_trains then
+                rspamd_logger.debugm(N, rspamd_config,
+                    'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+                    ann_key, lens, rule.train.max_trains, what)
+                cont_cb()
+              else
+                rspamd_logger.debugm(N, rspamd_config,
+                    'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+                    ann_key, what, lens, rule.train.max_trains)
+              end
             end
 
           else


More information about the Commits mailing list