commit 43205d7: [Minor] Neural: Add nan check and extensive logging
Vsevolod Stakhov
vsevolod at highsecure.ru
Fri Oct 18 16:21:09 UTC 2019
Author: Vsevolod Stakhov
Date: 2019-10-18 17:18:26 +0100
URL: https://github.com/rspamd/rspamd/commit/43205d7e865312939ed452223442e5128b0e2a6a (HEAD -> master)
[Minor] Neural: Add nan check and extensive logging
---
src/plugins/lua/neural.lua | 46 +++++++++++++++++++++++++++++++++-------------
1 file changed, 33 insertions(+), 13 deletions(-)
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 1ff1f40d7..e6ffe41be 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -564,7 +564,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
local inputs, outputs = {}, {}
-- Used to show sparsed vectors in a convenient format (for debugging only)
- --[[
local function debug_vec(t)
local ret = {}
for i,v in ipairs(t) do
@@ -575,7 +574,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
return ret
end
- ]]--
-- Make training set by joining vectors
-- KANN automatically shuffles those samples
@@ -595,22 +593,44 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
-- Called in child process
local function train()
local log_thresh = rule.train.max_iterations / 10
- train_ann:train1(inputs, outputs, {
- lr = rule.train.learning_rate,
- max_epoch = rule.train.max_iterations,
- cb = function(iter, train_cost, _)
- if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
- rspamd_logger.infox(rspamd_config,
- "ANN %s:%s: learned from %s redis key in %s iterations, error: %s",
+ local seen_nan = false
+
+ local function train_cb(iter, train_cost, value_cost)
+ if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
+ if train_cost ~= train_cost and not seen_nan then
+ -- We have nan :( try to log lot's of stuff to dig into a problem
+ seen_nan = true
+ rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
rule.prefix, set.name,
- ann_key,
- iter, train_cost)
+ value_cost)
+ for i,e in ipairs(inputs) do
+ lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
+ debug_vec(e), outputs[i][1])
+ end
end
+
+ rspamd_logger.infox(rspamd_config,
+ "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+ rule.prefix, set.name,
+ ann_key,
+ iter,
+ train_cost,
+ value_cost)
end
+ end
+
+ train_ann:train1(inputs, outputs, {
+ lr = rule.train.learning_rate,
+ max_epoch = rule.train.max_iterations,
+ cb = train_cb,
})
- local out = train_ann:save()
- return out
+ if not seen_nan then
+ local out = train_ann:save()
+ return out
+ else
+ return nil
+ end
end
set.learning_spawned = true
More information about the Commits
mailing list