commit 90fa147: [Project] Add preliminary bindings for kann

Vsevolod Stakhov vsevolod at highsecure.ru
Fri Jun 28 17:14:05 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-06-28 18:06:40 +0100
URL: https://github.com/rspamd/rspamd/commit/90fa147ca70698661da0cce271d6ac0982a92c37

[Project] Add preliminary bindings for kann

---
 src/lua/lua_kann.c | 564 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 564 insertions(+)

diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
new file mode 100644
index 000000000..cec75acbb
--- /dev/null
+++ b/src/lua/lua_kann.c
@@ -0,0 +1,564 @@
+/*-
+ * Copyright 2019 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lua_common.h"
+#include "contrib/kann/kann.h"
+
+/***
+ * @module rspamd_kann
+ * `rspamd_kann` is a Lua interface to kann library
+ */
+
+#define KANN_NODE_CLASS "rspamd{kann_node}"
+
+/* Simple macros to define behaviour */
+#define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
+#define KANN_LAYER_INTERFACE(name) {#name, lua_kann_layer_ ## name}
+
+#define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_ ## name (lua_State *L)
+#define KANN_TRANSFORM_INTERFACE(name) {#name, lua_kann_transform_ ## name}
+
+#define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
+#define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
+
+/*
+ * Forwarded declarations
+ */
+static kad_node_t *lua_check_kann_node (lua_State *L, int pos);
+
+/* Layers */
+KANN_LAYER_DEF(input);
+KANN_LAYER_DEF(dense);
+KANN_LAYER_DEF(layernorm);
+KANN_LAYER_DEF(rnn);
+KANN_LAYER_DEF(lstm);
+KANN_LAYER_DEF(gru);
+KANN_LAYER_DEF(conv2d);
+KANN_LAYER_DEF(conv1d);
+KANN_LAYER_DEF(cost);
+
+static luaL_reg rspamd_kann_layers_f[] = {
+		KANN_LAYER_INTERFACE(input),
+		KANN_LAYER_INTERFACE(dense),
+		KANN_LAYER_INTERFACE(layernorm),
+		KANN_LAYER_INTERFACE(rnn),
+		KANN_LAYER_INTERFACE(lstm),
+		KANN_LAYER_INTERFACE(gru),
+		KANN_LAYER_INTERFACE(conv2d),
+		KANN_LAYER_INTERFACE(conv1d),
+		KANN_LAYER_INTERFACE(cost),
+		{NULL, NULL},
+};
+
+/* Transition and composition functions */
+
+/* General transform */
+KANN_TRANSFORM_DEF (add);
+KANN_TRANSFORM_DEF (sub);
+KANN_TRANSFORM_DEF (mul);
+KANN_TRANSFORM_DEF (cmul);
+KANN_TRANSFORM_DEF (matmul);
+
+KANN_TRANSFORM_DEF (square);
+KANN_TRANSFORM_DEF (sigm);
+KANN_TRANSFORM_DEF (tanh);
+KANN_TRANSFORM_DEF (relu);
+KANN_TRANSFORM_DEF (softmax);
+KANN_TRANSFORM_DEF (1minus);
+KANN_TRANSFORM_DEF (exp);
+KANN_TRANSFORM_DEF (log);
+KANN_TRANSFORM_DEF (sin);
+static luaL_reg rspamd_kann_transform_f[] = {
+		KANN_TRANSFORM_INTERFACE (add),
+		KANN_TRANSFORM_INTERFACE (sub),
+		KANN_TRANSFORM_INTERFACE (mul),
+		KANN_TRANSFORM_INTERFACE (cmul),
+		KANN_TRANSFORM_INTERFACE (matmul),
+
+		KANN_TRANSFORM_INTERFACE (square),
+		KANN_TRANSFORM_INTERFACE (sigm),
+		KANN_TRANSFORM_INTERFACE (tanh),
+		KANN_TRANSFORM_INTERFACE (relu),
+		KANN_TRANSFORM_INTERFACE (softmax),
+		KANN_TRANSFORM_INTERFACE (1minus),
+		KANN_TRANSFORM_INTERFACE (exp),
+		KANN_TRANSFORM_INTERFACE (log),
+		KANN_TRANSFORM_INTERFACE (sin),
+		{NULL, NULL},
+};
+
+/* Loss functions */
+KANN_LOSS_DEF (mse);
+KANN_LOSS_DEF (ce_multi);
+KANN_LOSS_DEF (ce_bin);
+KANN_LOSS_DEF (ce_bin_neg);
+KANN_LOSS_DEF (ce_multi_weighted);
+static luaL_reg rspamd_kann_loss_f[] = {
+		KANN_LOSS_INTERFACE (mse),
+		KANN_LOSS_INTERFACE (ce_multi),
+		KANN_LOSS_INTERFACE (ce_bin),
+		KANN_LOSS_INTERFACE (ce_bin_neg),
+		KANN_LOSS_INTERFACE (ce_multi_weighted),
+		{NULL, NULL},
+};
+
+static int
+rspamd_kann_table_to_flags (lua_State *L, int table_pos)
+{
+	int result = 0;
+
+	lua_pushvalue (L, table_pos);
+
+	for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+		int fl = lua_tointeger (L, -1);
+
+		result |= fl;
+	}
+
+	lua_pop (L, 1);
+
+	return result;
+}
+
+static gint
+lua_load_kann (lua_State * L)
+{
+	lua_newtable (L);
+
+	/* Flags */
+	lua_pushstring (L, "flag");
+	lua_newtable (L);
+	lua_pushinteger (L, KANN_F_IN);
+	lua_setfield (L, -2, "in");
+	lua_pushinteger (L, KANN_F_COST);
+	lua_setfield (L, -2, "cost");
+	lua_pushinteger (L, KANN_F_OUT);
+	lua_setfield (L, -2, "out");
+	lua_pushinteger (L, KANN_F_TRUTH);
+	lua_setfield (L, -2, "truth");
+	lua_settable (L, -3);
+
+	/* Cost type */
+	lua_pushstring (L, "cost");
+	lua_newtable (L);
+	/* binary cross-entropy cost, used with sigmoid */
+	lua_pushinteger (L, KANN_C_CEB);
+	lua_setfield (L, -2, "ceb");
+	/* multi-class cross-entropy cost, used with softmax */
+	lua_pushinteger (L, KANN_C_CEM);
+	lua_setfield (L, -2, "cem");
+	/* binary cross-entropy-like cost, used with tanh */
+	lua_pushinteger (L, KANN_C_CEB_NEG);
+	lua_setfield (L, -2, "ceb_neg");
+	lua_pushinteger (L, KANN_C_MSE);
+	lua_setfield (L, -2, "mse");
+	lua_settable (L, -3);
+
+	/* RNN flag */
+	lua_pushstring (L, "rnn");
+	lua_newtable (L);
+	/* apply layer normalization */
+	lua_pushinteger (L, KANN_RNN_NORM);
+	lua_setfield (L, -2, "norm");
+	/* take the initial hidden values as variables */
+	lua_pushinteger (L, KANN_RNN_VAR_H0);
+	lua_setfield (L, -2, "var_h0");
+	lua_settable (L, -3);
+
+	/* Layers */
+	lua_pushstring (L, "layer");
+	lua_newtable (L);
+	luaL_register (L, NULL, rspamd_kann_layers_f);
+	lua_settable (L, -3);
+
+	/* Transforms */
+	lua_pushstring (L, "transform");
+	lua_newtable (L);
+	luaL_register (L, NULL, rspamd_kann_transform_f);
+	lua_settable (L, -3);
+
+	/* Cost */
+	lua_pushstring (L, "loss");
+	lua_newtable (L);
+	luaL_register (L, NULL, rspamd_kann_loss_f);
+	lua_settable (L, -3);
+
+	return 1;
+}
+
+static kad_node_t *
+lua_check_kann_node (lua_State *L, int pos)
+{
+	void *ud = rspamd_lua_check_udata (L, pos, KANN_NODE_CLASS);
+	luaL_argcheck (L, ud != NULL, pos, "'kann_node' expected");
+	return ud ? *((kad_node_t **)ud) : NULL;
+}
+
+void luaopen_kann (lua_State *L)
+{
+	/* Metatables */
+	rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL);
+	lua_pop (L, 1); /* No need in metatable... */
+	rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
+	lua_settop (L, 0);
+}
+
+/* Layers implementation */
+#define PUSH_KAD_NODE(n) do { \
+	kad_node_t **pt; \
+	pt = lua_newuserdata (L, sizeof (kad_node_t *)); \
+	*pt = (n); \
+	rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
+} while(0)
+
+#define PROCESS_KAD_FLAGS(n, pos) do { \
+	int fl = 0; \
+	if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
+	else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
+	(n)->ext_flag = fl; \
+}while(0)
+
+static int
+lua_kann_layer_input (lua_State *L)
+{
+	gint nnodes = luaL_checkinteger (L, 1);
+
+	if (nnodes > 0) {
+		kad_node_t *t;
+
+		t = kann_layer_input (nnodes);
+
+		PROCESS_KAD_FLAGS (t, 2);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, nnodes required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_layer_dense (lua_State *L)
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+	gint nnodes = luaL_checkinteger (L, 2);
+
+	if (in != NULL && nnodes > 0) {
+		kad_node_t *t;
+
+		t = kann_layer_dense (in, nnodes);
+
+		PROCESS_KAD_FLAGS (t, 3);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, input + nnodes required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_layer_layerdropout (lua_State *L)
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+	double r = luaL_checknumber (L, 2);
+
+	if (in != NULL) {
+		kad_node_t *t;
+
+		t = kann_layer_dropout (in, r);
+
+		PROCESS_KAD_FLAGS (t, 3);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, input + rate required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_layer_layernorm (lua_State *L)
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+
+	if (in != NULL) {
+		kad_node_t *t;
+
+		t = kann_layer_layernorm (in);
+
+		PROCESS_KAD_FLAGS (t, 2);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, input required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_layer_rnn (lua_State *L)
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+	gint nnodes = luaL_checkinteger (L, 2);
+	gint rnnflags = 0;
+
+	if (in != NULL && nnodes > 0) {
+		kad_node_t *t;
+
+		if (lua_type (L, 3) == LUA_TNUMBER) {
+			rnnflags = lua_tointeger (L, 3);
+		}
+
+		t = kann_layer_rnn (in, nnodes, rnnflags);
+
+		PROCESS_KAD_FLAGS (t, 4);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, input + nnodes required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_layer_lstm (lua_State *L)
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+	gint nnodes = luaL_checkinteger (L, 2);
+	gint rnnflags = 0;
+
+	if (in != NULL && nnodes > 0) {
+		kad_node_t *t;
+
+		if (lua_type (L, 3) == LUA_TNUMBER) {
+			rnnflags = lua_tointeger (L, 3);
+		}
+
+		t = kann_layer_lstm (in, nnodes, rnnflags);
+
+		PROCESS_KAD_FLAGS (t, 4);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, input + nnodes required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_layer_gru (lua_State *L)
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+	gint nnodes = luaL_checkinteger (L, 2);
+	gint rnnflags = 0;
+
+	if (in != NULL && nnodes > 0) {
+		kad_node_t *t;
+
+		if (lua_type (L, 3) == LUA_TNUMBER) {
+			rnnflags = lua_tointeger (L, 3);
+		}
+
+		t = kann_layer_gru (in, nnodes, rnnflags);
+
+		PROCESS_KAD_FLAGS (t, 4);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, input + nnodes required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_layer_conv2d (lua_State *L)
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+	int n_flt = luaL_checkinteger (L, 2);
+	int k_rows = luaL_checkinteger (L, 3);
+	int k_cols =  luaL_checkinteger (L, 4);
+	int stride_r = luaL_checkinteger (L, 5);
+	int stride_c = luaL_checkinteger (L, 6);
+	int pad_r = luaL_checkinteger (L, 7);
+	int pad_c = luaL_checkinteger (L, 8);
+
+	if (in != NULL) {
+		kad_node_t *t;
+		t = kann_layer_conv2d (in, n_flt, k_rows, k_cols, stride_r, stride_c,
+				pad_r, pad_c);
+
+		PROCESS_KAD_FLAGS (t, 9);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_layer_conv1d (lua_State *L)
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+	int n_flt = luaL_checkinteger (L, 2);
+	int k_size = luaL_checkinteger (L, 3);
+	int stride = luaL_checkinteger (L, 4);
+	int pad = luaL_checkinteger (L, 5);
+
+	if (in != NULL) {
+		kad_node_t *t;
+		t = kann_layer_conv1d (in, n_flt, k_size, stride, pad);
+
+		PROCESS_KAD_FLAGS (t, 6);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, input, nflt, k, stride, pad required");
+	}
+
+	return 1;
+}
+
+static int
+lua_kann_layer_cost (lua_State *L)
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+	int nout = luaL_checkinteger (L, 2);
+	int cost_type = luaL_checkinteger (L, 3);
+
+	if (in != NULL && nout > 0) {
+		kad_node_t *t;
+		t = kann_layer_cost (in, nout, cost_type);
+
+		PROCESS_KAD_FLAGS (t, 4);
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments, input, nout and cost_type are required");
+	}
+
+	return 1;
+}
+
+/* Generic helpers */
+static int
+lua_kann_call_unary_function (lua_State *L, const char *name,
+		kad_node_t *(*func)(kad_node_t *))
+{
+	kad_node_t *in = lua_check_kann_node (L, 1);
+
+	if (in != NULL) {
+		kad_node_t *t;
+		t = func (in);
+
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments for %s, input required", name);
+	}
+
+	return 1;
+}
+static int
+lua_kann_call_binary_function (lua_State *L, const char *name,
+							  kad_node_t *(*func)(kad_node_t *, kad_node_t *))
+{
+	kad_node_t *x = lua_check_kann_node (L, 1);
+	kad_node_t *y = lua_check_kann_node (L, 2);
+
+	if (x != NULL && y != NULL) {
+		kad_node_t *t;
+		t = func (x, y);
+
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments for %s, 2 inputs required", name);
+	}
+
+	return 1;
+}
+
+#define LUA_UNARY_TRANSFORM_FUNC_IMPL(name)									\
+static int lua_kann_transform_ ##name (lua_State *L)						\
+{																			\
+	return lua_kann_call_unary_function(L, #name, kad_##name);				\
+}
+
+#define LUA_BINARY_TRANSFORM_FUNC_IMPL(name)								\
+static int lua_kann_transform_ ##name (lua_State *L)						\
+{																			\
+	return lua_kann_call_binary_function(L, #name, kad_##name);				\
+}
+
+#define LUA_LOSS_FUNC_IMPL(name)											\
+static int lua_kann_loss_ ##name (lua_State *L)								\
+{																			\
+	return lua_kann_call_binary_function(L, #name, kad_##name);				\
+}
+
+/* Transform functions registered via macro helpers */
+LUA_BINARY_TRANSFORM_FUNC_IMPL (add)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (sub)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (mul)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (cmul)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (matmul)
+
+LUA_UNARY_TRANSFORM_FUNC_IMPL (square)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (sigm)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (tanh)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (relu)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (softmax)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (1minus)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (exp)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (log)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (sin)
+
+/* Generic cost functions */
+LUA_LOSS_FUNC_IMPL (mse)
+LUA_LOSS_FUNC_IMPL (ce_multi)
+LUA_LOSS_FUNC_IMPL (ce_bin)
+LUA_LOSS_FUNC_IMPL (ce_bin_neg)
+
+/* The only case of ternary weight function */
+static int
+lua_kann_loss_ce_multi_weighted (lua_State *L)
+{
+	kad_node_t *pred = lua_check_kann_node (L, 1);
+	kad_node_t *truth = lua_check_kann_node (L, 2);
+	kad_node_t *weight = lua_check_kann_node (L, 3);
+
+	if (pred != NULL && truth != NULL && weight != NULL) {
+		kad_node_t *t;
+		t = kad_ce_multi_weighted (pred, truth, weight);
+
+		PUSH_KAD_NODE (t);
+	}
+	else {
+		return luaL_error (L, "invalid arguments for ce_multi_weighted, 3 inputs required");
+	}
+
+	return 1;
+}
\ No newline at end of file


More information about the Commits mailing list