commit 6d9c4be: [Project] Neural: Implement PCA on ANN forward

Vsevolod Stakhov vsevolod at highsecure.ru
Fri Aug 28 11:49:05 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-28 12:43:29 +0100
URL: https://github.com/rspamd/rspamd/commit/6d9c4bed090e852b871d74443c8d34c4fa87a56e (HEAD -> master)

[Project] Neural: Implement PCA on ANN forward

---
 src/lua/lua_kann.c | 50 +++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 45 insertions(+), 5 deletions(-)

diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index db12e1f87..30bff538a 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -1205,21 +1205,48 @@ static int
 lua_kann_apply1 (lua_State *L)
 {
 	kann_t *k = lua_check_kann (L, 1);
+	struct rspamd_lua_tensor *pca = NULL;
 
 	if (k) {
 		if (lua_istable (L, 2)) {
 			gsize vec_len = rspamd_lua_table_size (L, 2);
-			float *vec = (float *) g_malloc (sizeof (float) * vec_len);
+			float *vec = (float *) g_malloc (sizeof (float) * vec_len),
+				*pca_out = NULL;
 			int i_out;
 			int n_in = kann_dim_in (k);
 
 			if (n_in <= 0) {
+				g_free (vec);
 				return luaL_error (L, "invalid inputs count: %d", n_in);
 			}
 
-			if (n_in != vec_len) {
-				return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
-						(int) vec_len, n_in);
+			if (lua_isuserdata (L, 3)) {
+				pca = lua_check_tensor (L, 3);
+
+				if (pca) {
+					if (pca->ndims != 2) {
+						g_free (vec);
+						return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
+					}
+
+					if (pca->dim[0] != n_in) {
+						g_free (vec);
+						return luaL_error (L, "invalid pca tensor: "
+											  "matrix must have %d rows and it has %d rows instead",
+								n_in, pca->dim[0]);
+					}
+				}
+				else {
+					g_free (vec);
+					return luaL_error (L, "invalid params: pca matrix expected");
+				}
+			}
+			else {
+				if (n_in != vec_len) {
+					g_free (vec);
+					return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+							(int) vec_len, n_in);
+				}
 			}
 
 			for (gsize i = 0; i < vec_len; i++) {
@@ -1237,7 +1264,19 @@ lua_kann_apply1 (lua_State *L)
 			}
 
 			kann_set_batch_size (k, 1);
-			kann_feed_bind (k, KANN_F_IN, 0, &vec);
+			if (pca) {
+				pca_out = g_malloc (sizeof (float) * n_in);
+
+				kad_sgemm_simple (0, 0, pca->dim[0], 1,
+						pca->dim[1], pca->data,
+						vec, pca_out);
+
+				kann_feed_bind (k, KANN_F_IN, 0, &pca_out);
+			}
+			else {
+				kann_feed_bind (k, KANN_F_IN, 0, &vec);
+			}
+
 			kad_eval_at (k->n, k->v, i_out);
 
 			gsize outlen = kad_len (k->v[i_out]);
@@ -1249,6 +1288,7 @@ lua_kann_apply1 (lua_State *L)
 			}
 
 			g_free (vec);
+			g_free (pca_out);
 		}
 		else if (lua_isuserdata (L, 2)) {
 			struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);


More information about the Commits mailing list