commit c8c4280: [Project] Add some missing functions to kann API

Vsevolod Stakhov vsevolod at highsecure.ru
Sat Jun 29 16:28:03 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-06-29 12:35:00 +0100
URL: https://github.com/rspamd/rspamd/commit/c8c4280fee8cc896a98dccb26626b7853a18aa7d

[Project] Add some missing functions to kann API

---
 contrib/kann/kann.h |   2 +
 src/lua/lua_kann.c  | 171 +++++++++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 171 insertions(+), 2 deletions(-)

diff --git a/contrib/kann/kann.h b/contrib/kann/kann.h
index 1605e5ea5..7ec748561 100644
--- a/contrib/kann/kann.h
+++ b/contrib/kann/kann.h
@@ -210,6 +210,8 @@ kad_node_t *kann_new_bias(int n);
 kad_node_t *kann_new_weight_conv2d(int n_out, int n_in, int k_row, int k_col);
 kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len);
 
+kad_node_t *kann_new_leaf_array(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, int32_t d[KAD_MAX_DIM]);
+
 kad_node_t *kann_new_leaf2(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, ...);
 kad_node_t *kann_layer_dense2(int *offset, kad_node_p *par, kad_node_t *in, int n1);
 kad_node_t *kann_layer_dropout2(int *offset, kad_node_p *par, kad_node_t *t, float r);
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index cec75acbb..171c81454 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -23,6 +23,7 @@
  */
 
 #define KANN_NODE_CLASS "rspamd{kann_node}"
+#define KANN_NETWORK_CLASS "rspamd{kann}"
 
 /* Simple macros to define behaviour */
 #define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
@@ -34,6 +35,10 @@
 #define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
 #define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
 
+#define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L)
+#define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name}
+
+
 /*
  * Forwarded declarations
  */
@@ -115,6 +120,26 @@ static luaL_reg rspamd_kann_loss_f[] = {
 		{NULL, NULL},
 };
 
+/* Creation functions */
+KANN_NEW_DEF (leaf);
+KANN_NEW_DEF (scalar);
+KANN_NEW_DEF (weight);
+KANN_NEW_DEF (bias);
+KANN_NEW_DEF (weight_conv2d);
+KANN_NEW_DEF (weight_conv1d);
+KANN_NEW_DEF (kann);
+
+static luaL_reg rspamd_kann_new_f[] = {
+		KANN_NEW_INTERFACE (leaf),
+		KANN_NEW_INTERFACE (scalar),
+		KANN_NEW_INTERFACE (weight),
+		KANN_NEW_INTERFACE (bias),
+		KANN_NEW_INTERFACE (weight_conv2d),
+		KANN_NEW_INTERFACE (weight_conv1d),
+		KANN_NEW_INTERFACE (kann),
+		{NULL, NULL},
+};
+
 static int
 rspamd_kann_table_to_flags (lua_State *L, int table_pos)
 {
@@ -196,6 +221,12 @@ lua_load_kann (lua_State * L)
 	luaL_register (L, NULL, rspamd_kann_loss_f);
 	lua_settable (L, -3);
 
+	/* Create functions */
+	lua_pushstring (L, "new");
+	lua_newtable (L);
+	luaL_register (L, NULL, rspamd_kann_new_f);
+	lua_settable (L, -3);
+
 	return 1;
 }
 
@@ -210,7 +241,9 @@ lua_check_kann_node (lua_State *L, int pos)
 void luaopen_kann (lua_State *L)
 {
 	/* Metatables */
-	rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL);
+	rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
+	lua_pop (L, 1); /* No need in metatable... */
+	rspamd_lua_new_class (L, KANN_NETWORK_CLASS, NULL); /* TODO: add methods */
 	lua_pop (L, 1); /* No need in metatable... */
 	rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
 	lua_settop (L, 0);
@@ -224,6 +257,13 @@ void luaopen_kann (lua_State *L)
 	rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
 } while(0)
 
+#define PUSH_KAN_NETWORK(n) do { \
+	kann_t **pn; \
+	pn = lua_newuserdata (L, sizeof (kann_t *)); \
+	*pn = (n); \
+	rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \
+} while(0)
+
 #define PROCESS_KAD_FLAGS(n, pos) do { \
 	int fl = 0; \
 	if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
@@ -561,4 +601,131 @@ lua_kann_loss_ce_multi_weighted (lua_State *L)
 	}
 
 	return 1;
-}
\ No newline at end of file
+}
+
+/* Creation functions */
+static int
+lua_kann_new_scalar (lua_State *L)
+{
+	gint flag = luaL_checkinteger (L, 1);
+	double x = luaL_checknumber (L, 2);
+	kad_node_t *t;
+
+	t = kann_new_scalar (flag, x);
+
+	PROCESS_KAD_FLAGS (t, 3);
+	PUSH_KAD_NODE (t);
+
+	return 1;
+}
+
+static int
+lua_kann_new_weight (lua_State *L)
+{
+	gint nrow = luaL_checkinteger (L, 1);
+	gint ncol = luaL_checkinteger (L, 2);
+	kad_node_t *t;
+
+	t = kann_new_weight (nrow, ncol);
+
+	PROCESS_KAD_FLAGS (t, 3);
+	PUSH_KAD_NODE (t);
+
+	return 1;
+}
+
+static int
+lua_kann_new_bias (lua_State *L)
+{
+	gint n = luaL_checkinteger (L, 1);
+	kad_node_t *t;
+
+	t = kann_new_bias (n);
+
+	PROCESS_KAD_FLAGS (t, 2);
+	PUSH_KAD_NODE (t);
+
+	return 1;
+}
+
+static int
+lua_kann_new_weight_conv2d (lua_State *L)
+{
+	gint nout = luaL_checkinteger (L, 1);
+	gint nin = luaL_checkinteger (L, 2);
+	gint krow = luaL_checkinteger (L, 3);
+	gint kcol = luaL_checkinteger (L, 4);
+	kad_node_t *t;
+
+	t = kann_new_weight_conv2d (nout, nin, krow, kcol);
+
+	PROCESS_KAD_FLAGS (t, 5);
+	PUSH_KAD_NODE (t);
+
+	return 1;
+}
+
+static int
+lua_kann_new_weight_conv1d (lua_State *L)
+{
+	gint nout = luaL_checkinteger (L, 1);
+	gint nin = luaL_checkinteger (L, 2);
+	gint klen = luaL_checkinteger (L, 3);
+	kad_node_t *t;
+
+	t = kann_new_weight_conv1d (nout, nin, klen);
+
+	PROCESS_KAD_FLAGS (t, 4);
+	PUSH_KAD_NODE (t);
+
+	return 1;
+}
+
+static int
+lua_kann_new_leaf (lua_State *L)
+{
+	gint dim = luaL_checkinteger (L, 1), i, *ar;
+	kad_node_t *t;
+
+	if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) {
+		ar = g_malloc0 (sizeof (ar) * dim);
+
+		for (i = 0; i < dim; i ++) {
+			lua_rawgeti (L, 2, i + 1);
+			ar[i] = lua_tointeger (L, -1);
+			lua_pop (L, 1);
+		}
+
+		t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar);
+
+		PROCESS_KAD_FLAGS (t, 3);
+		PUSH_KAD_NODE (t);
+
+		g_free (ar);
+	}
+	else {
+		return luaL_error (L, "invalid arguments for new.leaf, "
+						"dim and vector of elements are required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_new_kann (lua_State *L)
+{
+	kad_node_t *cost = lua_check_kann_node (L, 1);
+	kann_t *k;
+
+	if (cost) {
+		k = kann_new (cost, 0);
+
+		PUSH_KAN_NETWORK (k);
+	}
+	else {
+		return luaL_error (L, "invalid arguments for new.kann, "
+							  "cost node is required");
+	}
+
+	return 1;
+}


More information about the Commits mailing list