commit c5ef059: [Project] Add training support to kann

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Jul 1 12:35:11 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-07-01 13:30:09 +0100
URL: https://github.com/rspamd/rspamd/commit/c5ef059e0d0ea41b8a490c2f838a819e1363d0dd (HEAD -> master)

[Project] Add training support to kann

---
 contrib/kann/kann.c |  12 +++-
 src/lua/lua_kann.c  | 170 +++++++++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 178 insertions(+), 4 deletions(-)

diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c
index 3fbf139cc..43227bdc6 100644
--- a/contrib/kann/kann.c
+++ b/contrib/kann/kann.c
@@ -670,7 +670,8 @@ kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return
 kad_node_t *kann_layer_input(int n1)
 {
 	kad_node_t *t;
-	t = kad_feed(2, 1, n1), t->ext_flag |= KANN_F_IN;
+	t = kad_feed(2, 1, n1);
+	t->ext_flag |= KANN_F_IN;
 	return t;
 }
 
@@ -761,6 +762,7 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
 	assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE);
 	t = kann_layer_dense(t, n_out);
 	truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH;
+
 	if (cost_type == KANN_C_MSE) {
 		cost = kad_mse(t, truth);
 	} else if (cost_type == KANN_C_CEB) {
@@ -773,7 +775,13 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
 		t = kad_softmax(t);
 		cost = kad_ce_multi(t, truth);
 	}
-	t->ext_flag |= KANN_F_OUT, cost->ext_flag |= KANN_F_COST;
+	else {
+		assert (0);
+	}
+
+	t->ext_flag |= KANN_F_OUT;
+	cost->ext_flag |= KANN_F_COST;
+
 	return cost;
 }
 
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index a1b31014d..609f05539 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -295,7 +295,7 @@ void luaopen_kann (lua_State *L)
 	int fl = 0; \
 	if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
 	else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
-	(n)->ext_flag = fl; \
+	(n)->ext_flag |= fl; \
 }while(0)
 
 /***
@@ -984,12 +984,168 @@ lua_kann_load (lua_State *L)
 	return 1;
 }
 
+struct rspamd_kann_train_cbdata {
+	lua_State *L;
+	kann_t *k;
+	gint cbref;
+};
+
+static void
+lua_kann_train_cb (int iter, float train_cost, float val_cost, void *ud)
+{
+	struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *)ud;
+
+	if (cbd->cbref != -1) {
+		gint err_idx;
+		lua_State *L = cbd->L;
+
+		lua_pushcfunction (L, &rspamd_lua_traceback);
+		err_idx = lua_gettop (L);
+
+		lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->cbref);
+		lua_pushinteger (L, iter);
+		lua_pushnumber (L, train_cost);
+		lua_pushnumber (L, val_cost);
+
+		if (lua_pcall (L, 3, 0, err_idx) != 0) {
+			msg_err ("cannot run lua train callback: %s",
+					lua_tostring (L, -1));
+		}
+
+		lua_settop (L, err_idx - 1);
+	}
+}
+
+#define FREE_VEC(a, n) do { for(int i = 0; i < (n); i ++) g_free((a)[i]); g_free(a); } while(0)
+
 static int
 lua_kann_train1 (lua_State *L)
 {
 	kann_t *k = lua_check_kann (L, 1);
 
-	g_assert_not_reached (); /* TODO: implement */
+	/* Default train params */
+	double lr = 0.001;
+	gint64 mini_size = 64;
+	gint64 max_epoch = 25;
+	gint64 max_drop_streak = 10;
+	double frac_val = 0.1;
+	gint cbref = -1;
+
+	if (k && lua_istable (L, 2) && lua_istable (L, 3)) {
+		int n = rspamd_lua_table_size (L, 2);
+		int n_in = kann_dim_in (k);
+		int n_out = kann_dim_out (k);
+
+		if (n_in <= 0) {
+			return luaL_error (L, "invalid inputs count: %d", n_in);
+		}
+
+		if (n_out <= 0) {
+			return luaL_error (L, "invalid outputs count: %d", n_in);
+		}
+
+		if (n != rspamd_lua_table_size (L, 3) || n == 0) {
+			return luaL_error (L, "invalid dimensions: outputs size must be "
+						 "equal to inputs and non zero");
+		}
+
+		if (lua_istable (L, 4)) {
+			GError *err = NULL;
+
+			if (!rspamd_lua_parse_table_arguments (L, 4, &err,
+					RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
+					"lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F",
+					&lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) {
+				n = luaL_error (L, "invalid params: %s",
+						err ? err->message : "unknown error");
+				g_error_free (err);
+
+				return n;
+			}
+		}
+
+		float **x, **y;
+
+		/* Fill vectors */
+		x = (float **)g_malloc (sizeof (float *) * n);
+		y = (float **)g_malloc (sizeof (float *) * n);
+
+		for (int s = 0; s < n; s ++) {
+			/* Inputs */
+			lua_rawgeti (L, 2, s + 1);
+			x[s] = (float *)g_malloc (sizeof (float) * n_in);
+
+			if (rspamd_lua_table_size (L, -1) != n_in) {
+				FREE_VEC (x, n);
+				FREE_VEC (y, n);
+
+				n = luaL_error (L, "invalid params at pos %d: "
+					   "bad input dimension %d; %d expected",
+						s + 1,
+						(int)rspamd_lua_table_size (L, -1),
+						n_in);
+
+				return n;
+			}
+
+			for (int i = 0; i < n_in; i ++) {
+				lua_rawgeti (L, -1, i + 1);
+				x[s][i] = lua_tonumber (L, -1);
+
+				lua_pop (L, 1);
+			}
+
+			lua_pop (L, 1);
+
+			/* Outputs */
+			y[s] = (float *)g_malloc (sizeof (float) * n_out);
+			lua_rawgeti (L, 3, s + 1);
+
+			if (rspamd_lua_table_size (L, -1) != n_out) {
+				FREE_VEC (x, n);
+				FREE_VEC (y, n);
+
+				n = luaL_error (L, "invalid params at pos %d: "
+					   "bad output dimension %d; "
+					   "%d expected",
+						s + 1,
+						(int)rspamd_lua_table_size (L, -1),
+						n_out);
+
+				return n;
+			}
+
+			for (int i = 0; i < n_out; i ++) {
+				lua_rawgeti (L, -1, i + 1);
+				y[s][i] = lua_tonumber (L, -1);
+
+				lua_pop (L, 1);
+			}
+
+			lua_pop (L, 1);
+		}
+
+		struct rspamd_kann_train_cbdata cbd;
+
+		cbd.cbref = cbref;
+		cbd.k = k;
+		cbd.L = L;
+
+		int niters = kann_train_fnn1 (k, lr,
+				mini_size, max_epoch, max_drop_streak,
+				frac_val, n, x, y, lua_kann_train_cb, &cbd);
+
+		lua_pushinteger (L, niters);
+
+		FREE_VEC (x, n);
+		FREE_VEC (y, n);
+	}
+	else {
+		return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
+							  " optional params are expected");
+	}
+
+	return 1;
 }
 
 static int
@@ -1001,6 +1157,16 @@ lua_kann_apply1 (lua_State *L)
 		gsize vec_len = rspamd_lua_table_size (L, 2);
 		float *vec = (float *)g_malloc (sizeof (float) * vec_len);
 		int i_out;
+		int n_in = kann_dim_in (k);
+
+		if (n_in <= 0) {
+			return luaL_error (L, "invalid inputs count: %d", n_in);
+		}
+
+		if (n_in != vec_len) {
+			return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+					(int)vec_len, n_in);
+		}
 
 		for (gsize i = 0; i < vec_len; i ++) {
 			lua_rawgeti (L, 2, i + 1);


More information about the Commits mailing list