commit 414c7b4: [Minor] Add printing and fix multiplication

Vsevolod Stakhov vsevolod at highsecure.ru
Wed Aug 5 20:07:12 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-05 16:05:40 +0100
URL: https://github.com/rspamd/rspamd/commit/414c7b4ff70f4bbe934166709d29fc37389e20be

[Minor] Add printing and fix multiplication

---
 src/lua/lua_tensor.c | 71 +++++++++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 62 insertions(+), 9 deletions(-)

diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index e8aebd180..21bdf9673 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -31,6 +31,7 @@ LUA_FUNCTION_DEF (tensor, new);
 LUA_FUNCTION_DEF (tensor, fromtable);
 LUA_FUNCTION_DEF (tensor, destroy);
 LUA_FUNCTION_DEF (tensor, mul);
+LUA_FUNCTION_DEF (tensor, tostring);
 
 static luaL_reg rspamd_tensor_f[] = {
 		LUA_INTERFACE_DEF (tensor, load),
@@ -44,6 +45,8 @@ static luaL_reg rspamd_tensor_m[] = {
 		{"__gc", lua_tensor_destroy},
 		{"__mul", lua_tensor_mul},
 		{"mul", lua_tensor_mul},
+		{"__tostring", lua_tensor_tostring},
+		{"tostring", lua_tensor_tostring},
 		{NULL, NULL},
 };
 
@@ -114,12 +117,14 @@ lua_tensor_fromtable (lua_State *L)
 		if (lua_isnumber (L, -1)) {
 			lua_pop (L, 1);
 			/* Input vector */
-			gint dim = rspamd_lua_table_size (L, 1);
+			gint dims[2];
+			dims[0] = 1;
+			dims[1] = rspamd_lua_table_size (L, 1);
 
-			struct rspamd_lua_tensor *res = lua_newtensor (L, 1,
-					&dim, false);
+			struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
+					dims, false);
 
-			for (guint i = 0; i < dim; i ++) {
+			for (guint i = 0; i < dims[1]; i ++) {
 				lua_rawgeti (L, 1, i + 1);
 				res->data[i] = lua_tonumber (L, -1);
 				lua_pop (L, 1);
@@ -168,8 +173,8 @@ lua_tensor_fromtable (lua_State *L)
 			}
 
 			gint dims[2];
-			dims[0] = ncols;
-			dims[1] = nrows;
+			dims[0] = nrows;
+			dims[1] = ncols;
 
 			struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
 					dims, false);
@@ -238,6 +243,47 @@ lua_tensor_save (lua_State *L)
 	return 1;
 }
 
+static gint
+lua_tensor_tostring (lua_State *L)
+{
+	struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+
+	if (t) {
+		GString *out = g_string_sized_new (128);
+
+		if (t->ndims == 1) {
+			/* Print as a vector */
+			for (gint i = 0; i < t->dim[0]; i ++) {
+				rspamd_printf_gstring (out, "%.4f ", t->data[i]);
+			}
+			/* Trim last space */
+			out->len --;
+		}
+		else {
+			for (gint i = 0; i < t->dim[0]; i ++) {
+				for (gint j = 0; j < t->dim[1]; j ++) {
+					rspamd_printf_gstring (out, "%.4f ",
+							t->data[i * t->dim[1] + j]);
+				}
+				/* Trim last space */
+				out->len --;
+				rspamd_printf_gstring (out, "\n");
+			}
+			/* Trim last ; */
+			out->len --;
+		}
+
+		lua_pushlstring (L, out->str, out->len);
+
+		g_string_free (out, TRUE);
+	}
+	else {
+		return luaL_error (L, "invalid arguments");
+	}
+
+	return 1;
+}
+
 /***
  * @method tensor:mul(other, [transA, [transB]])
  * Multiply two tensors (optionally transposed) and return a new tensor
@@ -259,12 +305,19 @@ lua_tensor_mul (lua_State *L)
 	}
 
 	if (t1 && t2) {
-		gint dims[2];
+		gint dims[2], shadow_dims[2];
 		dims[0] = transA ? t1->dim[1] : t1->dim[0];
+		shadow_dims[0] = transB ? t2->dim[1] : t2->dim[0];
 		dims[1] = transB ? t2->dim[0] : t2->dim[1];
+		shadow_dims[1] = transA ? t1->dim[0] : t1->dim[1];
+
+		if (shadow_dims[0] != shadow_dims[1]) {
+			return luaL_error (L, "incompatible dimensions %d x %d * %d x %d",
+					dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
+		}
 
-		res = lua_newtensor (L, 2, dims, false);
-		kad_sgemm_simple (transA, transB, t1->dim[1], t2->dim[0], t1->dim[0],
+		res = lua_newtensor (L, 2, dims, true);
+		kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0],
 				t1->data, t2->data, res->data);
 	}
 	else {


More information about the Commits mailing list