commit 723294c: [Minor] Fix tensor multiplication for the vectors case
Vsevolod Stakhov
vsevolod at highsecure.ru
Wed Aug 19 13:07:08 UTC 2020
Author: Vsevolod Stakhov
Date: 2020-08-19 13:51:10 +0100
URL: https://github.com/rspamd/rspamd/commit/723294cbaad55a9d738adae263d347e95faca049
[Minor] Fix tensor multiplication for the vectors case
---
src/lua/lua_tensor.c | 20 +++++++++++++++++---
1 file changed, 17 insertions(+), 3 deletions(-)
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index cf91006d0..16bba985b 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -351,6 +351,7 @@ lua_tensor_index (lua_State *L)
}
}
else if (lua_isstring (L, 2)) {
+ /* Access to methods */
lua_getmetatable (L, 1);
lua_pushvalue (L, 2);
lua_rawget (L, -2);
@@ -392,7 +393,20 @@ lua_tensor_mul (lua_State *L)
dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
}
- res = lua_newtensor (L, 2, dims, true);
+ if (dims[0] == 0) {
+ /* Column */
+ dims[0] = 1;
+ res = lua_newtensor (L, 2, dims, true, true);
+ }
+ else if (dims[1] == 0) {
+ /* Row */
+ res = lua_newtensor (L, 1, dims, true, true);
+ dims[1] = 1;
+ }
+ else {
+ res = lua_newtensor (L, 2, dims, true, true);
+ }
+
kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0],
t1->data, t2->data, res->data);
}
@@ -438,7 +452,7 @@ lua_tensor_load (lua_State *L)
if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) {
if (ndims == 1) {
if (nelts == dims[0]) {
- struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+ struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
memcpy (t->data, data + sizeof (int) * 4, nelts *
sizeof (rspamd_tensor_num_t));
}
@@ -449,7 +463,7 @@ lua_tensor_load (lua_State *L)
}
else if (ndims == 2) {
if (nelts == dims[0] * dims[1]) {
- struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+ struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
memcpy (t->data, data + sizeof (int) * 4, nelts *
sizeof (rspamd_tensor_num_t));
}
More information about the Commits
mailing list