commit 71e5848: [Minor] Allow to use lua_tensor in kann apply

Vsevolod Stakhov vsevolod at highsecure.ru
Fri Aug 21 20:49:06 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-21 16:11:32 +0100
URL: https://github.com/rspamd/rspamd/commit/71e58489aa8efbc883ba961b1f1cf15eebec3c87

[Minor] Allow to use lua_tensor in kann apply

---
 src/lua/lua_kann.c   | 102 +++++++++++++++++++++++++++++++++++----------------
 src/lua/lua_tensor.c |   2 +-
 src/lua/lua_tensor.h |   2 +
 3 files changed, 74 insertions(+), 32 deletions(-)

diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index 33036fe04..1827fe1ac 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -15,6 +15,7 @@
  */
 
 #include "lua_common.h"
+#include "lua_tensor.h"
 #include "contrib/kann/kann.h"
 
 /***
@@ -1155,48 +1156,87 @@ lua_kann_apply1 (lua_State *L)
 {
 	kann_t *k = lua_check_kann (L, 1);
 
-	if (k && lua_istable (L, 2)) {
-		gsize vec_len = rspamd_lua_table_size (L, 2);
-		float *vec = (float *)g_malloc (sizeof (float) * vec_len);
-		int i_out;
-		int n_in = kann_dim_in (k);
+	if (k) {
+		if (lua_istable (L, 2)) {
+			gsize vec_len = rspamd_lua_table_size (L, 2);
+			float *vec = (float *) g_malloc (sizeof (float) * vec_len);
+			int i_out;
+			int n_in = kann_dim_in (k);
 
-		if (n_in <= 0) {
-			return luaL_error (L, "invalid inputs count: %d", n_in);
-		}
+			if (n_in <= 0) {
+				return luaL_error (L, "invalid inputs count: %d", n_in);
+			}
 
-		if (n_in != vec_len) {
-			return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
-					(int)vec_len, n_in);
-		}
+			if (n_in != vec_len) {
+				return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+						(int) vec_len, n_in);
+			}
 
-		for (gsize i = 0; i < vec_len; i ++) {
-			lua_rawgeti (L, 2, i + 1);
-			vec[i] = lua_tonumber (L, -1);
-			lua_pop (L, 1);
-		}
+			for (gsize i = 0; i < vec_len; i++) {
+				lua_rawgeti (L, 2, i + 1);
+				vec[i] = lua_tonumber (L, -1);
+				lua_pop (L, 1);
+			}
 
-		i_out = kann_find (k, KANN_F_OUT, 0);
+			i_out = kann_find (k, KANN_F_OUT, 0);
+
+			if (i_out <= 0) {
+				g_free (vec);
+				return luaL_error (L, "invalid ANN: output layer is missing or is "
+									  "at the input pos");
+			}
+
+			kann_set_batch_size (k, 1);
+			kann_feed_bind (k, KANN_F_IN, 0, &vec);
+			kad_eval_at (k->n, k->v, i_out);
+
+			gsize outlen = kad_len (k->v[i_out]);
+			lua_createtable (L, outlen, 0);
+
+			for (gsize i = 0; i < outlen; i++) {
+				lua_pushnumber (L, k->v[i_out]->x[i]);
+				lua_rawseti (L, -2, i + 1);
+			}
 
-		if (i_out <= 0) {
 			g_free (vec);
-			return luaL_error (L, "invalid ANN: output layer is missing or is "
-						 "at the input pos");
 		}
+		else if (lua_isuserdata (L, 2)) {
+			struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);
 
-		kann_set_batch_size (k, 1);
-		kann_feed_bind (k, KANN_F_IN, 0, &vec);
-		kad_eval_at (k->n, k->v, i_out);
+			if (t && t->ndims == 1) {
+				int i_out;
+				int n_in = kann_dim_in (k);
 
-		gsize outlen = kad_len (k->v[i_out]);
-		lua_createtable (L, outlen, 0);
+				if (n_in != t->dim[0]) {
+					return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+							(int) t->dim[0], n_in);
+				}
 
-		for (gsize i = 0; i < outlen; i ++) {
-			lua_pushnumber (L, k->v[i_out]->x[i]);
-			lua_rawseti (L, -2, i + 1);
-		}
+				i_out = kann_find (k, KANN_F_OUT, 0);
+
+				if (i_out <= 0) {
+					return luaL_error (L, "invalid ANN: output layer is missing or is "
+										  "at the input pos");
+				}
 
-		g_free (vec);
+				kann_set_batch_size (k, 1);
+				kann_feed_bind (k, KANN_F_IN, 0, &t->data);
+				kad_eval_at (k->n, k->v, i_out);
+
+				gint outlen = kad_len (k->v[i_out]);
+				struct rspamd_lua_tensor *out;
+				out = lua_newtensor (L, 1, &outlen, false, false);
+				/* Ensure that kann and tensor have the same understanding of floats */
+				G_STATIC_ASSERT (sizeof (float) == sizeof (rspamd_tensor_num_t));
+				memcpy (out->data, k->v[i_out]->x, outlen * sizeof (float));
+			}
+			else {
+				return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
+			}
+		}
+		else {
+			return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
+		}
 	}
 	else {
 		return luaL_error (L, "invalid arguments: rspamd{kann} expected");
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index 6e5bec7d8..1506d4548 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -54,7 +54,7 @@ static luaL_reg rspamd_tensor_m[] = {
 		{NULL, NULL},
 };
 
-static struct rspamd_lua_tensor *
+struct rspamd_lua_tensor *
 lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own)
 {
 	struct rspamd_lua_tensor *res;
diff --git a/src/lua/lua_tensor.h b/src/lua/lua_tensor.h
index e4c110011..e022f64b9 100644
--- a/src/lua/lua_tensor.h
+++ b/src/lua/lua_tensor.h
@@ -28,5 +28,7 @@ struct rspamd_lua_tensor {
 };
 
 struct rspamd_lua_tensor *lua_check_tensor (lua_State *L, int pos);
+struct rspamd_lua_tensor *lua_newtensor (lua_State *L, int ndims,
+		const int *dim, bool zero_fill, bool own);
 
 #endif


More information about the Commits mailing list