commit 195e79d: [Feature] Neural: Introduce classes bias that allows non-equal classes learning

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Mar 16 11:21:06 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-03-16 11:15:12 +0000
URL: https://github.com/rspamd/rspamd/commit/195e79d69ba3b0ab8b21e9f81eb76ee98e3858f6

[Feature] Neural: Introduce classes bias that allows non-equal classes learning

---
 .luacheckrc                |  8 +++-----
 src/plugins/lua/neural.lua | 49 ++++++++++++++++++++++++++++++++++------------
 2 files changed, 39 insertions(+), 18 deletions(-)

diff --git a/.luacheckrc b/.luacheckrc
index 1af61bfda..7ade0174a 100644
--- a/.luacheckrc
+++ b/.luacheckrc
@@ -34,7 +34,9 @@ globals = {
   'rspamadm_ev_base',
   'rspamadm_session',
   'rspamadm_dns_resolver',
-  'jit'
+  'jit',
+  'table.unpack',
+  'unpack',
 }
 
 ignore = {
@@ -55,10 +57,6 @@ files['/**/src/plugins/lua/reputation.lua'].globals = {
   'math.tanh',
 }
 
-files['/**/lualib/lua_util.lua'].globals = {
-  'table.unpack',
-  'unpack',
-}
 
 files['/**/lualib/lua_redis.lua'].globals = {
   'rspamadm_ev_base',
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 1897f0843..affb07307 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -43,6 +43,7 @@ local default_options = {
     train_prob = 1.0,
     learn_threads = 1,
     learning_rate = 0.01,
+    classes_bias = 0.0, -- What difference is allowed between classes (1:1 proportion means 0 bias)
   },
   watch_interval = 60.0,
   lock_expire = 600,
@@ -99,6 +100,7 @@ end
 -- key2 - spam or ham
 -- key3 - maximum trains
 -- key4 - sampling coin (as Redis scripts do not allow math.random calls)
+-- key5 - classes bias
 -- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
 local redis_lua_script_can_store_train_vec = [[
   local prefix = KEYS[1]
@@ -108,6 +110,7 @@ local redis_lua_script_can_store_train_vec = [[
   local nham = 0
   local lim = tonumber(KEYS[3])
   local coin = tonumber(KEYS[4])
+  local classes_bias = tonumber(KEYS[5])
 
   local ret = redis.call('LLEN', prefix .. '_spam')
   if ret then nspam = tonumber(ret) end
@@ -119,8 +122,8 @@ local redis_lua_script_can_store_train_vec = [[
       if nspam > nham then
         -- Apply sampling
         local skip_rate = 1.0 - nham / (nspam + 1)
-        if coin < skip_rate then
-          return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}
+        if coin < skip_rate - classes_bias then
+          return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)}
         end
       end
       return {tostring(nspam),'can learn'}
@@ -132,8 +135,8 @@ local redis_lua_script_can_store_train_vec = [[
       if nham > nspam then
         -- Apply sampling
         local skip_rate = 1.0 - nspam / (nham + 1)
-        if coin < skip_rate then
-          return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}
+        if coin < skip_rate - classes_bias then
+          return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)}
         end
       end
       return {tostring(nham),'can learn'}
@@ -505,6 +508,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
             learn_type,
             tostring(train_opts.max_trains),
             tostring(math.random()),
+            tostring(train_opts.classes_bias)
           })
     else
       lua_util.debugm(N, task,
@@ -1014,6 +1018,10 @@ end
 local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
   local my_symbols = set.symbols
   local sel_elt
+  local lens = {
+    spam = 0,
+    ham = 0,
+  }
 
   for _,elt in fun.iter(profiles) do
     if elt and elt.symbols then
@@ -1040,21 +1048,36 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
           rspamd_logger.errx(rspamd_config,
               'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
         elseif data and type(data) == 'number' or type(data) == 'string' then
-          if tonumber(data) and tonumber(data) >= rule.train.max_trains then
-            if is_final then
+          local ntrains = tonumber(data) or 0
+          lens[what] = ntrains
+          if is_final then
+            local unpack = rawget(table, "unpack") or unpack
+            -- Ensure that we have the following:
+            -- one class has reached max_trains
+            -- other class(es) are at least as full as classes_bias
+            -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
+            -- one class must have 10 or more trains whilst another should have
+            -- at least (10 * (1 - 0.25)) = 8 trains
+
+            local max_len = math.max(unpack(lens))
+            local len_bias_check_pred = function(l)
+              return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
+            end
+            if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
               rspamd_logger.debugm(N, rspamd_config,
                   'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                  ann_key, tonumber(data), rule.train.max_trains, what)
+                  ann_key, lens, rule.train.max_trains, what)
+              cont_cb()
             else
               rspamd_logger.debugm(N, rspamd_config,
-                  'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
-                  what, ann_key, tonumber(data), rule.train.max_trains)
+                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+                  ann_key, what, lens, rule.train.max_trains)
             end
-            cont_cb()
+
           else
             rspamd_logger.debugm(N, rspamd_config,
-                'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                ann_key, what, tonumber(data), rule.train.max_trains)
+                'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+                what, ann_key, ntrains, rule.train.max_trains)
           end
         end
       end
@@ -1064,7 +1087,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
     local function initiate_train()
       rspamd_logger.infox(rspamd_config,
           'need to learn ANN %s after %s required learn vectors',
-          ann_key, rule.train.max_trains)
+          ann_key, lens)
       do_train_ann(worker, ev_base, rule, set, ann_key)
     end
 


More information about the Commits mailing list