commit 723294c: [Minor] Fix tensor multiplication for the vectors case

Vsevolod Stakhov vsevolod at highsecure.ru
Wed Aug 19 13:07:08 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-19 13:51:10 +0100
URL: https://github.com/rspamd/rspamd/commit/723294cbaad55a9d738adae263d347e95faca049

[Minor] Fix tensor multiplication for the vectors case

---
 src/lua/lua_tensor.c | 20 +++++++++++++++++---
 1 file changed, 17 insertions(+), 3 deletions(-)

diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index cf91006d0..16bba985b 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -351,6 +351,7 @@ lua_tensor_index (lua_State *L)
 			}
 		}
 		else if (lua_isstring (L, 2)) {
+			/* Access to methods */
 			lua_getmetatable (L, 1);
 			lua_pushvalue (L, 2);
 			lua_rawget (L, -2);
@@ -392,7 +393,20 @@ lua_tensor_mul (lua_State *L)
 					dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
 		}
 
-		res = lua_newtensor (L, 2, dims, true);
+		if (dims[0] == 0) {
+			/* Column */
+			dims[0] = 1;
+			res = lua_newtensor (L, 2, dims, true, true);
+		}
+		else if (dims[1] == 0) {
+			/* Row */
+			res = lua_newtensor (L, 1, dims, true, true);
+			dims[1] = 1;
+		}
+		else {
+			res = lua_newtensor (L, 2, dims, true, true);
+		}
+
 		kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0],
 				t1->data, t2->data, res->data);
 	}
@@ -438,7 +452,7 @@ lua_tensor_load (lua_State *L)
 		if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) {
 			if (ndims == 1) {
 				if (nelts == dims[0]) {
-					struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+					struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
 					memcpy (t->data, data + sizeof (int) * 4, nelts *
 							sizeof (rspamd_tensor_num_t));
 				}
@@ -449,7 +463,7 @@ lua_tensor_load (lua_State *L)
 			}
 			else if (ndims == 2) {
 				if (nelts == dims[0] * dims[1]) {
-					struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+					struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
 					memcpy (t->data, data + sizeof (int) * 4, nelts *
 							sizeof (rspamd_tensor_num_t));
 				}


More information about the Commits mailing list