commit 362dc83: [Project] Neural: Implement PCA learning

Vsevolod Stakhov vsevolod at highsecure.ru
Thu Aug 27 21:49:10 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-27 15:46:51 +0100
URL: https://github.com/rspamd/rspamd/commit/362dc834f1be24b107a0f3f593e743ce2ae66a04

[Project] Neural: Implement PCA learning

---
 src/plugins/lua/neural.lua | 44 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 44 insertions(+)

diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 32a751987..91caa8e07 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -638,6 +638,44 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
   )
 end
 
+-- This is an utility function for PCA training
+local function fill_scatter(inputs, meanv)
+  local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
+  local row_len = #inputs[1]
+
+  for i=1,row_len do
+    local col = rspamd_tensor.new(1, #inputs)
+    for j=1,#inputs do
+      local x = inputs[j][i] - meanv[j]
+      col[j] = x
+    end
+    local prod = col:mul(col, false, true)
+    for ii=1,#prod do
+      for jj=1,#prod[1] do
+        scatter_matrix[ii][jj] = scatter_matrix[ii][jj] + prod[ii][jj]
+      end
+    end
+  end
+
+  return scatter_matrix
+end
+
+-- This function takes all inputs, applies PCA transformation and returns the final
+-- PCA matrix as rspamd_tensor
+local function learn_pca(inputs, max_inputs)
+  local meanv = inputs:mean()
+  local scatter_matrix = fill_scatter(inputs, meanv)
+  local eigenvals = scatter_matrix:eigen()
+  -- scatter matrix is not filled with eigenvectors
+  lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
+  local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
+  for i=1,max_inputs do
+    w[i] = scatter_matrix[#scatter_matrix - i + 1]
+  end
+
+  return w
+end
+
 -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
 local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
   -- Check training data sanity
@@ -715,6 +753,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
         lr = rule.train.learning_rate,
         max_epoch = rule.train.max_iterations,
         cb = train_cb,
+        pca = set.ann.pca
       })
 
       if not seen_nan then
@@ -812,6 +851,11 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
       end
     end
 
+    if rule.max_inputs then
+      -- Train PCA in the main process, presumably it is not that long
+      set.ann.pca = learn_pca(inputs, rule.max_inputs)
+    end
+
     worker:spawn_process{
       func = train,
       on_complete = ann_trained,


More information about the Commits mailing list