commit ca75dba: [Feature] Neural: Add sampling when storing training vectors
Vsevolod Stakhov
vsevolod at highsecure.ru
Fri Oct 18 16:21:06 UTC 2019
Author: Vsevolod Stakhov
Date: 2019-10-18 17:08:44 +0100
URL: https://github.com/rspamd/rspamd/commit/ca75dbad6bcb55084cc1104c2dc5ef1109c270c0
[Feature] Neural: Add sampling when storing training vectors
---
src/plugins/lua/neural.lua | 114 ++++++++++++++++++++++++++++-----------------
1 file changed, 72 insertions(+), 42 deletions(-)
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 0b93cd4a7..7acb0eca3 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -97,7 +97,7 @@ end
-- key1 - ann key
-- key2 - spam or ham
-- key3 - maximum trains
--- returns 1 or 0: 1 - allow learn, 0 - not allow learn
+-- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
local redis_lua_script_can_store_train_vec = [[
local prefix = KEYS[1]
local locked = redis.call('HGET', prefix, 'lock')
@@ -114,19 +114,33 @@ local redis_lua_script_can_store_train_vec = [[
if KEYS[2] == 'spam' then
if nspam <= lim then
- return tostring(nspam)
- else
- return tostring(-(nspam))
+ if nspam > nham then
+ -- Apply sampling
+ local skip_rate = 1.0 - nham / (nspam + 1)
+ if math.random() < skip_rate then
+ return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}
+ end
+ end
+ return {tostring(nspam),'can learn'}
+ else -- Enough learns
+ return {tostring(-(nspam)),'too many spam samples'}
end
else
if nham <= lim then
- return tostring(nham)
+ if nsham > nspam then
+ -- Apply sampling
+ local skip_rate = 1.0 - nspam / (nham + 1)
+ if math.random() < skip_rate then
+ return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}
+ end
+ end
+ return {tostring(nham),'can learn'}
else
- return tostring(-(nham))
+ return {tostring(-(nham)),'too many ham samples'}
end
end
- return tostring(0)
+ return {tostring(0),'bad input'}
]]
local redis_can_store_train_vec_id = nil
@@ -416,45 +430,50 @@ local function ann_push_task_result(rule, task, verdict, score, set)
if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
local function can_train_cb(err, data)
- if not err and tonumber(data) >= 0 then
- local coin = math.random()
- if coin < 1.0 - train_opts.train_prob then
- rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
- return
- end
- local vec = result_to_vector(task, set)
+ if not err and type(data) == 'table' then
+ local nsamples,reason = tonumber(data[1]),data[2]
- local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
- local target_key = set.ann.redis_key .. '_' .. learn_type
+ if nsamples > 0 then
+ local coin = math.random()
- local function learn_vec_cb(_err)
- if _err then
- rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
- rule.prefix, set.name, _err)
- else
- lua_util.debugm(N, task,
- "add train data for ANN rule " ..
- "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
- rule.prefix, set.name, learn_type, #vec, target_key, #str)
+ if coin < 1.0 - train_opts.train_prob then
+ rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+ return
end
- end
- lua_redis.redis_make_request(task,
- rule.redis,
- nil,
- true, -- is write
- learn_vec_cb, --callback
- 'LPUSH', -- command
- { target_key, str } -- arguments
- )
- else
- if err then
- rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
- rule.prefix, set.name, err)
- elseif tonumber(data) < 0 then
- rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
- rule.prefix, set.name, learn_type, -tonumber(data))
+ local vec = result_to_vector(task, set)
+
+ local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+ local target_key = set.ann.redis_key .. '_' .. learn_type
+
+ local function learn_vec_cb(_err)
+ if _err then
+ rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+ rule.prefix, set.name, _err)
+ else
+ lua_util.debugm(N, task,
+ "add train data for ANN rule " ..
+ "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+ rule.prefix, set.name, learn_type, #vec, target_key, #str)
+ end
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ true, -- is write
+ learn_vec_cb, --callback
+ 'LPUSH', -- command
+ { target_key, str } -- arguments
+ )
+ else
+ -- Negative result returned
+ rspamd_logger.infox(task, "cannot learn ANN %s:%s: %s (%s vectors stored)",
+ rule.prefix, set.name, learn_type, reason, -tonumber(nsamples))
end
+ else
+ rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
+ rule.prefix, set.name, err)
end
end
@@ -466,7 +485,11 @@ local function ann_push_task_result(rule, task, verdict, score, set)
lua_redis.exec_redis_script(redis_can_store_train_vec_id,
{task = task, is_write = true},
can_train_cb,
- { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
+ {
+ set.ann.redis_key,
+ learn_type,
+ tostring(train_opts.max_trains),
+ })
else
lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s',
skip_reason)
@@ -1128,6 +1151,13 @@ local function ann_push_vector(task)
return
end
+ if score ~= score then
+ lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
+ verdict)
+
+ return
+ end
+
for _,rule in pairs(settings.rules) do
local set = get_rule_settings(task, rule)
More information about the Commits
mailing list