commit 913ac14: [Project] Neural: Fix PCA based learning

Vsevolod Stakhov vsevolod at highsecure.ru
Thu Aug 27 22:56:06 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-27 23:51:38 +0100
URL: https://github.com/rspamd/rspamd/commit/913ac147bbc4e706095003fb8c16d24e2187a77f (HEAD -> master)

[Project] Neural: Fix PCA based learning

---
 src/lua/lua_kann.c         | 88 ++++++++++++++++++++++++++++++++++++----------
 src/plugins/lua/neural.lua | 21 ++++++-----
 2 files changed, 81 insertions(+), 28 deletions(-)

diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index 1827fe1ac..db12e1f87 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -1023,6 +1023,7 @@ static int
 lua_kann_train1 (lua_State *L)
 {
 	kann_t *k = lua_check_kann (L, 1);
+	struct rspamd_lua_tensor *pca = NULL;
 
 	/* Default train params */
 	double lr = 0.001;
@@ -1055,8 +1056,8 @@ lua_kann_train1 (lua_State *L)
 
 			if (!rspamd_lua_parse_table_arguments (L, 4, &err,
 					RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
-					"lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F",
-					&lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) {
+					"lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}",
+					&lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) {
 				n = luaL_error (L, "invalid params: %s",
 						err ? err->message : "unknown error");
 				g_error_free (err);
@@ -1065,36 +1066,83 @@ lua_kann_train1 (lua_State *L)
 			}
 		}
 
-		float **x, **y;
+		if (pca) {
+			/* Check pca matrix validity */
+			if (pca->ndims != 2) {
+				return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
+			}
+
+			if (pca->dim[0] != n_in) {
+				return luaL_error (L, "invalid pca tensor: "
+						  "matrix must have %d rows and it has %d rows instead",
+						  n_in, pca->dim[0]);
+			}
+		}
 
-		/* Fill vectors */
+		float **x, **y, *tmp_row = NULL;
+
+		/* Fill vectors row by row */
 		x = (float **)g_malloc0 (sizeof (float *) * n);
 		y = (float **)g_malloc0 (sizeof (float *) * n);
 
+		if (pca) {
+			tmp_row = g_malloc (sizeof (float) * pca->dim[1]);
+		}
+
 		for (int s = 0; s < n; s ++) {
 			/* Inputs */
 			lua_rawgeti (L, 2, s + 1);
 			x[s] = (float *)g_malloc (sizeof (float) * n_in);
 
-			if (rspamd_lua_table_size (L, -1) != n_in) {
-				FREE_VEC (x, n);
-				FREE_VEC (y, n);
+			if (pca == NULL) {
+				if (rspamd_lua_table_size (L, -1) != n_in) {
+					FREE_VEC (x, n);
+					FREE_VEC (y, n);
 
-				lua_pop (L, 1);
-				n = luaL_error (L, "invalid params at pos %d: "
-					   "bad input dimension %d; %d expected",
-						s + 1,
-						(int)rspamd_lua_table_size (L, -1),
-						n_in);
+					n = luaL_error (L, "invalid params at pos %d: "
+									   "bad input dimension %d; %d expected",
+							s + 1,
+							(int) rspamd_lua_table_size (L, -1),
+							n_in);
+					lua_pop (L, 1);
 
-				return n;
+					return n;
+				}
+
+				for (int i = 0; i < n_in; i++) {
+					lua_rawgeti (L, -1, i + 1);
+					x[s][i] = lua_tonumber (L, -1);
+
+					lua_pop (L, 1);
+				}
 			}
+			else {
+				if (rspamd_lua_table_size (L, -1) != pca->dim[1]) {
+					FREE_VEC (x, n);
+					FREE_VEC (y, n);
+					g_free (tmp_row);
+
+					n = luaL_error (L, "(pca on) invalid params at pos %d: "
+									   "bad input dimension %d; %d expected",
+							s + 1,
+							(int) rspamd_lua_table_size (L, -1),
+							pca->dim[1]);
+					lua_pop (L, 1);
 
-			for (int i = 0; i < n_in; i ++) {
-				lua_rawgeti (L, -1, i + 1);
-				x[s][i] = lua_tonumber (L, -1);
+					return n;
+				}
 
-				lua_pop (L, 1);
+
+				for (int i = 0; i < pca->dim[1]; i++) {
+					lua_rawgeti (L, -1, i + 1);
+					tmp_row[i] = lua_tonumber (L, -1);
+
+					lua_pop (L, 1);
+				}
+
+				kad_sgemm_simple (0, 0, pca->dim[0], 1,
+						pca->dim[1], pca->data,
+						tmp_row, x[s]);
 			}
 
 			lua_pop (L, 1);
@@ -1104,9 +1152,9 @@ lua_kann_train1 (lua_State *L)
 			lua_rawgeti (L, 3, s + 1);
 
 			if (rspamd_lua_table_size (L, -1) != n_out) {
-				lua_pop (L, 1);
 				FREE_VEC (x, n);
 				FREE_VEC (y, n);
+				g_free (tmp_row);
 
 				n = luaL_error (L, "invalid params at pos %d: "
 					   "bad output dimension %d; "
@@ -1114,6 +1162,7 @@ lua_kann_train1 (lua_State *L)
 						s + 1,
 						(int)rspamd_lua_table_size (L, -1),
 						n_out);
+				lua_pop (L, 1);
 
 				return n;
 			}
@@ -1142,6 +1191,7 @@ lua_kann_train1 (lua_State *L)
 
 		FREE_VEC (x, n);
 		FREE_VEC (y, n);
+		g_free (tmp_row);
 	}
 	else {
 		return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 5b4ff8b3b..0258fb0b0 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -640,17 +640,15 @@ end
 
 -- This is an utility function for PCA training
 local function fill_scatter(inputs)
-  local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
-  local row_len = #inputs[1]
+  local scatter_matrix = rspamd_tensor.new(2, #inputs[1], #inputs[1])
+  local nsamples = #inputs
 
-  if type(inputs) == 'table' then
-    -- Convert to a tensor
-    inputs = rspamd_tensor.fromtable(inputs)
-  end
+  -- Convert to a tensor where each row is an input dimension
+  inputs = rspamd_tensor.fromtable(inputs):transpose()
 
   local meanv = inputs:mean()
 
-  for i=1,row_len do
+  for i=1,nsamples do
     local col = rspamd_tensor.new(1, #inputs)
     for j=1,#inputs do
       local x = inputs[j][i] - meanv[j]
@@ -679,6 +677,8 @@ local function learn_pca(inputs, max_inputs)
     w[i] = scatter_matrix[#scatter_matrix - i + 1]
   end
 
+  lua_util.debugm(N, 'pca matrix: %s', w)
+
   return w
 end
 
@@ -856,8 +856,11 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
         local profile_serialized = ucl.to_format(profile, 'json-compact', true)
 
         rspamd_logger.infox(rspamd_config,
-            'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',
-            rule.prefix, set.name, #data, set.ann.redis_key, ann_key)
+            'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+            rule.prefix, set.name,
+            #data, #ann_data,
+            #(set.ann.pca or {}), #(pca_data or {}),
+            set.ann.redis_key, ann_key)
 
         lua_redis.exec_redis_script(redis_save_unlock_id,
             {ev_base = ev_base, is_write = true},


More information about the Commits mailing list