commit 7ec92c4: [Feature] Neural: Move PCA learning to a subprocess

Vsevolod Stakhov vsevolod at highsecure.ru
Fri Dec 18 16:14:10 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-12-18 16:06:53 +0000
URL: https://github.com/rspamd/rspamd/commit/7ec92c421a5d1da6d26b6f2cdfed3cc585481155

[Feature] Neural: Move PCA learning to a subprocess

---
 lualib/plugins/neural.lua | 50 ++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 39 insertions(+), 11 deletions(-)

diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
index 4d4c44b5d..6f82089a4 100644
--- a/lualib/plugins/neural.lua
+++ b/lualib/plugins/neural.lua
@@ -23,6 +23,7 @@ local rspamd_kann = require "rspamd_kann"
 local rspamd_logger = require "rspamd_logger"
 local rspamd_tensor = require "rspamd_tensor"
 local rspamd_util = require "rspamd_util"
+local ucl = require "ucl"
 
 local N = 'neural'
 
@@ -464,12 +465,22 @@ local function spawn_train(params)
       lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
           params.rule.prefix, params.set.name)
 
+      local pca
+      if params.rule.max_inputs then
+        -- Train PCA in the main process, presumably it is not that long
+        lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
+            params.rule.prefix, params.set.name)
+        pca = learn_pca(inputs, params.rule.max_inputs)
+      end
+
+      lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
+          params.rule.prefix, params.set.name)
       local ret,err = pcall(train_ann.train1, train_ann,
           inputs, outputs, {
             lr = params.rule.train.learning_rate,
             max_epoch = params.rule.train.max_iterations,
             cb = train_cb,
-            pca = (params.set.ann or {}).pca
+            pca = pca
           })
 
       if not ret then
@@ -477,11 +488,26 @@ local function spawn_train(params)
             params.rule.prefix, params.set.name, err)
 
         return nil
+      else
+        lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
+            params.rule.prefix, params.set.name)
       end
 
       if not seen_nan then
-        local out = train_ann:save()
-        return out
+        -- Convert to strings as ucl cannot rspamd_text properly
+        local pca_data
+        if pca then
+          pca_data = tostring(pca:save())
+        end
+        local out = {
+          ann_data = tostring(train_ann:save()),
+          pca_data = pca_data,
+        }
+
+        local final_data = ucl.to_format(out, 'msgpack')
+        lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
+            params.rule.prefix, params.set.name, #final_data)
+        return final_data
       else
         return nil
       end
@@ -523,15 +549,20 @@ local function spawn_train(params)
             {params.ann_key, 'lock'}
         )
       else
-        local ann_data = rspamd_util.zstd_compress(data)
-        local pca_data
+        local parser = ucl.parser()
+        local ok, parse_err = parser:parse_text(data, 'msgpack')
+        assert(ok, parse_err)
+        local parsed = parser:get_object()
+        local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
+        local pca_data = parsed.pca_data
 
         fill_set_ann(params.set, params.ann_key)
-        if params.set.ann.pca then
-          pca_data = rspamd_util.zstd_compress(params.set.ann.pca:save())
+        if pca_data then
+          params.set.ann.pca = rspamd_tensor.load(pca_data)
+          pca_data = rspamd_util.zstd_compress(pca_data)
         end
         -- Deserialise ANN from the child process
-        ann_trained = rspamd_kann.load(data)
+        ann_trained = rspamd_kann.load(parsed.ann_data)
         local version = (params.set.ann.version or 0) + 1
         params.set.ann.version = version
         params.set.ann.ann = ann_trained
@@ -545,7 +576,6 @@ local function spawn_train(params)
           version = version
         }
 
-        local ucl = require "ucl"
         local profile_serialized = ucl.to_format(profile, 'json-compact', true)
 
         rspamd_logger.infox(rspamd_config,
@@ -572,8 +602,6 @@ local function spawn_train(params)
 
     if params.rule.max_inputs then
       fill_set_ann(params.set, params.ann_key)
-      -- Train PCA in the main process, presumably it is not that long
-      params.set.ann.pca = learn_pca(inputs, params.rule.max_inputs)
     end
 
     params.worker:spawn_process{


More information about the Commits mailing list