commit b8a7db1: [Minor] Neural: Various fixes
Vsevolod Stakhov
vsevolod at highsecure.ru
Mon Jul 15 15:49:05 UTC 2019
Author: Vsevolod Stakhov
Date: 2019-07-15 15:42:59 +0100
URL: https://github.com/rspamd/rspamd/commit/b8a7db17236cf6fc3757e593c5ba2b3429ed1dc6
[Minor] Neural: Various fixes
---
src/plugins/lua/neural.lua | 57 +++++++++++++++++++++++++++-------------------
1 file changed, 34 insertions(+), 23 deletions(-)
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 51a33e6e1..2e4c8e7cc 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -273,7 +273,7 @@ local function new_ann_profile(task, rule, set, version)
local profile = {
symbols = set.symbols,
redis_key = ann_key,
- version = version or 0,
+ version = version,
digest = set.digest,
distance = 0 -- Since we are using our own profile
}
@@ -334,8 +334,8 @@ local function ann_scores_filter(task)
score = out[1]
local symscore = string.format('%.3f', score)
- lua_util.debugm(N, task, '%s:%s ann score: %s',
- rule.prefix, set.name, symscore)
+ lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
+ rule.prefix, set.name, set.ann.version, symscore)
if score > 0 then
local result = score
@@ -425,6 +425,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
local vec = result_to_vector(task, set)
local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+ local target_key = set.ann.redis_key .. '_' .. learn_type
local function learn_vec_cb(_err)
if _err then
@@ -432,8 +433,9 @@ local function ann_push_task_result(rule, task, verdict, score, set)
rule.prefix, set.name, _err)
else
lua_util.debugm(N, task,
- "add train data for ANN rule %s:%s, save %s vector of %s elts; %s bytes compressed",
- rule.prefix, set.name, learn_type, #vec, #str)
+ "add train data for ANN rule" ..
+ "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+ rule.prefix, set.name, learn_type, #vec, target_key, #str)
end
end
@@ -443,7 +445,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
true, -- is write
learn_vec_cb, --callback
'LPUSH', -- command
- { set.ann.redis_key .. '_' .. learn_type, str} -- arguments
+ { target_key, str } -- arguments
)
else
if err then
@@ -458,7 +460,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
if not set.ann then
-- Need to create or load a profile corresponding to the current configuration
- set.ann = new_ann_profile(task, rule, set)
+ set.ann = new_ann_profile(task, rule, set, 0)
end
-- Check if we can learn
lua_redis.exec_redis_script(redis_can_store_train_vec_id,
@@ -606,8 +608,8 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
{ann_key, 'lock'}
)
else
- rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes',
- rule.prefix, set.name, #data)
+ rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes; redis key: %s',
+ rule.prefix, set.name, #data, ann_key)
local ann_data = rspamd_util.zstd_compress(data)
if not set.ann then
set.ann = {
@@ -800,6 +802,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
if ann then
set.ann = {
+ digest = profile.digest,
version = profile.version,
symbols = profile.symbols,
distance = min_diff,
@@ -943,18 +946,21 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
ann_key)
-- Create continuation closure
- local redis_len_cb_gen = function(cont_cb)
+ local redis_len_cb_gen = function(cont_cb, what)
return function(err, data)
if err then
rspamd_logger.errx(rspamd_config,
- 'cannot get ANN trains %s from redis: %s', ann_key, err)
+ 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
if tonumber(data) and tonumber(data) >= rule.train.max_trains then
+ rspamd_logger.debugm(N, rspamd_config,
+ 'ANN %s has %s %s learn vectors (%s required)',
+ ann_key, tonumber(data), what, rule.train.max_trains)
cont_cb()
else
rspamd_logger.debugm(N, rspamd_config,
- 'no need to learn ANN %s %s learn vectors (%s required)',
- ann_key, tonumber(data), rule.train.max_trains)
+ 'no need to learn ANN %s %s %s learn vectors (%s required)',
+ ann_key, tonumber(data), what, rule.train.max_trains)
end
end
end
@@ -975,7 +981,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
rule.redis,
nil,
false, -- is write
- redis_len_cb_gen(initiate_train), --callback
+ redis_len_cb_gen(initiate_train, 'ham'), --callback
'LLEN', -- command
{ann_key .. '_ham'}
)
@@ -986,7 +992,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
rule.redis,
nil,
false, -- is write
- redis_len_cb_gen(check_ham_len), --callback
+ redis_len_cb_gen(check_ham_len, 'spam'), --callback
'LLEN', -- command
{ann_key .. '_spam'}
)
@@ -1016,14 +1022,15 @@ local function load_ann_profile(element)
end
-- Function to check or load ANNs from Redis
-local function check_anns(worker, cfg, ev_base, rule, process_callback)
+local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
for _,set in pairs(rule.settings) do
local function members_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
err)
elseif type(data) == 'table' then
- lua_util.debugm(N, cfg, 'process element %s:%s', rule.prefix, set.name)
+ lua_util.debugm(N, cfg, '%s: process element %s:%s',
+ what, rule.prefix, set.name)
process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
end
end
@@ -1272,7 +1279,7 @@ end
rspamd_config:register_symbol({
name = 'NEURAL_LEARN',
- type = 'idempotent,nostat',
+ type = 'idempotent,nostat,explicit_disable',
priority = 5,
callback = ann_push_vector
})
@@ -1284,10 +1291,13 @@ for _,rule in pairs(settings.rules) do
rspamd_config:add_post_init(process_rules_settings)
-- This function will check ANNs in Redis when a worker is loaded
rspamd_config:add_on_load(function(cfg, ev_base, worker)
- rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- return check_anns(worker, cfg, ev_base, rule, process_existing_ann)
- end)
+ if worker:is_scanner() then
+ rspamd_config:add_periodic(ev_base, 0.0,
+ function(_, _)
+ return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
+ 'try_load_ann')
+ end)
+ end
if worker:is_primary_controller() then
-- We also want to train neural nets when they have enough data
@@ -1295,7 +1305,8 @@ for _,rule in pairs(settings.rules) do
function(_, _)
-- Clean old ANNs
cleanup_anns(rule, cfg, ev_base)
- return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann)
+ return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
+ 'try_train_ann')
end)
end
end)
More information about the Commits
mailing list