commit 6d11758: [Fix] Neural: Another bunch of fixes
Vsevolod Stakhov
vsevolod at highsecure.ru
Mon Jul 15 15:49:07 UTC 2019
Author: Vsevolod Stakhov
Date: 2019-07-15 16:40:47 +0100
URL: https://github.com/rspamd/rspamd/commit/6d11758e98e2adac29897cb45b7a243625d9b761
[Fix] Neural: Another bunch of fixes
---
src/plugins/lua/neural.lua | 45 +++++++++++++++++++++++++++++++++++++++------
1 file changed, 39 insertions(+), 6 deletions(-)
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 2e4c8e7cc..a68d6f83a 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -541,6 +541,20 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
else
local inputs, outputs = {}, {}
+ -- Used to show sparsed vectors in a convenient format (for debugging only)
+ --[[
+ local function debug_vec(t)
+ local ret = {}
+ for i,v in ipairs(t) do
+ if v ~= 0 then
+ ret[#ret + 1] = string.format('%d=%.2f', i, v)
+ end
+ end
+
+ return ret
+ end
+ ]]--
+
-- Make training set by joining vectors
-- KANN automatically shuffles those samples
-- 1.0 is used for spam and -1.0 is used for ham
@@ -548,21 +562,26 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
for _,e in ipairs(spam_vec) do
inputs[#inputs + 1] = e
outputs[#outputs + 1] = {1.0}
+ --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
end
for _,e in ipairs(ham_vec) do
inputs[#inputs + 1] = e
outputs[#outputs + 1] = {-1.0}
+ --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
end
-- Called in child process
local function train()
+ local log_thresh = rule.train.max_iterations / 10
train_ann:train1(inputs, outputs, {
lr = rule.train.learning_rate,
max_epoch = rule.train.max_iterations,
cb = function(iter, train_cost, _)
- if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then
- rspamd_logger.infox(rspamd_config, "ANN %s:%s: learned %s iterations, error: %s",
+ if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
+ rspamd_logger.infox(rspamd_config,
+ "ANN %s:%s: learned from %s redis key in %s iterations, error: %s",
rule.prefix, set.name,
+ ann_key,
iter, train_cost)
end
end
@@ -589,7 +608,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
)
else
rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, set.ann.redis_key)
end
end
@@ -608,8 +627,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
{ann_key, 'lock'}
)
else
- rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes; redis key: %s',
- rule.prefix, set.name, #data, ann_key)
local ann_data = rspamd_util.zstd_compress(data)
if not set.ann then
set.ann = {
@@ -637,6 +654,10 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+ rspamd_logger.infox(rspamd_config,
+ 'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',
+ rule.prefix, set.name, #data, set.ann.redis_key, ann_key)
+
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
@@ -1131,8 +1152,20 @@ local function process_rules_settings()
rule.prefix, selt.name)
end
+ local function filter_symbols_predicate(sname)
+ local fl = rspamd_config:get_symbol_flags(sname)
+ if fl then
+ fl = lua_util.list_to_hash(fl)
+
+ return not (fl.nostat or fl.idempotent or fl.skip)
+ end
+
+ return false
+ end
+
-- Generic stuff
- table.sort(selt.symbols)
+ table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
+
selt.digest = lua_util.table_digest(selt.symbols)
selt.prefix = redis_ann_prefix(rule, selt.name)
More information about the Commits
mailing list