commit 5bcf096: [Fix] Allow to adjust neurons in the hidden layer

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Aug 24 16:07:09 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-24 14:32:11 +0100
URL: https://github.com/rspamd/rspamd/commit/5bcf0964e917917c19b7bb1627b4146d3d1d38c7

[Fix] Allow to adjust neurons in the hidden layer

---
 src/plugins/lua/neural.lua | 31 +++++++++++++++++++++----------
 1 file changed, 21 insertions(+), 10 deletions(-)

diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 99efe720e..9df2f1c55 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -31,6 +31,9 @@ local ts = require("tableshape").types
 local lua_verdict = require "lua_verdict"
 local N = "neural"
 
+-- Used in prefix to avoid wrong ANN to be loaded
+local plugin_ver = '2'
+
 -- Module vars
 local default_options = {
   train = {
@@ -52,6 +55,7 @@ local default_options = {
   lock_expire = 600,
   learning_spawned = false,
   ann_expire = 60 * 60 * 24 * 2, -- 2 days
+  hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
   symbol_spam = 'NEURAL_SPAM',
   symbol_ham = 'NEURAL_HAM',
 }
@@ -251,8 +255,8 @@ end
 local function redis_ann_prefix(rule, settings_name)
   -- We also need to count metatokens:
   local n = meta_functions.version
-  return string.format('%s_%s_%d_%s',
-      settings.prefix, rule.prefix, n, settings_name)
+  return string.format('%s%d_%s_%d_%s',
+    settings.prefix, plugin_ver, rule.prefix, n, settings_name)
 end
 
 -- Creates and stores ANN profile in Redis
@@ -337,9 +341,9 @@ local function ann_scores_filter(task)
   end
 end
 
-local function create_ann(n, nlayers)
+local function create_ann(n, nlayers, rule)
     -- We ignore number of layers so far when using kann
-  local nhidden = math.floor((n + 1) / 2)
+  local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
   local t = rspamd_kann.layer.input(n)
   t = rspamd_kann.transform.relu(t)
   t = rspamd_kann.layer.dense(t, nhidden);
@@ -364,14 +368,18 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
           -- Apply sampling
           local skip_rate = 1.0 - nham / (nspam + 1)
           if coin < skip_rate - train_opts.classes_bias then
-            rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type,
-                skip_rate - train_opts.classes_bias)
+            rspamd_logger.infox(task,
+                'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+                learn_type,
+                skip_rate - train_opts.classes_bias,
+                nspam, nham)
             return false
           end
         end
         return true
       else -- Enough learns
-        rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', learn_type,
+        rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
+            learn_type,
             nspam)
       end
     else
@@ -380,8 +388,11 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
           -- Apply sampling
           local skip_rate = 1.0 - nspam / (nham + 1)
           if coin < skip_rate - train_opts.classes_bias then
-            rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type,
-                skip_rate - train_opts.classes_bias)
+            rspamd_logger.infox(task,
+                'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+                learn_type,
+                skip_rate - train_opts.classes_bias,
+                nspam, nham)
             return false
           end
         end
@@ -625,7 +636,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
       meta_functions.rspamd_count_metatokens()
 
   -- Now we can train ann
-  local train_ann = create_ann(n, 3)
+  local train_ann = create_ann(n, 3, rule)
 
   if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
     -- Invalidate ANN as it is definitely invalid


More information about the Commits mailing list