commit 362dc83: [Project] Neural: Implement PCA learning
Vsevolod Stakhov
vsevolod at highsecure.ru
Thu Aug 27 21:49:10 UTC 2020
Author: Vsevolod Stakhov
Date: 2020-08-27 15:46:51 +0100
URL: https://github.com/rspamd/rspamd/commit/362dc834f1be24b107a0f3f593e743ce2ae66a04
[Project] Neural: Implement PCA learning
---
src/plugins/lua/neural.lua | 44 ++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 44 insertions(+)
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 32a751987..91caa8e07 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -638,6 +638,44 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
)
end
+-- This is an utility function for PCA training
+local function fill_scatter(inputs, meanv)
+ local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
+ local row_len = #inputs[1]
+
+ for i=1,row_len do
+ local col = rspamd_tensor.new(1, #inputs)
+ for j=1,#inputs do
+ local x = inputs[j][i] - meanv[j]
+ col[j] = x
+ end
+ local prod = col:mul(col, false, true)
+ for ii=1,#prod do
+ for jj=1,#prod[1] do
+ scatter_matrix[ii][jj] = scatter_matrix[ii][jj] + prod[ii][jj]
+ end
+ end
+ end
+
+ return scatter_matrix
+end
+
+-- This function takes all inputs, applies PCA transformation and returns the final
+-- PCA matrix as rspamd_tensor
+local function learn_pca(inputs, max_inputs)
+ local meanv = inputs:mean()
+ local scatter_matrix = fill_scatter(inputs, meanv)
+ local eigenvals = scatter_matrix:eigen()
+ -- scatter matrix is not filled with eigenvectors
+ lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
+ local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
+ for i=1,max_inputs do
+ w[i] = scatter_matrix[#scatter_matrix - i + 1]
+ end
+
+ return w
+end
+
-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
-- Check training data sanity
@@ -715,6 +753,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
lr = rule.train.learning_rate,
max_epoch = rule.train.max_iterations,
cb = train_cb,
+ pca = set.ann.pca
})
if not seen_nan then
@@ -812,6 +851,11 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
end
end
+ if rule.max_inputs then
+ -- Train PCA in the main process, presumably it is not that long
+ set.ann.pca = learn_pca(inputs, rule.max_inputs)
+ end
+
worker:spawn_process{
func = train,
on_complete = ann_trained,
More information about the Commits
mailing list