commit d13c306: [Project] Tensor: Move scatter matrix calculation to C

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Aug 31 14:49:17 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-31 15:34:44 +0100
URL: https://github.com/rspamd/rspamd/commit/d13c3065b8e2ea5c3e4beec04f5d4ed5c5b84515

[Project] Tensor: Move scatter matrix calculation to C

---
 src/lua/lua_tensor.c | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 73 insertions(+)

diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index e918188eb..06b7cdffe 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -39,12 +39,14 @@ LUA_FUNCTION_DEF (tensor, eigen);
 LUA_FUNCTION_DEF (tensor, mean);
 LUA_FUNCTION_DEF (tensor, transpose);
 LUA_FUNCTION_DEF (tensor, has_blas);
+LUA_FUNCTION_DEF (tensor, scatter_matrix);
 
 static luaL_reg rspamd_tensor_f[] = {
 		LUA_INTERFACE_DEF (tensor, load),
 		LUA_INTERFACE_DEF (tensor, new),
 		LUA_INTERFACE_DEF (tensor, fromtable),
 		LUA_INTERFACE_DEF (tensor, has_blas),
+		LUA_INTERFACE_DEF (tensor, scatter_matrix),
 		{NULL, NULL},
 };
 
@@ -636,6 +638,7 @@ mean_vec (rspamd_tensor_num_t *x, int n)
 	rspamd_tensor_num_t s = 0;
 	rspamd_tensor_num_t c = 0;
 
+	/* https://en.wikipedia.org/wiki/Kahan_summation_algorithm */
 	for (int i = 0; i < n; i ++) {
 		rspamd_tensor_num_t v = x[i];
 		rspamd_tensor_num_t y = v - c;
@@ -728,6 +731,76 @@ lua_tensor_has_blas (lua_State *L)
 	return 1;
 }
 
+static gint
+lua_tensor_scatter_matrix (lua_State *L)
+{
+	struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res;
+	int dims[2];
+
+	if (t) {
+		if (t->ndims != 2) {
+			return luaL_error (L, "matrix required");
+		}
+
+		/* X * X square matrix */
+		dims[0] = t->dim[1];
+		dims[1] = t->dim[1];
+		res = lua_newtensor (L, 2, dims, true, true);
+
+		/* Auxiliary vars */
+		rspamd_tensor_num_t *means, /* means vector */
+			*tmp_row, /* temp row for Kahan's algorithm */
+			*tmp_square /* temp matrix for multiplications */;
+		means = g_malloc0 (sizeof (rspamd_tensor_num_t) * t->dim[1]);
+		tmp_row = g_malloc0 (sizeof (rspamd_tensor_num_t) * t->dim[1]);
+		tmp_square = g_malloc (sizeof (rspamd_tensor_num_t) * t->dim[1] * t->dim[1]);
+
+		/*
+		 * Column based means
+		 * means will have s, tmp_row will have c
+		 */
+		for (int i = 0; i < t->dim[0]; i ++) {
+			/* Cycle by rows */
+			for (int j = 0; j < t->dim[1]; j ++) {
+				rspamd_tensor_num_t v = t->data[i * t->dim[1] + j];
+				rspamd_tensor_num_t y = v - tmp_row[j];
+				rspamd_tensor_num_t st = means[j] + y;
+				tmp_row[j] = (st - means[j]) - y;
+				means[j] = st;
+			}
+		}
+
+		for (int j = 0; j < t->dim[1]; j ++) {
+			means[j] /= t->dim[0];
+		}
+
+		for (int i = 0; i < t->dim[0]; i ++) {
+			/* Update for each sample */
+			for (int j = 0; j < t->dim[1]; j ++) {
+				tmp_row[j] = t->data[i * t->dim[1] + j] - means[j];
+			}
+
+			memset (tmp_square, 0, t->dim[1] * t->dim[1] * sizeof (rspamd_tensor_num_t));
+			kad_sgemm_simple (1, 0, t->dim[1], t->dim[1], 1,
+					tmp_row, tmp_row, tmp_square);
+
+			for (int j = 0; j < t->dim[1]; j ++) {
+				kad_saxpy (t->dim[1], 1.0, &tmp_square[j * t->dim[1]],
+						&res->data[j * t->dim[1]]);
+			}
+		}
+
+		g_free (tmp_row);
+		g_free (means);
+		g_free (tmp_square);
+	}
+	else {
+		return luaL_error (L, "tensor required");
+	}
+
+	return 1;
+}
+
 static gint
 lua_load_tensor (lua_State * L)
 {


More information about the Commits mailing list