commit 7ec92c4: [Feature] Neural: Move PCA learning to a subprocess
Vsevolod Stakhov
vsevolod at highsecure.ru
Fri Dec 18 16:14:10 UTC 2020
Author: Vsevolod Stakhov
Date: 2020-12-18 16:06:53 +0000
URL: https://github.com/rspamd/rspamd/commit/7ec92c421a5d1da6d26b6f2cdfed3cc585481155
[Feature] Neural: Move PCA learning to a subprocess
---
lualib/plugins/neural.lua | 50 ++++++++++++++++++++++++++++++++++++-----------
1 file changed, 39 insertions(+), 11 deletions(-)
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
index 4d4c44b5d..6f82089a4 100644
--- a/lualib/plugins/neural.lua
+++ b/lualib/plugins/neural.lua
@@ -23,6 +23,7 @@ local rspamd_kann = require "rspamd_kann"
local rspamd_logger = require "rspamd_logger"
local rspamd_tensor = require "rspamd_tensor"
local rspamd_util = require "rspamd_util"
+local ucl = require "ucl"
local N = 'neural'
@@ -464,12 +465,22 @@ local function spawn_train(params)
lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
params.rule.prefix, params.set.name)
+ local pca
+ if params.rule.max_inputs then
+ -- Train PCA in the main process, presumably it is not that long
+ lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
+ pca = learn_pca(inputs, params.rule.max_inputs)
+ end
+
+ lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
local ret,err = pcall(train_ann.train1, train_ann,
inputs, outputs, {
lr = params.rule.train.learning_rate,
max_epoch = params.rule.train.max_iterations,
cb = train_cb,
- pca = (params.set.ann or {}).pca
+ pca = pca
})
if not ret then
@@ -477,11 +488,26 @@ local function spawn_train(params)
params.rule.prefix, params.set.name, err)
return nil
+ else
+ lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
end
if not seen_nan then
- local out = train_ann:save()
- return out
+ -- Convert to strings as ucl cannot rspamd_text properly
+ local pca_data
+ if pca then
+ pca_data = tostring(pca:save())
+ end
+ local out = {
+ ann_data = tostring(train_ann:save()),
+ pca_data = pca_data,
+ }
+
+ local final_data = ucl.to_format(out, 'msgpack')
+ lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
+ params.rule.prefix, params.set.name, #final_data)
+ return final_data
else
return nil
end
@@ -523,15 +549,20 @@ local function spawn_train(params)
{params.ann_key, 'lock'}
)
else
- local ann_data = rspamd_util.zstd_compress(data)
- local pca_data
+ local parser = ucl.parser()
+ local ok, parse_err = parser:parse_text(data, 'msgpack')
+ assert(ok, parse_err)
+ local parsed = parser:get_object()
+ local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
+ local pca_data = parsed.pca_data
fill_set_ann(params.set, params.ann_key)
- if params.set.ann.pca then
- pca_data = rspamd_util.zstd_compress(params.set.ann.pca:save())
+ if pca_data then
+ params.set.ann.pca = rspamd_tensor.load(pca_data)
+ pca_data = rspamd_util.zstd_compress(pca_data)
end
-- Deserialise ANN from the child process
- ann_trained = rspamd_kann.load(data)
+ ann_trained = rspamd_kann.load(parsed.ann_data)
local version = (params.set.ann.version or 0) + 1
params.set.ann.version = version
params.set.ann.ann = ann_trained
@@ -545,7 +576,6 @@ local function spawn_train(params)
version = version
}
- local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
rspamd_logger.infox(rspamd_config,
@@ -572,8 +602,6 @@ local function spawn_train(params)
if params.rule.max_inputs then
fill_set_ann(params.set, params.ann_key)
- -- Train PCA in the main process, presumably it is not that long
- params.set.ann.pca = learn_pca(inputs, params.rule.max_inputs)
end
params.worker:spawn_process{
More information about the Commits
mailing list