commit 718238f: [Rework] Rework learn and add classify condition

Vsevolod Stakhov vsevolod at highsecure.ru
Wed Sep 1 13:28:05 UTC 2021


Author: Vsevolod Stakhov
Date: 2021-09-01 14:26:32 +0100
URL: https://github.com/rspamd/rspamd/commit/718238fd33017f346d1e84fe757481f9f147eb90 (HEAD -> master)

[Rework] Rework learn and add classify condition

---
 src/libserver/cfg_file.h   |   1 +
 src/libserver/cfg_rcl.c    |  30 +++++++-
 src/libstat/stat_process.c | 180 +++++++++++++++++++++++++--------------------
 src/lua/lua_common.c       |  16 ++--
 src/lua/lua_common.h       |   2 +-
 5 files changed, 140 insertions(+), 89 deletions(-)

diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h
index 4d865e273..745f0fb22 100644
--- a/src/libserver/cfg_file.h
+++ b/src/libserver/cfg_file.h
@@ -192,6 +192,7 @@ struct rspamd_classifier_config {
 	const gchar *backend;                           /**< name of statfile's backend							*/
 	ucl_object_t *opts;                             /**< other options                                      */
 	GList *learn_conditions;                        /**< list of learn condition callbacks					*/
+	GList *classify_conditions;                     /**< list of classify condition callbacks					*/
 	gchar *name;                                    /**< unique name of classifier							*/
 	guint32 min_tokens;                             /**< minimal number of tokens to process classifier 	*/
 	guint32 max_tokens;                             /**< maximum number of tokens							*/
diff --git a/src/libserver/cfg_rcl.c b/src/libserver/cfg_rcl.c
index 717b16bea..e3c69c343 100644
--- a/src/libserver/cfg_rcl.c
+++ b/src/libserver/cfg_rcl.c
@@ -1299,7 +1299,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
 	ccf->tokenizer = tkcf;
 
 	/* Handle lua conditions */
-	val = ucl_object_lookup_any (obj, "condition", "learn_condition", NULL);
+	val = ucl_object_lookup_any (obj, "learn_condition", NULL);
 
 	if (val) {
 		LL_FOREACH (val, cur) {
@@ -1310,7 +1310,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
 
 				lua_script = ucl_object_tolstring(cur, &slen);
 				ref_idx = rspamd_lua_function_ref_from_str(L,
-						lua_script, slen, err);
+						lua_script, slen, "learn_condition", err);
 
 				if (ref_idx == LUA_NOREF) {
 					return FALSE;
@@ -1325,6 +1325,32 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
 		}
 	}
 
+	val = ucl_object_lookup_any (obj, "classify_condition", NULL);
+
+	if (val) {
+		LL_FOREACH (val, cur) {
+			if (ucl_object_type(cur) == UCL_STRING) {
+				const gchar *lua_script;
+				gsize slen;
+				gint ref_idx;
+
+				lua_script = ucl_object_tolstring(cur, &slen);
+				ref_idx = rspamd_lua_function_ref_from_str(L,
+						lua_script, slen, "classify_condition", err);
+
+				if (ref_idx == LUA_NOREF) {
+					return FALSE;
+				}
+
+				rspamd_lua_add_ref_dtor (L, cfg->cfg_pool, ref_idx);
+				ccf->classify_conditions = rspamd_mempool_glist_append(
+						cfg->cfg_pool,
+						ccf->classify_conditions,
+						GINT_TO_POINTER (ref_idx));
+			}
+		}
+	}
+
 	ccf->opts = (ucl_object_t *)obj;
 	cfg->classifiers = g_list_prepend (cfg->classifiers, ccf);
 
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 8ac4e499e..4e856b563 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -190,9 +190,75 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx,
 			b32_hout, g_free);
 }
 
+static gboolean
+rspamd_stat_classifier_is_skipped (struct rspamd_task *task,
+		struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam)
+{
+	GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions;
+	lua_State *L = task->cfg->lua_state;
+	gboolean ret = FALSE;
+
+	while (cur) {
+		gint cb_ref = GPOINTER_TO_INT (cur->data);
+		gint old_top = lua_gettop (L);
+
+		lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
+		/* Push task and two booleans: is_spam and is_unlearn */
+		struct rspamd_task **ptask = lua_newuserdata (L, sizeof (*ptask));
+		*ptask = task;
+		rspamd_lua_setclass (L, "rspamd{task}", -1);
+
+		if (is_learn) {
+			lua_pushboolean(L, is_spam);
+			lua_pushboolean(L,
+					task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
+		}
+
+		if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
+			msg_err_task ("call to %s failed: %s",
+					"condition callback",
+					lua_tostring (L, -1));
+		}
+		else {
+			if (lua_isboolean (L, 1)) {
+				if (!lua_toboolean (L, 1)) {
+					ret = TRUE;
+				}
+			}
+
+			if (lua_isstring (L, 2)) {
+				if (ret) {
+					msg_notice_task ("%s condition for classifier %s returned: %s; skip classifier",
+							is_learn ? "learn" : "classify", cl->cfg->name,
+							lua_tostring(L, 2));
+				}
+				else {
+					msg_info_task ("%s condition for classifier %s returned: %s",
+							is_learn ? "learn" : "classify", cl->cfg->name,
+							lua_tostring(L, 2));
+				}
+			}
+			else if (ret) {
+				msg_notice_task("%s condition for classifier %s returned false; skip classifier",
+						is_learn ? "learn" : "classify", cl->cfg->name);
+			}
+
+			if (ret) {
+				lua_settop (L, old_top);
+				break;
+			}
+		}
+
+		lua_settop (L, old_top);
+		cur = g_list_next (cur);
+	}
+
+	return ret;
+}
+
 static void
 rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
-		struct rspamd_task *task, gboolean learn)
+		struct rspamd_task *task, gboolean is_learn, gboolean is_spam)
 {
 	guint i;
 	struct rspamd_statfile *st;
@@ -207,12 +273,39 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
 	rspamd_mempool_add_destructor (task->task_pool,
 			rspamd_ptr_array_free_hard, task->stat_runtimes);
 
+	/* Temporary set all stat_runtimes to some max size to distinguish from NULL */
+	for (i = 0; i < st_ctx->statfiles->len; i ++) {
+		g_ptr_array_index (task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE);
+	}
+
+	for (i = 0; i < st_ctx->classifiers->len; i++) {
+		struct rspamd_classifier *cl = g_ptr_array_index (st_ctx->classifiers, i);
+		gboolean skip_classifier = FALSE;
+
+		if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+			skip_classifier = TRUE;
+		}
+		else {
+			if (rspamd_stat_classifier_is_skipped (task, cl, is_learn , is_spam)) {
+				skip_classifier = TRUE;
+			}
+		}
+
+		if (skip_classifier) {
+			/* Set NULL for all statfiles indexed by id */
+			for (int j = 0; j < cl->statfiles_ids->len; j++) {
+				int id = g_array_index (cl->statfiles_ids, gint, j);
+				g_ptr_array_index (task->stat_runtimes, id) = NULL;
+			}
+		}
+	}
+
 	for (i = 0; i < st_ctx->statfiles->len; i ++) {
 		st = g_ptr_array_index (st_ctx->statfiles, i);
 		g_assert (st != NULL);
 
-		if (st->classifier->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
-			g_ptr_array_index (task->stat_runtimes, i) = NULL;
+		if (g_ptr_array_index (task->stat_runtimes, i) == NULL) {
+			/* The whole classifier is skipped */
 			continue;
 		}
 
@@ -224,7 +317,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
 			continue;
 		}
 
-		bk_run = st->backend->runtime (task, st->stcf, learn, st->bkcf);
+		bk_run = st->backend->runtime (task, st->stcf, is_learn, st->bkcf);
 
 		if (bk_run == NULL) {
 			msg_err_task ("cannot init backend %s for statfile %s",
@@ -249,11 +342,6 @@ rspamd_stat_backends_process (struct rspamd_stat_ctx *st_ctx,
 	for (i = 0; i < st_ctx->statfiles->len; i++) {
 		st = g_ptr_array_index (st_ctx->statfiles, i);
 		cl = st->classifier;
-
-		if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
-			continue;
-		}
-
 		bk_run = g_ptr_array_index (task->stat_runtimes, i);
 
 		if (bk_run != NULL) {
@@ -302,10 +390,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
 		st = g_ptr_array_index (st_ctx->statfiles, i);
 		cl = st->classifier;
 
-		if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
-			continue;
-		}
-
 		bk_run = g_ptr_array_index (task->stat_runtimes, i);
 		g_assert (st != NULL);
 
@@ -332,10 +416,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
 
 		/* Do not process classifiers on backend failures */
 		for (j = 0; j < cl->statfiles_ids->len; j++) {
-			if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
-				continue;
-			}
-
 			id = g_array_index (cl->statfiles_ids, gint, j);
 			bk_run =  g_ptr_array_index (task->stat_runtimes, id);
 			st = g_ptr_array_index (st_ctx->statfiles, id);
@@ -406,7 +486,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage,
 
 	if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) {
 		/* Preprocess tokens */
-		rspamd_stat_preprocess (st_ctx, task, FALSE);
+		rspamd_stat_preprocess (st_ctx, task, FALSE, FALSE);
 	}
 	else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) {
 		/* Process backends */
@@ -490,13 +570,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
 {
 	struct rspamd_classifier *cl, *sel = NULL;
 	guint i;
-	gboolean learned = FALSE, too_small = FALSE, too_large = FALSE,
-			conditionally_skipped = FALSE;
-	lua_State *L;
-	struct rspamd_task **ptask;
-	GList *cur;
-	gint cb_ref;
-	gchar *cond_str = NULL;
+	gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
 
 	if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
 			*err == NULL) {
@@ -544,52 +618,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
 			continue;
 		}
 
-		/* Check all conditions for this classifier */
-		cur = cl->cfg->learn_conditions;
-		L = task->cfg->lua_state;
-
-		while (cur) {
-			cb_ref = GPOINTER_TO_INT (cur->data);
-
-			gint old_top = lua_gettop (L);
-			lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
-			/* Push task and two booleans: is_spam and is_unlearn */
-			ptask = lua_newuserdata (L, sizeof (*ptask));
-			*ptask = task;
-			rspamd_lua_setclass (L, "rspamd{task}", -1);
-			lua_pushboolean (L, spam);
-			lua_pushboolean (L,
-					task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
-
-			if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
-				msg_err_task ("call to %s failed: %s",
-						"condition callback",
-						lua_tostring (L, -1));
-			}
-			else {
-				if (lua_isboolean (L, 1)) {
-					if (!lua_toboolean (L, 1)) {
-						conditionally_skipped = TRUE;
-						/* Also check for error string if needed */
-						if (lua_isstring (L, 2)) {
-							cond_str = rspamd_mempool_strdup (task->task_pool,
-									lua_tostring (L, 2));
-						}
-
-						lua_settop (L, old_top);
-						break;
-					}
-				}
-			}
-
-			lua_settop (L, old_top);
-			cur = g_list_next (cur);
-		}
-
-		if (conditionally_skipped) {
-			break;
-		}
-
 		if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam,
 				task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
 			learned = TRUE;
@@ -627,14 +655,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
 					task->tokens->len,
 					sel->cfg->min_tokens);
 		}
-		else if (conditionally_skipped) {
-			g_set_error (err, rspamd_stat_quark (), 204,
-					"<%s> is skipped for %s classifier: "
-					"%s",
-					MESSAGE_FIELD (task, message_id),
-					sel->cfg->name,
-					cond_str ? cond_str : "unknown reason");
-		}
 	}
 
 	return learned;
@@ -828,7 +848,7 @@ rspamd_stat_learn (struct rspamd_task *task,
 
 	if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
 		/* Process classifiers */
-		rspamd_stat_preprocess (st_ctx, task, TRUE);
+		rspamd_stat_preprocess (st_ctx, task, TRUE, spam);
 
 		if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) {
 			return RSPAMD_STAT_PROCESS_ERROR;
diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c
index 5d874d507..ee29f9b9d 100644
--- a/src/lua/lua_common.c
+++ b/src/lua/lua_common.c
@@ -2294,7 +2294,7 @@ rspamd_lua_require_function (lua_State *L, const gchar *modname,
 
 gint
 rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
-								  GError **err)
+								  const gchar *modname, GError **err)
 {
 	gint err_idx, ref_idx;
 
@@ -2302,11 +2302,12 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
 	err_idx = lua_gettop (L);
 
 	/* Load file */
-	if (luaL_loadbuffer (L, str, slen, "lua_embedded_str") != 0) {
+	if (luaL_loadbuffer (L, str, slen, modname) != 0) {
 		g_set_error (err,
 				lua_error_quark(),
 				EINVAL,
-				"cannot load lua script: %s",
+				"%s: cannot load lua script: %s",
+				modname,
 				lua_tostring (L, -1));
 		lua_settop (L, err_idx - 1); /* Error function */
 
@@ -2318,7 +2319,8 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
 		g_set_error (err,
 				lua_error_quark(),
 				EINVAL,
-				"cannot init lua script: %s",
+				"%s: cannot init lua script: %s",
+				modname,
 				lua_tostring (L, -1));
 		lua_settop (L, err_idx - 1);
 
@@ -2329,8 +2331,10 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
 		g_set_error (err,
 				lua_error_quark(),
 				EINVAL,
-				"cannot init lua script: "
-				"must return function");
+				"%s: cannot init lua script: "
+				"must return function not %s",
+				modname,
+				lua_typename (L, lua_type (L, -1)));
 		lua_settop (L, err_idx - 1);
 
 		return LUA_NOREF;
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h
index b929ab864..10816d450 100644
--- a/src/lua/lua_common.h
+++ b/src/lua/lua_common.h
@@ -572,7 +572,7 @@ void rspamd_lua_add_ref_dtor (lua_State *L, rspamd_mempool_t *pool,
  * @return
  */
 gint rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
-									   GError **err);
+									   const gchar *modname, GError **err);
 
 /**
 * Tries to load some module using `require` and get some method from it


More information about the Commits mailing list