commit ff02466: [Project] Add ANN load function

Vsevolod Stakhov vsevolod at highsecure.ru
Sat Jul 6 19:56:04 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-07-06 09:39:48 +0100
URL: https://github.com/rspamd/rspamd/commit/ff024667584fa62f342507647ef0beab5917cf58

[Project] Add ANN load function

---
 src/plugins/lua/neural.lua | 45 ++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 44 insertions(+), 1 deletion(-)

diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index ff53249c5..cca6f647c 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -725,7 +725,50 @@ end
 -- serialize profile one more time and set its rank to the current time
 -- set.ann fields are set according to Redis data received
 local function load_new_ann(rule, ev_base, set, profile, min_diff)
+  local ann_key = profile.ann_key
 
+  local function data_cb(err, data)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
+          ann_key, err)
+    else
+      local _err,ann_data = rspamd_util.zstd_decompress(data[1])
+      local ann
+
+      if _err or not ann_data then
+        rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
+            rule.prefix .. ':' .. set.name, ann_key, _err)
+        return
+      else
+        ann = rspamd_kann.load(ann_data)
+
+        if ann then
+          set.ann = {
+            ann = ann,
+            version = profile.version,
+            symbols = profile.symbols,
+            distance = min_diff
+          }
+
+          rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
+              rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version)
+        else
+          rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s',
+              rule.prefix .. ':' .. set.name, ann_key)
+        end
+      end
+    end
+  end
+  lua_redis.redis_make_request_taskless(ev_base,
+      rspamd_config,
+      rule.redis,
+      nil,
+      false, -- is write
+      data_cb, --callback
+      'HGET', -- command
+      {ann_key, 'ann'}, -- arguments
+      {opaque_data = true}
+  )
 end
 
 -- Used to check an element in Redis serialized as JSON
@@ -740,7 +783,7 @@ local function process_existing_ann(rule, ev_base, set, profiles)
   for _,elt in fun.iter(profiles) do
     if elt and elt.symbols then
       local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
-
+      -- Check distance
       if dist < #my_symbols * .3 then
         if dist < min_diff then
           min_diff = dist


More information about the Commits mailing list