commit f5a453a: [Minor] Lua_tensor: Add transpose and mean methods

Vsevolod Stakhov vsevolod at highsecure.ru
Tue Aug 25 14:49:09 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-25 15:14:58 +0100
URL: https://github.com/rspamd/rspamd/commit/f5a453a97f446996a91e793c54c6144cfbc15522

[Minor] Lua_tensor: Add transpose and mean methods

---
 src/lua/lua_tensor.c | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 88 insertions(+)

diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index d14ec8831..09a10cabc 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -37,6 +37,7 @@ LUA_FUNCTION_DEF (tensor, newindex);
 LUA_FUNCTION_DEF (tensor, len);
 LUA_FUNCTION_DEF (tensor, eugen);
 LUA_FUNCTION_DEF (tensor, mean);
+LUA_FUNCTION_DEF (tensor, transpose);
 
 static luaL_reg rspamd_tensor_f[] = {
 		LUA_INTERFACE_DEF (tensor, load),
@@ -57,6 +58,7 @@ static luaL_reg rspamd_tensor_m[] = {
 		{"__len", lua_tensor_len},
 		LUA_INTERFACE_DEF (tensor, eugen),
 		LUA_INTERFACE_DEF (tensor, mean),
+		LUA_INTERFACE_DEF (tensor, transpose),
 		{NULL, NULL},
 };
 
@@ -625,6 +627,92 @@ lua_tensor_eugen (lua_State *L)
 	return 1;
 }
 
+static inline rspamd_tensor_num_t
+mean_vec (rspamd_tensor_num_t *x, int n)
+{
+	rspamd_tensor_num_t s = 0;
+	rspamd_tensor_num_t c = 0;
+
+	for (int i = 0; i < n; i ++) {
+		rspamd_tensor_num_t v = x[i];
+		rspamd_tensor_num_t y = v - c;
+		rspamd_tensor_num_t t = s + y;
+		c = (t - s) - y;
+		s = t;
+	}
+
+	return s / (rspamd_tensor_num_t)n;
+}
+
+static gint
+lua_tensor_mean (lua_State *L)
+{
+	struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+
+	if (t) {
+		if (t->ndims == 1) {
+			/* Mean of all elements in a vector */
+			lua_pushnumber (L, mean_vec (t->data, t->dim[0]));
+		}
+		else {
+			/* Row-wise mean vector output */
+			struct rspamd_lua_tensor *res;
+
+			res = lua_newtensor (L, 1, &t->dim[0], false, true);
+
+			for (int i = 0; i < t->dim[0]; i ++) {
+				res->data[i] = mean_vec (&t->data[i * t->dim[1]], t->dim[1]);
+			}
+		}
+	}
+	else {
+		return luaL_error (L, "invalid arguments");
+	}
+
+	return 1;
+}
+
+static gint
+lua_tensor_transpose (lua_State *L)
+{
+	struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res;
+	int dims[2];
+
+	if (t) {
+		if (t->ndims == 1) {
+			/* Row to column */
+			dims[0] = 1;
+			dims[1] = t->dim[0];
+			res = lua_newtensor (L, 2, dims, false, true);
+			memcpy (res->data, t->data, t->dim[0] * sizeof (rspamd_tensor_num_t));
+		}
+		else {
+			/* Cache friendly algorithm */
+			struct rspamd_lua_tensor *res;
+
+			dims[0] = t->dim[1];
+			dims[1] = t->dim[0];
+			res = lua_newtensor (L, 2, dims, false, true);
+
+			static const int block = 32;
+
+			for (int i = 0; i < t->dim[0]; i += block) {
+				for(int j = 0; j < t->dim[1]; ++j) {
+					for(int boff = 0; boff < block && i + boff < t->dim[0]; ++boff) {
+						res->data[j * t->dim[0] + i + boff] =
+								t->data[(i + boff) * t->dim[1] + j];
+					}
+				}
+			}
+		}
+	}
+	else {
+		return luaL_error (L, "invalid arguments");
+	}
+
+	return 1;
+}
+
 static gint
 lua_load_tensor (lua_State * L)
 {


More information about the Commits mailing list