commit 06664d0: [Minor] Neural: Moar fixes
Vsevolod Stakhov
vsevolod at highsecure.ru
Mon Jul 8 13:28:08 UTC 2019
Author: Vsevolod Stakhov
Date: 2019-07-08 14:21:05 +0100
URL: https://github.com/rspamd/rspamd/commit/06664d00f86ead46600aa91655ce2e960e4b4d0a (HEAD -> master)
[Minor] Neural: Moar fixes
---
src/plugins/lua/neural.lua | 254 ++++++++++++++++++++++++++++-----------------
1 file changed, 159 insertions(+), 95 deletions(-)
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 7b6c2fa5f..0375d57cd 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -104,14 +104,14 @@ local redis_lua_script_can_store_train_vec = [[
if ret then nham = tonumber(ret) end
if KEYS[2] == 'spam' then
- if nham <= lim and nham + 1 >= nspam then
- return tostring(nspam + 1)
+ if nspam <= lim then
+ return tostring(nspam)
else
return tostring(-(nspam))
end
else
- if nspam <= lim and nspam + 1 >= nham then
- return tostring(nham + 1)
+ if nham <= lim then
+ return tostring(nham)
else
return tostring(-(nham))
end
@@ -127,8 +127,9 @@ local redis_can_store_train_vec_id = nil
-- key2 - number of elements to leave
local redis_lua_script_maybe_invalidate = [[
local card = redis.call('ZCARD', KEYS[1])
- if card > tonumber(KEYS[2]) then
- local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))))
+ local lim = tonumber(KEYS[2])
+ if card > lim then
+ local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
for _,k in ipairs(to_delete) do
local tb = cjson.decode(k)
redis.call('DEL', tb.redis_key)
@@ -136,7 +137,7 @@ local redis_lua_script_maybe_invalidate = [[
redis.call('DEL', tb.redis_key .. '_spam')
redis.call('DEL', tb.redis_key .. '_ham')
end
- redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))))
+ redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
return to_delete
else
return {}
@@ -152,17 +153,17 @@ local redis_maybe_invalidate_id = nil
-- key4 - hostname
local redis_lua_script_maybe_lock = [[
local locked = redis.call('HGET', KEYS[1], 'lock')
+ local now = tonumber(KEYS[2])
if locked then
locked = tonumber(locked)
- now = tonumber(KEYS[2])
- expire = tonumber(KEYS[3])
+ local expire = tonumber(KEYS[3])
if now > locked and (now - locked) < expire then
return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')}
end
end
redis.call('HSET', KEYS[1], 'lock', tostring(now))
redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
- return true
+ return 1
]]
local redis_maybe_lock_id = nil
@@ -178,6 +179,8 @@ local redis_lua_script_save_unlock = [[
local now = tonumber(KEYS[6])
redis.call('ZADD', KEYS[2], now, KEYS[4])
redis.call('HSET', KEYS[1], 'ann', KEYS[3])
+ redis.call('DEL', KEYS[1] .. '_spam')
+ edis.call('DEL', KEYS[1] .. '_ham')
redis.call('HDEL', KEYS[1], 'lock')
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
return 1
@@ -267,7 +270,7 @@ local function new_ann_profile(task, rule, set, version)
}
local ucl = require "ucl"
- local profile_serialized = ucl.to_format(profile, 'json-compact')
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
local function add_cb(err, _)
if err then
@@ -322,7 +325,8 @@ local function ann_scores_filter(task)
score = out[1]
local symscore = string.format('%.3f', score)
- rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore)
+ lua_util.debugm(N, task, '%s:%s ann score: %s',
+ rule.prefix, set.name, symscore)
if score > 0 then
local result = score
@@ -348,26 +352,44 @@ end
local function ann_push_task_result(rule, task, verdict, score, set)
local train_opts = rule.train
-
-
local learn_spam, learn_ham
+ local skip_reason = 'unknown'
if train_opts.autotrain then
- if verdict == 'passthrough' or verdict == 'uncertain' then
+ if verdict == 'passthrough' then
lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
verdict, score)
end
- if train_opts['spam_score'] then
- learn_spam = score >= train_opts['spam_score']
+ if train_opts.spam_score then
+ learn_spam = score >= train_opts.spam_score
+
+ if not learn_spam then
+ skip_reason = string.format('score < spam_score: %f < %f',
+ score, train_opts.spam_score)
+ end
else
learn_spam = verdict == 'spam' or verdict == 'junk'
+
+ if not learn_spam then
+ skip_reason = string.format('verdict: %s',
+ verdict)
+ end
end
- if train_opts['ham_score'] then
- learn_ham = score <= train_opts['ham_score']
+ if train_opts.ham_score then
+ learn_ham = score <= train_opts.ham_score
+ if not learn_ham then
+ skip_reason = string.format('score > ham_score: %f < %f',
+ score, train_opts.ham_score)
+ end
else
learn_ham = verdict == 'ham'
+
+ if not learn_ham then
+ skip_reason = string.format('verdict: %s',
+ verdict)
+ end
end
else
-- Train by request header
@@ -378,6 +400,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
learn_spam = true
elseif hdr:lower() == 'ham' then
learn_ham = true
+ else
+ skip_reason = string.format('no explicit header')
end
end
end
@@ -387,18 +411,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
local learn_type
if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
- local function learn_vec_cb(err)
- if err then
- rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
- rule.prefix, set.name, err)
- else
- rspamd_logger.infox(task, "trained ANN rule %s:%s, save %s vector",
- rule.prefix, set.name, learn_type)
- end
- end
-
local function can_train_cb(err, data)
- if not err and tonumber(data) > 0 then
+ if not err and tonumber(data) >= 0 then
local coin = math.random()
if coin < 1.0 - train_opts.train_prob then
rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
@@ -408,6 +422,17 @@ local function ann_push_task_result(rule, task, verdict, score, set)
local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+ local function learn_vec_cb(_err)
+ if _err then
+ rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+ rule.prefix, set.name, _err)
+ else
+ lua_util.debugm(N, task,
+ "add train data for ANN rule %s:%s, save %s vector of %s elts; %s bytes compressed",
+ rule.prefix, set.name, learn_type, #vec, #str)
+ end
+ end
+
lua_redis.redis_make_request(task,
rule.redis,
nil,
@@ -422,7 +447,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
rule.prefix, set.name, err)
elseif tonumber(data) < 0 then
rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
- rule.prefix, set.name, learn_type, -tonumber(data))
+ rule.prefix, set.name, learn_type, -tonumber(data))
end
end
end
@@ -436,6 +461,9 @@ local function ann_push_task_result(rule, task, verdict, score, set)
{task = task, is_write = true},
can_train_cb,
{ set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
+ else
+ lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s',
+ skip_reason)
end
end
@@ -481,6 +509,7 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
{ann_key, 'lock', '30'}
)
else
+ lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
return false -- do not plan any more updates
end
@@ -537,7 +566,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
return out
end
- rule.learning_spawned = true
+ set.learning_spawned = true
local function redis_save_cb(err)
if err then
@@ -559,7 +588,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
end
local function ann_trained(err, data)
- rule.learning_spawned = false
+ set.learning_spawned = false
if err then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
rule.prefix, set.name, err)
@@ -598,7 +627,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
}
local ucl = require "ucl"
- local profile_serialized = ucl.to_format(profile, 'json-compact')
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
@@ -695,7 +724,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
if err then
rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
ann_key, err)
- elseif type(data) == 'boolean' and data then
+ elseif type(data) == 'number' and data == 1 then
-- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
@@ -752,47 +781,52 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
ann_key, err)
else
- local _err,ann_data = rspamd_util.zstd_decompress(data[1])
- local ann
-
- if _err or not ann_data then
- rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
- rule.prefix .. ':' .. set.name, ann_key, _err)
- return
- else
- ann = rspamd_kann.load(ann_data)
-
- if ann then
- set.ann = {
- ann = ann,
- version = profile.version,
- symbols = profile.symbols,
- distance = min_diff,
- redis_key = profile.redis_key
- }
+ if type(data) == 'string' then
+ local _err,ann_data = rspamd_util.zstd_decompress(data)
+ local ann
- local ucl = require "ucl"
- local profile_serialized = ucl.to_format(profile, 'json-compact')
-
- local function rank_cb(_, _)
- -- TODO: maybe add some logging
- end
- -- Also update rank for the loaded ANN to avoid removal
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- rank_cb, --callback
- 'ZADD', -- command
- {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
- )
- rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
- rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version)
+ if _err or not ann_data then
+ rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
+ rule.prefix .. ':' .. set.name, ann_key, _err)
+ return
else
- rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s',
- rule.prefix .. ':' .. set.name, ann_key)
+ ann = rspamd_kann.load(ann_data)
+
+ if ann then
+ set.ann = {
+ ann = ann,
+ version = profile.version,
+ symbols = profile.symbols,
+ distance = min_diff,
+ redis_key = profile.redis_key
+ }
+
+ local ucl = require "ucl"
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+
+ local function rank_cb(_, _)
+ -- TODO: maybe add some logging
+ end
+ -- Also update rank for the loaded ANN to avoid removal
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ rank_cb, --callback
+ 'ZADD', -- command
+ {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
+ )
+ rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #ann_data, profile.version)
+ else
+ rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s',
+ rule.prefix, set.name, ann_key)
+ end
end
+ else
+ lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s',
+ rule.prefix, set.name, ann_key)
end
end
end
@@ -803,8 +837,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
false, -- is write
data_cb, --callback
'HGET', -- command
- {ann_key, 'ann'}, -- arguments
- {opaque_data = true}
+ {ann_key, 'ann'} -- arguments
)
end
@@ -900,23 +933,46 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
-- We have our ANN and that's train vectors, check if we can learn
local ann_key = sel_elt.redis_key
- lua_util.debugm(N, rspamd_config, "check ANN %s", ann_key)
- local redis_len_cb = function(err, data)
- if err then
- rspamd_logger.errx(rspamd_config,
- 'cannot get FANN trains %s from redis: %s', ann_key, err)
- elseif data and type(data) == 'number' or type(data) == 'string' then
- if tonumber(data) and tonumber(data) >= rule.train.max_trains then
- rspamd_logger.infox(rspamd_config,
- 'need to learn ANN %s after %s learn vectors (%s required)',
- ann_key, tonumber(data), rule.train.max_trains)
- do_train_ann(worker, ev_base, rule, set, ann_key)
- else
- rspamd_logger.debugm(N, rspamd_config,
- 'no need to learn ANN %s %s learn vectors (%s required)',
- ann_key, tonumber(data), rule.train.max_trains)
+ lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
+ ann_key)
+
+ -- Create continuation closure
+ local redis_len_cb_gen = function(cont_cb)
+ return function(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config,
+ 'cannot get ANN trains %s from redis: %s', ann_key, err)
+ elseif data and type(data) == 'number' or type(data) == 'string' then
+ if tonumber(data) and tonumber(data) >= rule.train.max_trains then
+ cont_cb()
+ else
+ rspamd_logger.debugm(N, rspamd_config,
+ 'no need to learn ANN %s %s learn vectors (%s required)',
+ ann_key, tonumber(data), rule.train.max_trains)
+ end
end
end
+
+ end
+
+ local function initiate_train()
+ rspamd_logger.infox(rspamd_config,
+ 'need to learn ANN %s after %s learn vectors (%s required)',
+ ann_key, tonumber(data), rule.train.max_trains)
+ do_train_ann(worker, ev_base, rule, set, ann_key)
+ end
+
+ -- Spam vector is OK, check ham vector length
+ local function check_ham_len()
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb_gen(initiate_train), --callback
+ 'LLEN', -- command
+ {ann_key .. '_ham'}
+ )
end
lua_redis.redis_make_request_taskless(ev_base,
@@ -924,7 +980,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
rule.redis,
nil,
false, -- is write
- redis_len_cb, --callback
+ redis_len_cb_gen(check_ham_len), --callback
'LLEN', -- command
{ann_key .. '_spam'}
)
@@ -1005,14 +1061,22 @@ local function cleanup_anns(rule, cfg, ev_base)
end
local function ann_push_vector(task)
- if task:has_flag('skip') then return end
- if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
+ if task:has_flag('skip') then
+ lua_util.debugm(N, task, 'do not push data for skipped task')
+ return
+ end
+ if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
+ lua_util.debugm(N, task, 'do not push data for manual scan')
+ return
+ end
local verdict,score = lua_util.get_task_verdict(task)
for _,rule in pairs(settings.rules) do
local set = get_rule_settings(task, rule)
if set then
ann_push_task_result(rule, task, verdict, score, set)
+ else
+ lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
end
end
@@ -1064,7 +1128,7 @@ local function process_rules_settings()
if rule.default then
local default_settings = {
- symbols = lua_util.keys(lua_settings.default_symbols()),
+ symbols = lua_settings.default_symbols(),
name = 'default'
}
@@ -1099,7 +1163,7 @@ local function process_rules_settings()
if nelt then
rule.settings[s] = nelt
- lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s; same symbols',
+ lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s',
nelt.name, rule.prefix)
end
end
More information about the Commits
mailing list