commit 844acef: [Project] Neural: Add PCA loading logic
Vsevolod Stakhov
vsevolod at highsecure.ru
Thu Aug 27 21:49:08 UTC 2020
Author: Vsevolod Stakhov
Date: 2020-08-27 15:35:42 +0100
URL: https://github.com/rspamd/rspamd/commit/844acefdab6da55af0b371419ca8039f2bd78d29
[Project] Neural: Add PCA loading logic
---
src/plugins/lua/neural.lua | 118 ++++++++++++++++++++++++++++-----------------
1 file changed, 75 insertions(+), 43 deletions(-)
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index a3027662c..352d397d5 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -22,6 +22,7 @@ end
local rspamd_logger = require "rspamd_logger"
local rspamd_util = require "rspamd_util"
local rspamd_kann = require "rspamd_kann"
+local rspamd_text = require "rspamd_text"
local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
local rspamd_tensor = require "rspamd_tensor"
@@ -71,6 +72,7 @@ local redis_profile_schema = ts.shape{
}
local has_blas = rspamd_tensor.has_blas()
+local text_cookie = rspamd_text.cookie
-- Rule structure:
-- * static config fields (see `default_options`)
@@ -327,7 +329,7 @@ local function ann_scores_filter(task)
local vec = result_to_vector(task, profile)
local score
- local out = ann:apply1(vec)
+ local out = ann:apply1(vec, set.ann.pca)
score = out[1]
local symscore = string.format('%.3f', score)
@@ -940,52 +942,81 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
ann_key, err)
else
- if type(data) == 'string' then
- local _err,ann_data = rspamd_util.zstd_decompress(data)
- local ann
-
- if _err or not ann_data then
- rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
- rule.prefix .. ':' .. set.name, ann_key, _err)
- return
+ if type(data) == 'table' then
+ if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
+ local _err,ann_data = rspamd_util.zstd_decompress(data[1])
+ local ann
+
+ if _err or not ann_data then
+ rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
+ rule.prefix .. ':' .. set.name, ann_key, _err)
+ return
+ else
+ ann = rspamd_kann.load(ann_data)
+
+ if ann then
+ set.ann = {
+ digest = profile.digest,
+ version = profile.version,
+ symbols = profile.symbols,
+ distance = min_diff,
+ redis_key = profile.redis_key
+ }
+
+ local ucl = require "ucl"
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+ set.ann.ann = ann -- To avoid serialization
+
+ local function rank_cb(_, _)
+ -- TODO: maybe add some logging
+ end
+ -- Also update rank for the loaded ANN to avoid removal
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ rank_cb, --callback
+ 'ZADD', -- command
+ {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
+ )
+ rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #ann_data, profile.version)
+ else
+ rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
+ rule.prefix, set.name, ann_key)
+ end
+ end
else
- ann = rspamd_kann.load(ann_data)
-
- if ann then
- set.ann = {
- digest = profile.digest,
- version = profile.version,
- symbols = profile.symbols,
- distance = min_diff,
- redis_key = profile.redis_key
- }
-
- local ucl = require "ucl"
- local profile_serialized = ucl.to_format(profile, 'json-compact', true)
- set.ann.ann = ann -- To avoid serialization
-
- local function rank_cb(_, _)
- -- TODO: maybe add some logging
+ lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
+ rule.prefix, set.name, ann_key)
+ end
+ if set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
+ -- PCA table
+ local _err,pca_data = rspamd_util.zstd_decompress(data[2])
+ if pca_data then
+ if rule.max_inputs then
+ -- We can use PCA
+ set.ann.pca = rspamd_tensor.load(pca_data)
+ else
+ -- no need in pca, why is it there?
+ rspamd_logger.warnx(rspamd_config, 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
+ rule.prefix, set.name, ann_key)
end
- -- Also update rank for the loaded ANN to avoid removal
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- rank_cb, --callback
- 'ZADD', -- command
- {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
- )
- rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #ann_data, profile.version)
else
- rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s',
- rule.prefix, set.name, ann_key)
+ -- pca can be missing merely if we have no max_inputs
+ if rule.max_inputs then
+ rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
+ rule.prefix, set.name, ann_key, _err)
+ set.ann.ann = nil
+ else
+ -- It is okay
+ set.ann.pca = nil
+ end
end
end
else
- lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s',
+ lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
end
end
@@ -996,8 +1027,9 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
nil,
false, -- is write
data_cb, --callback
- 'HGET', -- command
- {ann_key, 'ann'} -- arguments
+ 'HMGET', -- command
+ {ann_key, 'ann', 'pca'}, -- arguments
+ {opaque_data = true}
)
end
More information about the Commits
mailing list