commit 083e6ac: [Project] Add simple forward propagation function

Vsevolod Stakhov vsevolod at highsecure.ru
Sun Jun 30 08:42:04 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-06-30 09:40:58 +0100
URL: https://github.com/rspamd/rspamd/commit/083e6ac5ce374e1e9759c7998dd04b9525333eb4 (HEAD -> master)

[Project] Add simple forward propagation function

---
 contrib/kann/kann.h |  5 ++++-
 src/lua/lua_kann.c  | 51 ++++++++++++++++++++++++++++++++++++++++++++-------
 2 files changed, 48 insertions(+), 8 deletions(-)

diff --git a/contrib/kann/kann.h b/contrib/kann/kann.h
index 7ec748561..af0de5fba 100644
--- a/contrib/kann/kann.h
+++ b/contrib/kann/kann.h
@@ -220,7 +220,10 @@ kad_node_t *kann_layer_rnn2(int *offset, kad_node_t **par, kad_node_t *in, kad_n
 kad_node_t *kann_layer_gru2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag);
 
 /* operations on network with a single input node and a single output node */
-int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max_drop_streak, float frac_val, int n, float **_x, float **_y);
+typedef void (*kann_train_cb)(int iter, float train_cost, float val_cost, void *ud);
+int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch,
+		int max_drop_streak, float frac_val, int n,
+		float **_x, float **_y, kann_train_cb cb, void *ud);
 float kann_cost_fnn1(kann_t *a, int n, float **x, float **y);
 const float *kann_apply1(kann_t *a, float *x);
 
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index 3d50cc587..a1b31014d 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -143,13 +143,13 @@ static luaL_reg rspamd_kann_new_f[] = {
 LUA_FUNCTION_DEF (kann, load);
 LUA_FUNCTION_DEF (kann, destroy);
 LUA_FUNCTION_DEF (kann, save);
-LUA_FUNCTION_DEF (kann, train);
-LUA_FUNCTION_DEF (kann, forward);
+LUA_FUNCTION_DEF (kann, train1);
+LUA_FUNCTION_DEF (kann, apply1);
 
 static luaL_reg rspamd_kann_m[] = {
 		LUA_INTERFACE_DEF (kann, save),
-		LUA_INTERFACE_DEF (kann, train),
-		LUA_INTERFACE_DEF (kann, forward),
+		LUA_INTERFACE_DEF (kann, train1),
+		LUA_INTERFACE_DEF (kann, apply1),
 		{"__gc", lua_kann_destroy},
 		{NULL, NULL},
 };
@@ -985,7 +985,7 @@ lua_kann_load (lua_State *L)
 }
 
 static int
-lua_kann_train (lua_State *L)
+lua_kann_train1 (lua_State *L)
 {
 	kann_t *k = lua_check_kann (L, 1);
 
@@ -993,9 +993,46 @@ lua_kann_train (lua_State *L)
 }
 
 static int
-lua_kann_forward (lua_State *L)
+lua_kann_apply1 (lua_State *L)
 {
 	kann_t *k = lua_check_kann (L, 1);
 
-	g_assert_not_reached (); /* TODO: implement */
+	if (k && lua_istable (L, 2)) {
+		gsize vec_len = rspamd_lua_table_size (L, 2);
+		float *vec = (float *)g_malloc (sizeof (float) * vec_len);
+		int i_out;
+
+		for (gsize i = 0; i < vec_len; i ++) {
+			lua_rawgeti (L, 2, i + 1);
+			vec[i] = lua_tonumber (L, -1);
+			lua_pop (L, 1);
+		}
+
+		i_out = kann_find (k, KANN_F_OUT, 0);
+
+		if (i_out <= 0) {
+			g_free (vec);
+			return luaL_error (L, "invalid ANN: output layer is missing or is "
+						 "at the input pos");
+		}
+
+		kann_set_batch_size (k, 1);
+		kann_feed_bind (k, KANN_F_IN, 0, &vec);
+		kad_eval_at (k->n, k->v, i_out);
+
+		gsize outlen = kad_len (k->v[i_out]);
+		lua_createtable (L, outlen, 0);
+
+		for (gsize i = 0; i < outlen; i ++) {
+			lua_pushnumber (L, k->v[i_out]->x[i]);
+			lua_rawseti (L, -2, i + 1);
+		}
+
+		g_free (vec);
+	}
+	else {
+		return luaL_error (L, "invalid arguments: rspamd{kann} expected");
+	}
+
+	return 1;
 }
\ No newline at end of file


More information about the Commits mailing list