commit 568fd73: [Project] Add a simple matrix Lua library

Vsevolod Stakhov vsevolod at highsecure.ru
Wed Aug 5 20:07:10 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-05 15:36:41 +0100
URL: https://github.com/rspamd/rspamd/commit/568fd73ab8ba4cf27a2a8c068885444e87cbd148

[Project] Add a simple matrix Lua library

---
 src/lua/CMakeLists.txt                             |   3 +-
 src/lua/lua_common.c                               |   1 +
 src/lua/lua_common.h                               |   2 +
 src/lua/lua_kann.c                                 |   2 +
 src/lua/lua_tensor.c                               | 323 +++++++++++++++++++++
 .../base64/base64.h => lua/lua_tensor.h}           |  25 +-
 6 files changed, 343 insertions(+), 13 deletions(-)

diff --git a/src/lua/CMakeLists.txt b/src/lua/CMakeLists.txt
index 30f5008fa..84c819c2d 100644
--- a/src/lua/CMakeLists.txt
+++ b/src/lua/CMakeLists.txt
@@ -31,6 +31,7 @@ SET(LUASRC			  ${CMAKE_CURRENT_SOURCE_DIR}/lua_common.c
 					  ${CMAKE_CURRENT_SOURCE_DIR}/lua_text.c
 		 			  ${CMAKE_CURRENT_SOURCE_DIR}/lua_worker.c
 					  ${CMAKE_CURRENT_SOURCE_DIR}/lua_kann.c
-					  ${CMAKE_CURRENT_SOURCE_DIR}/lua_spf.c)
+					  ${CMAKE_CURRENT_SOURCE_DIR}/lua_spf.c
+					  ${CMAKE_CURRENT_SOURCE_DIR}/lua_tensor.c)
 
 SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)
\ No newline at end of file
diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c
index c658181ee..92da12d52 100644
--- a/src/lua/lua_common.c
+++ b/src/lua/lua_common.c
@@ -972,6 +972,7 @@ rspamd_lua_init (bool wipe_mem)
 	luaopen_worker (L);
 	luaopen_kann (L);
 	luaopen_spf (L);
+	luaopen_tensor (L);
 #ifndef WITH_LUAJIT
 	rspamd_lua_add_preload (L, "bit", luaopen_bit);
 	lua_settop (L, 0);
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h
index cc4c9a9d4..4c9e9dd4f 100644
--- a/src/lua/lua_common.h
+++ b/src/lua/lua_common.h
@@ -365,6 +365,8 @@ void luaopen_kann (lua_State *L);
 
 void luaopen_spf (lua_State *L);
 
+void luaopen_tensor (lua_State *L);
+
 void rspamd_lua_dostring (const gchar *line);
 
 double rspamd_lua_normalize (struct rspamd_config *cfg,
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index e356f5912..33036fe04 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -1079,6 +1079,7 @@ lua_kann_train1 (lua_State *L)
 				FREE_VEC (x, n);
 				FREE_VEC (y, n);
 
+				lua_pop (L, 1);
 				n = luaL_error (L, "invalid params at pos %d: "
 					   "bad input dimension %d; %d expected",
 						s + 1,
@@ -1102,6 +1103,7 @@ lua_kann_train1 (lua_State *L)
 			lua_rawgeti (L, 3, s + 1);
 
 			if (rspamd_lua_table_size (L, -1) != n_out) {
+				lua_pop (L, 1);
 				FREE_VEC (x, n);
 				FREE_VEC (y, n);
 
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
new file mode 100644
index 000000000..e8aebd180
--- /dev/null
+++ b/src/lua/lua_tensor.c
@@ -0,0 +1,323 @@
+/*-
+ * Copyright 2020 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lua_common.h"
+#include "lua_tensor.h"
+#include "contrib/kann/kautodiff.h"
+
+/***
+ * @module rspamd_tensor
+ * `rspamd_tensor` is a simple Lua library to abstract matrices and vectors
+ * Internally, they are represented as arrays of float variables
+ * So far, merely 1D and 2D tensors are supported
+ */
+
+LUA_FUNCTION_DEF (tensor, load);
+LUA_FUNCTION_DEF (tensor, save);
+LUA_FUNCTION_DEF (tensor, new);
+LUA_FUNCTION_DEF (tensor, fromtable);
+LUA_FUNCTION_DEF (tensor, destroy);
+LUA_FUNCTION_DEF (tensor, mul);
+
+static luaL_reg rspamd_tensor_f[] = {
+		LUA_INTERFACE_DEF (tensor, load),
+		LUA_INTERFACE_DEF (tensor, new),
+		LUA_INTERFACE_DEF (tensor, fromtable),
+		{NULL, NULL},
+};
+
+static luaL_reg rspamd_tensor_m[] = {
+		LUA_INTERFACE_DEF (tensor, save),
+		{"__gc", lua_tensor_destroy},
+		{"__mul", lua_tensor_mul},
+		{"mul", lua_tensor_mul},
+		{NULL, NULL},
+};
+
+static struct rspamd_lua_tensor *
+lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill)
+{
+	struct rspamd_lua_tensor *res;
+
+	res = lua_newuserdata (L, sizeof (struct rspamd_lua_tensor));
+
+	res->ndims = ndims;
+	res->size = 1;
+
+	for (guint i = 0; i < ndims; i ++) {
+		res->size *= dim[i];
+		res->dim[i] = dim[i];
+	}
+
+	/* To avoid allocating large stuff in Lua */
+	res->data = g_malloc (sizeof (rspamd_tensor_num_t) * res->size);
+
+	if (zero_fill) {
+		memset (res->data, 0, sizeof (rspamd_tensor_num_t) * res->size);
+	}
+
+	rspamd_lua_setclass (L, TENSOR_CLASS, -1);
+
+	return res;
+}
+
+/***
+ * @function tensor.new(ndims, [dim1, ... dimN])
+ * Creates a new zero filled tensor with the specific number of dimensions
+ * @return
+ */
+static gint
+lua_tensor_new (lua_State *L)
+{
+	gint ndims = luaL_checkinteger (L, 1);
+
+	if (ndims > 0 && ndims <= 2) {
+		gint *dims = g_alloca (sizeof (gint) * ndims);
+
+		for (guint i = 0; i < ndims; i ++) {
+			dims[i] = lua_tointeger (L, i + 2);
+		}
+
+		(void)lua_newtensor (L, ndims, dims, true);
+	}
+	else {
+		return luaL_error (L, "incorrect dimensions number: %d", ndims);
+	}
+
+	return 1;
+}
+
+/***
+ * @function tensor.fromtable(tbl)
+ * Creates a new zero filled tensor with the specific number of dimensions
+ * @return
+ */
+static gint
+lua_tensor_fromtable (lua_State *L)
+{
+	if (lua_istable (L, 1)) {
+		lua_rawgeti (L, 1, 1);
+
+		if (lua_isnumber (L, -1)) {
+			lua_pop (L, 1);
+			/* Input vector */
+			gint dim = rspamd_lua_table_size (L, 1);
+
+			struct rspamd_lua_tensor *res = lua_newtensor (L, 1,
+					&dim, false);
+
+			for (guint i = 0; i < dim; i ++) {
+				lua_rawgeti (L, 1, i + 1);
+				res->data[i] = lua_tonumber (L, -1);
+				lua_pop (L, 1);
+			}
+		}
+		else if (lua_istable (L, -1)) {
+			/* Input matrix */
+			lua_pop (L, 1);
+
+			/* Calculate the overall size */
+			gint nrows = rspamd_lua_table_size (L, 1), ncols = 0;
+			gint err;
+
+			for (gint i = 0; i < nrows; i ++) {
+				lua_rawgeti (L, 1, i + 1);
+
+				if (ncols == 0) {
+					ncols = rspamd_lua_table_size (L, -1);
+
+					if (ncols == 0) {
+						lua_pop (L, 1);
+						err = luaL_error (L, "invalid params at pos %d: "
+										   "bad input dimension %d",
+								i,
+								(int)ncols);
+
+						return err;
+					}
+				}
+				else {
+					if (ncols != rspamd_lua_table_size (L, -1)) {
+						gint t = rspamd_lua_table_size (L, -1);
+
+						lua_pop (L, 1);
+						err = luaL_error (L, "invalid params at pos %d: "
+											 "bad input dimension %d; %d expected",
+								i,
+								t,
+								ncols);
+
+						return err;
+					}
+				}
+
+				lua_pop (L, 1);
+			}
+
+			gint dims[2];
+			dims[0] = ncols;
+			dims[1] = nrows;
+
+			struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
+					dims, false);
+
+			for (gint i = 0; i < nrows; i ++) {
+				lua_rawgeti (L, 1, i + 1);
+
+				for (gint j = 0; j < ncols; j++) {
+					lua_rawgeti (L, -1, j + 1);
+
+					res->data[i * ncols + j] = lua_tonumber (L, -1);
+
+					lua_pop (L, 1);
+				}
+
+				lua_pop (L, 1);
+			}
+		}
+		else {
+			lua_pop (L, 1);
+			return luaL_error (L, "incorrect table");
+		}
+	}
+	else {
+		return luaL_error (L, "incorrect input");
+	}
+
+	return 1;
+}
+
+
+/***
+ * @method tensor:destroy()
+ * Tensor destructor
+ * @return
+ */
+static gint
+lua_tensor_destroy (lua_State *L)
+{
+	struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+
+	if (t) {
+		g_free (t->data);
+	}
+
+	return 0;
+}
+
+/***
+ * @method tensor:save()
+ * Tensor serialisation function
+ * @return
+ */
+static gint
+lua_tensor_save (lua_State *L)
+{
+	struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+
+	if (t) {
+
+	}
+	else {
+		return luaL_error (L, "invalid arguments");
+	}
+
+	return 1;
+}
+
+/***
+ * @method tensor:mul(other, [transA, [transB]])
+ * Multiply two tensors (optionally transposed) and return a new tensor
+ * @return
+ */
+static gint
+lua_tensor_mul (lua_State *L)
+{
+	struct rspamd_lua_tensor *t1 = lua_check_tensor (L, 1),
+			*t2 = lua_check_tensor (L, 2), *res;
+	int transA = 0, transB = 0;
+
+	if (lua_isboolean (L, 3)) {
+		transA = lua_toboolean (L, 3);
+	}
+
+	if (lua_isboolean (L, 4)) {
+		transB = lua_toboolean (L, 4);
+	}
+
+	if (t1 && t2) {
+		gint dims[2];
+		dims[0] = transA ? t1->dim[1] : t1->dim[0];
+		dims[1] = transB ? t2->dim[0] : t2->dim[1];
+
+		res = lua_newtensor (L, 2, dims, false);
+		kad_sgemm_simple (transA, transB, t1->dim[1], t2->dim[0], t1->dim[0],
+				t1->data, t2->data, res->data);
+	}
+	else {
+		return luaL_error (L, "invalid arguments");
+	}
+
+	return 1;
+}
+
+/***
+ * @function tensor.load(rspamd_text)
+ * Deserialize tensor
+ * @return
+ */
+static gint
+lua_tensor_load (lua_State *L)
+{
+	struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+
+	if (t) {
+
+	}
+	else {
+		return luaL_error (L, "invalid arguments");
+	}
+
+	return 1;
+}
+
+static gint
+lua_load_tensor (lua_State * L)
+{
+	lua_newtable (L);
+	luaL_register (L, NULL, rspamd_tensor_f);
+
+	return 1;
+}
+
+
+void luaopen_tensor (lua_State *L)
+{
+	/* Metatables */
+	rspamd_lua_new_class (L, TENSOR_CLASS, rspamd_tensor_m);
+	lua_pop (L, 1); /* No need in metatable... */
+	rspamd_lua_add_preload (L, "rspamd_tensor", lua_load_tensor);
+	lua_settop (L, 0);
+}
+
+struct rspamd_lua_tensor *
+lua_check_tensor (lua_State *L, int pos)
+{
+	void *ud = rspamd_lua_check_udata (L, pos, TENSOR_CLASS);
+	luaL_argcheck (L, ud != NULL, pos, "'tensor' expected");
+	return ud ? ((struct rspamd_lua_tensor *)ud) : NULL;
+}
+
diff --git a/src/libcryptobox/base64/base64.h b/src/lua/lua_tensor.h
similarity index 61%
copy from src/libcryptobox/base64/base64.h
copy to src/lua/lua_tensor.h
index e2be379b5..554245f0b 100644
--- a/src/libcryptobox/base64/base64.h
+++ b/src/lua/lua_tensor.h
@@ -1,5 +1,5 @@
 /*-
- * Copyright 2016 Vsevolod Stakhov
+ * Copyright 2020 Vsevolod Stakhov
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,19 +13,20 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#ifndef SRC_LIBCRYPTOBOX_BASE64_BASE64_H_
-#define SRC_LIBCRYPTOBOX_BASE64_BASE64_H_
+#ifndef RSPAMD_LUA_TENSOR_H
+#define RSPAMD_LUA_TENSOR_H
 
-#include "config.h"
+#define TENSOR_CLASS "rspamd{tensor}"
 
-#ifdef  __cplusplus
-extern "C" {
-#endif
+typedef float rspamd_tensor_num_t;
 
-const char *base64_load (void);
+struct rspamd_lua_tensor {
+	int ndims;
+	int size; /* overall size (product of dims) */
+	rspamd_tensor_num_t *data;
+	int dim[2];
+};
 
-#ifdef  __cplusplus
-}
-#endif
+struct rspamd_lua_tensor *lua_check_tensor (lua_State *L, int pos);
 
-#endif /* SRC_LIBCRYPTOBOX_BASE64_BASE64_H_ */
+#endif


More information about the Commits mailing list