commit 131c74b: [Rework] Enable explicit coroutines symbols

Vsevolod Stakhov vsevolod at highsecure.ru
Fri Mar 1 10:07:03 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-03-01 09:59:53 +0000
URL: https://github.com/rspamd/rspamd/commit/131c74bd2c3419c6ca6dcdd259376f489f305205

[Rework] Enable explicit coroutines symbols

---
 src/libserver/rspamd_symcache.h |   1 +
 src/lua/lua_config.c            | 208 ++++++++++++++++++++++++++++++++++++----
 2 files changed, 193 insertions(+), 16 deletions(-)

diff --git a/src/libserver/rspamd_symcache.h b/src/libserver/rspamd_symcache.h
index 69eac1f01..a038d6a9d 100644
--- a/src/libserver/rspamd_symcache.h
+++ b/src/libserver/rspamd_symcache.h
@@ -50,6 +50,7 @@ enum rspamd_symbol_type {
 	SYMBOL_TYPE_MIME_ONLY = (1 << 15), /* Symbol is mime only */
 	SYMBOL_TYPE_EXPLICIT_DISABLE = (1 << 16), /* Symbol should be disabled explicitly only */
 	SYMBOL_TYPE_IGNORE_PASSTHROUGH = (1 << 17), /* Symbol ignores passthrough result */
+	SYMBOL_TYPE_USE_CORO = (1 << 18), /* Symbol uses lua coroutines */
 };
 
 /**
diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c
index b5e668e71..9d7f3341c 100644
--- a/src/lua/lua_config.c
+++ b/src/lua/lua_config.c
@@ -1222,6 +1222,142 @@ lua_metric_symbol_callback (struct rspamd_task *task,
 	g_assert (lua_gettop (L) == level - 1);
 }
 
+static void lua_metric_symbol_callback_return (struct thread_entry *thread_entry,
+											   int ret);
+
+static void lua_metric_symbol_callback_error (struct thread_entry *thread_entry,
+											  int ret,
+											  const char *msg);
+
+static void
+lua_metric_symbol_callback_coro (struct rspamd_task *task,
+							struct rspamd_symcache_item *item,
+							gpointer ud)
+{
+	struct lua_callback_data *cd = ud;
+	struct rspamd_task **ptask;
+	struct thread_entry *thread_entry;
+
+	rspamd_symcache_item_async_inc (task, item, "lua coro symbol");
+	thread_entry = lua_thread_pool_get_for_task (task);
+
+	g_assert(thread_entry->cd == NULL);
+	thread_entry->cd = cd;
+
+	lua_State *thread = thread_entry->lua_state;
+	cd->stack_level = lua_gettop (thread);
+	cd->item = item;
+
+	if (cd->cb_is_ref) {
+		lua_rawgeti (thread, LUA_REGISTRYINDEX, cd->callback.ref);
+	}
+	else {
+		lua_getglobal (thread, cd->callback.name);
+	}
+
+	ptask = lua_newuserdata (thread, sizeof (struct rspamd_task *));
+	rspamd_lua_setclass (thread, "rspamd{task}", -1);
+	*ptask = task;
+
+	thread_entry->finish_callback = lua_metric_symbol_callback_return;
+	thread_entry->error_callback = lua_metric_symbol_callback_error;
+
+	lua_thread_call (thread_entry, 1);
+}
+
+static void
+lua_metric_symbol_callback_error (struct thread_entry *thread_entry,
+								  int ret,
+								  const char *msg)
+{
+	struct lua_callback_data *cd = thread_entry->cd;
+	struct rspamd_task *task = thread_entry->task;
+	msg_err_task ("call to coroutine (%s) failed (%d): %s", cd->symbol, ret, msg);
+
+	rspamd_symcache_item_async_dec_check (task, cd->item, "lua coro symbol");
+}
+
+static void
+lua_metric_symbol_callback_return (struct thread_entry *thread_entry, int ret)
+{
+	struct lua_callback_data *cd = thread_entry->cd;
+	struct rspamd_task *task = thread_entry->task;
+	int nresults;
+	struct rspamd_symbol_result *s;
+
+	(void)ret;
+
+	lua_State *L = thread_entry->lua_state;
+
+	nresults = lua_gettop (L) - cd->stack_level;
+
+	if (nresults >= 1) {
+		/* Function returned boolean, so maybe we need to insert result? */
+		gint res = 0;
+		gint i;
+		gdouble flag = 1.0;
+		gint type;
+
+		type = lua_type (L, cd->stack_level + 1);
+
+		if (type == LUA_TBOOLEAN) {
+			res = lua_toboolean (L, cd->stack_level + 1);
+		}
+		else if (type == LUA_TFUNCTION) {
+			g_assert_not_reached ();
+		}
+		else {
+			res = lua_tonumber (L, cd->stack_level + 1);
+		}
+
+		if (res) {
+			gint first_opt = 2;
+
+			if (lua_type (L, cd->stack_level + 2) == LUA_TNUMBER) {
+				flag = lua_tonumber (L, cd->stack_level + 2);
+				/* Shift opt index */
+				first_opt = 3;
+			}
+			else {
+				flag = res;
+			}
+
+			s = rspamd_task_insert_result (task, cd->symbol, flag, NULL);
+
+			if (s) {
+				guint last_pos = lua_gettop (L);
+
+				for (i = cd->stack_level + first_opt; i <= last_pos; i++) {
+					if (lua_type (L, i) == LUA_TSTRING) {
+						const char *opt = lua_tostring (L, i);
+
+						rspamd_task_add_result_option (task, s, opt);
+					}
+					else if (lua_type (L, i) == LUA_TTABLE) {
+						lua_pushvalue (L, i);
+
+						for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+							const char *opt = lua_tostring (L, -1);
+
+							rspamd_task_add_result_option (task, s, opt);
+						}
+
+						lua_pop (L, 1);
+					}
+				}
+			}
+
+		}
+
+		lua_pop (L, nresults);
+	}
+
+	g_assert (lua_gettop (L) == cd->stack_level); /* we properly cleaned up the stack */
+
+	cd->stack_level = 0;
+	rspamd_symcache_item_async_dec_check (task, cd->item, "lua coro symbol");
+}
+
 static gint
 rspamd_register_symbol_fromlua (lua_State *L,
 		struct rspamd_config *cfg,
@@ -1255,6 +1391,10 @@ rspamd_register_symbol_fromlua (lua_State *L,
 	}
 
 	if (ref != -1) {
+		if (type & SYMBOL_TYPE_USE_CORO) {
+			/* Coroutines are incompatible with squeezing */
+			no_squeeze = TRUE;
+		}
 		/*
 		 * We call for routine called lua_squeeze_rules.squeeze_rule if it exists
 		 */
@@ -1322,15 +1462,27 @@ rspamd_register_symbol_fromlua (lua_State *L,
 					cd->symbol = rspamd_mempool_strdup (cfg->cfg_pool, name);
 				}
 
-				ret = rspamd_symcache_add_symbol (cfg->cache,
-						name,
-						priority,
-						lua_metric_symbol_callback,
-						cd,
-						type,
-						parent);
+				if (type & SYMBOL_TYPE_USE_CORO) {
+					ret = rspamd_symcache_add_symbol (cfg->cache,
+							name,
+							priority,
+							lua_metric_symbol_callback_coro,
+							cd,
+							type,
+							parent);
+				}
+				else {
+					ret = rspamd_symcache_add_symbol (cfg->cache,
+							name,
+							priority,
+							lua_metric_symbol_callback,
+							cd,
+							type,
+							parent);
+				}
+
 				rspamd_mempool_add_destructor (cfg->cfg_pool,
-						(rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
+						(rspamd_mempool_destruct_t) lua_destroy_cfg_symbol,
 						cd);
 			}
 		}
@@ -1346,13 +1498,24 @@ rspamd_register_symbol_fromlua (lua_State *L,
 				cd->symbol = rspamd_mempool_strdup (cfg->cfg_pool, name);
 			}
 
-			ret = rspamd_symcache_add_symbol (cfg->cache,
-					name,
-					priority,
-					lua_metric_symbol_callback,
-					cd,
-					type,
-					parent);
+			if (type & SYMBOL_TYPE_USE_CORO) {
+				ret = rspamd_symcache_add_symbol (cfg->cache,
+						name,
+						priority,
+						lua_metric_symbol_callback_coro,
+						cd,
+						type,
+						parent);
+			}
+			else {
+				ret = rspamd_symcache_add_symbol (cfg->cache,
+						name,
+						priority,
+						lua_metric_symbol_callback,
+						cd,
+						type,
+						parent);
+			}
 			rspamd_mempool_add_destructor (cfg->cfg_pool,
 					(rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
 					cd);
@@ -1529,6 +1692,9 @@ lua_parse_symbol_flags (const gchar *str)
 		if (strstr (str, "explicit_disable") != NULL) {
 			ret |= SYMBOL_TYPE_EXPLICIT_DISABLE;
 		}
+		if (strstr (str, "coro") != NULL) {
+			ret |= SYMBOL_TYPE_USE_CORO;
+		}
 	}
 
 	return ret;
@@ -2423,7 +2589,7 @@ lua_config_newindex (lua_State *L)
 	LUA_TRACE_POINT;
 	struct rspamd_config *cfg = lua_check_config (L, 1);
 	const gchar *name;
-	gint id, nshots;
+	gint id, nshots, flags = 0;
 	gboolean optional = FALSE, no_squeeze = FALSE;
 
 	name = luaL_checkstring (L, 2);
@@ -2458,6 +2624,7 @@ lua_config_newindex (lua_State *L)
 			 * "weight" - optional weight
 			 * "priority" - optional priority
 			 * "type" - optional type (normal, virtual, callback)
+			 * "flags" - optional flags
 			 * -- Metric options
 			 * "score" - optional default score (overridden by metric)
 			 * "group" - optional default group
@@ -2510,6 +2677,15 @@ lua_config_newindex (lua_State *L)
 			}
 			lua_pop (L, 1);
 
+			lua_pushstring (L, "flags");
+			lua_gettable (L, -2);
+
+			if (lua_type (L, -1) == LUA_TSTRING) {
+				type_str = lua_tostring (L, -1);
+				type |= lua_parse_symbol_flags (type_str);
+			}
+			lua_pop (L, 1);
+
 			lua_pushstring (L, "condition");
 			lua_gettable (L, -2);
 


More information about the Commits mailing list