commit 43205d7: [Minor] Neural: Add nan check and extensive logging

Vsevolod Stakhov vsevolod at highsecure.ru
Fri Oct 18 16:21:09 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-10-18 17:18:26 +0100
URL: https://github.com/rspamd/rspamd/commit/43205d7e865312939ed452223442e5128b0e2a6a (HEAD -> master)

[Minor] Neural: Add nan check and extensive logging

---
 src/plugins/lua/neural.lua | 46 +++++++++++++++++++++++++++++++++-------------
 1 file changed, 33 insertions(+), 13 deletions(-)

diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 1ff1f40d7..e6ffe41be 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -564,7 +564,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
     local inputs, outputs = {}, {}
 
     -- Used to show sparsed vectors in a convenient format (for debugging only)
-    --[[
     local function debug_vec(t)
       local ret = {}
       for i,v in ipairs(t) do
@@ -575,7 +574,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
 
       return ret
     end
-    ]]--
 
     -- Make training set by joining vectors
     -- KANN automatically shuffles those samples
@@ -595,22 +593,44 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
     -- Called in child process
     local function train()
       local log_thresh = rule.train.max_iterations / 10
-      train_ann:train1(inputs, outputs, {
-        lr = rule.train.learning_rate,
-        max_epoch = rule.train.max_iterations,
-        cb = function(iter, train_cost, _)
-          if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
-            rspamd_logger.infox(rspamd_config,
-                "ANN %s:%s: learned from %s redis key in %s iterations, error: %s",
+      local seen_nan = false
+
+      local function train_cb(iter, train_cost, value_cost)
+        if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
+          if train_cost ~= train_cost and not seen_nan then
+            -- We have nan :( try to log lot's of stuff to dig into a problem
+            seen_nan = true
+            rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
                 rule.prefix, set.name,
-                ann_key,
-                iter, train_cost)
+                value_cost)
+            for i,e in ipairs(inputs) do
+              lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
+                  debug_vec(e), outputs[i][1])
+            end
           end
+
+          rspamd_logger.infox(rspamd_config,
+              "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+              rule.prefix, set.name,
+              ann_key,
+              iter,
+              train_cost,
+              value_cost)
         end
+      end
+
+      train_ann:train1(inputs, outputs, {
+        lr = rule.train.learning_rate,
+        max_epoch = rule.train.max_iterations,
+        cb = train_cb,
       })
 
-      local out = train_ann:save()
-      return out
+      if not seen_nan then
+        local out = train_ann:save()
+        return out
+      else
+        return nil
+      end
     end
 
     set.learning_spawned = true


More information about the Commits mailing list