commit 414c7b4: [Minor] Add printing and fix multiplication
Vsevolod Stakhov
vsevolod at highsecure.ru
Wed Aug 5 20:07:12 UTC 2020
Author: Vsevolod Stakhov
Date: 2020-08-05 16:05:40 +0100
URL: https://github.com/rspamd/rspamd/commit/414c7b4ff70f4bbe934166709d29fc37389e20be
[Minor] Add printing and fix multiplication
---
src/lua/lua_tensor.c | 71 +++++++++++++++++++++++++++++++++++++++++++++-------
1 file changed, 62 insertions(+), 9 deletions(-)
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index e8aebd180..21bdf9673 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -31,6 +31,7 @@ LUA_FUNCTION_DEF (tensor, new);
LUA_FUNCTION_DEF (tensor, fromtable);
LUA_FUNCTION_DEF (tensor, destroy);
LUA_FUNCTION_DEF (tensor, mul);
+LUA_FUNCTION_DEF (tensor, tostring);
static luaL_reg rspamd_tensor_f[] = {
LUA_INTERFACE_DEF (tensor, load),
@@ -44,6 +45,8 @@ static luaL_reg rspamd_tensor_m[] = {
{"__gc", lua_tensor_destroy},
{"__mul", lua_tensor_mul},
{"mul", lua_tensor_mul},
+ {"__tostring", lua_tensor_tostring},
+ {"tostring", lua_tensor_tostring},
{NULL, NULL},
};
@@ -114,12 +117,14 @@ lua_tensor_fromtable (lua_State *L)
if (lua_isnumber (L, -1)) {
lua_pop (L, 1);
/* Input vector */
- gint dim = rspamd_lua_table_size (L, 1);
+ gint dims[2];
+ dims[0] = 1;
+ dims[1] = rspamd_lua_table_size (L, 1);
- struct rspamd_lua_tensor *res = lua_newtensor (L, 1,
- &dim, false);
+ struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
+ dims, false);
- for (guint i = 0; i < dim; i ++) {
+ for (guint i = 0; i < dims[1]; i ++) {
lua_rawgeti (L, 1, i + 1);
res->data[i] = lua_tonumber (L, -1);
lua_pop (L, 1);
@@ -168,8 +173,8 @@ lua_tensor_fromtable (lua_State *L)
}
gint dims[2];
- dims[0] = ncols;
- dims[1] = nrows;
+ dims[0] = nrows;
+ dims[1] = ncols;
struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
dims, false);
@@ -238,6 +243,47 @@ lua_tensor_save (lua_State *L)
return 1;
}
+static gint
+lua_tensor_tostring (lua_State *L)
+{
+ struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+
+ if (t) {
+ GString *out = g_string_sized_new (128);
+
+ if (t->ndims == 1) {
+ /* Print as a vector */
+ for (gint i = 0; i < t->dim[0]; i ++) {
+ rspamd_printf_gstring (out, "%.4f ", t->data[i]);
+ }
+ /* Trim last space */
+ out->len --;
+ }
+ else {
+ for (gint i = 0; i < t->dim[0]; i ++) {
+ for (gint j = 0; j < t->dim[1]; j ++) {
+ rspamd_printf_gstring (out, "%.4f ",
+ t->data[i * t->dim[1] + j]);
+ }
+ /* Trim last space */
+ out->len --;
+ rspamd_printf_gstring (out, "\n");
+ }
+ /* Trim last ; */
+ out->len --;
+ }
+
+ lua_pushlstring (L, out->str, out->len);
+
+ g_string_free (out, TRUE);
+ }
+ else {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ return 1;
+}
+
/***
* @method tensor:mul(other, [transA, [transB]])
* Multiply two tensors (optionally transposed) and return a new tensor
@@ -259,12 +305,19 @@ lua_tensor_mul (lua_State *L)
}
if (t1 && t2) {
- gint dims[2];
+ gint dims[2], shadow_dims[2];
dims[0] = transA ? t1->dim[1] : t1->dim[0];
+ shadow_dims[0] = transB ? t2->dim[1] : t2->dim[0];
dims[1] = transB ? t2->dim[0] : t2->dim[1];
+ shadow_dims[1] = transA ? t1->dim[0] : t1->dim[1];
+
+ if (shadow_dims[0] != shadow_dims[1]) {
+ return luaL_error (L, "incompatible dimensions %d x %d * %d x %d",
+ dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
+ }
- res = lua_newtensor (L, 2, dims, false);
- kad_sgemm_simple (transA, transB, t1->dim[1], t2->dim[0], t1->dim[0],
+ res = lua_newtensor (L, 2, dims, true);
+ kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0],
t1->data, t2->data, res->data);
}
else {
More information about the Commits
mailing list