commit 1d26ec3: [Project] Neural: Implement PCA in learning
Vsevolod Stakhov
vsevolod at highsecure.ru
Thu Aug 27 21:49:15 UTC 2020
Author: Vsevolod Stakhov
Date: 2020-08-27 22:00:38 +0100
URL: https://github.com/rspamd/rspamd/commit/1d26ec302293d9eeeafb5e14f4d3a0d73c126f4f
[Project] Neural: Implement PCA in learning
---
src/plugins/lua/neural.lua | 62 ++++++++++++++++++++++++++++++++--------------
1 file changed, 43 insertions(+), 19 deletions(-)
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 91caa8e07..5b4ff8b3b 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -639,10 +639,17 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
end
-- This is an utility function for PCA training
-local function fill_scatter(inputs, meanv)
+local function fill_scatter(inputs)
local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
local row_len = #inputs[1]
+ if type(inputs) == 'table' then
+ -- Convert to a tensor
+ inputs = rspamd_tensor.fromtable(inputs)
+ end
+
+ local meanv = inputs:mean()
+
for i=1,row_len do
local col = rspamd_tensor.new(1, #inputs)
for j=1,#inputs do
@@ -663,8 +670,7 @@ end
-- This function takes all inputs, applies PCA transformation and returns the final
-- PCA matrix as rspamd_tensor
local function learn_pca(inputs, max_inputs)
- local meanv = inputs:mean()
- local scatter_matrix = fill_scatter(inputs, meanv)
+ local scatter_matrix = fill_scatter(inputs)
local eigenvals = scatter_matrix:eigen()
-- scatter matrix is not filled with eigenvectors
lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
@@ -676,6 +682,19 @@ local function learn_pca(inputs, max_inputs)
return w
end
+-- Fills ANN data for a specific settings element
+local function fill_set_ann(set, ann_key)
+ if not set.ann then
+ set.ann = {
+ symbols = set.symbols,
+ distance = 0,
+ digest = set.digest,
+ redis_key = ann_key,
+ version = 0,
+ }
+ end
+end
+
-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
-- Check training data sanity
@@ -684,7 +703,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
meta_functions.rspamd_count_metatokens()
-- Now we can train ann
- local train_ann = create_ann(n, 3, rule)
+ local train_ann = create_ann(rule.max_inputs or n, 3, rule)
if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
-- Invalidate ANN as it is definitely invalid
@@ -749,12 +768,23 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
end
end
- train_ann:train1(inputs, outputs, {
- lr = rule.train.learning_rate,
- max_epoch = rule.train.max_iterations,
- cb = train_cb,
- pca = set.ann.pca
- })
+ lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
+ rule.prefix, set.name)
+
+ local ret,err = pcall(train_ann.train1, train_ann,
+ inputs, outputs, {
+ lr = rule.train.learning_rate,
+ max_epoch = rule.train.max_iterations,
+ cb = train_cb,
+ pca = set.ann.pca
+ })
+
+ if not ret then
+ rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
+ rule.prefix, set.name, err)
+
+ return nil
+ end
if not seen_nan then
local out = train_ann:save()
@@ -806,14 +836,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
if set.ann.pca then
pca_data = rspamd_util.zstd_compress(set.ann.pca:save())
end
- if not set.ann then
- set.ann = {
- symbols = set.symbols,
- distance = 0,
- digest = set.digest,
- redis_key = ann_key,
- }
- end
+ fill_set_ann(set, ann_key)
-- Deserialise ANN from the child process
ann_trained = rspamd_kann.load(data)
local version = (set.ann.version or 0) + 1
@@ -852,6 +875,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
end
if rule.max_inputs then
+ fill_set_ann(set, ann_key)
-- Train PCA in the main process, presumably it is not that long
set.ann.pca = learn_pca(inputs, rule.max_inputs)
end
@@ -1045,7 +1069,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
end
- if set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
+ if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
-- PCA table
local _err,pca_data = rspamd_util.zstd_decompress(data[2])
if pca_data then
More information about the Commits
mailing list