commit 718238f: [Rework] Rework learn and add classify condition
Vsevolod Stakhov
vsevolod at highsecure.ru
Wed Sep 1 13:28:05 UTC 2021
Author: Vsevolod Stakhov
Date: 2021-09-01 14:26:32 +0100
URL: https://github.com/rspamd/rspamd/commit/718238fd33017f346d1e84fe757481f9f147eb90 (HEAD -> master)
[Rework] Rework learn and add classify condition
---
src/libserver/cfg_file.h | 1 +
src/libserver/cfg_rcl.c | 30 +++++++-
src/libstat/stat_process.c | 180 +++++++++++++++++++++++++--------------------
src/lua/lua_common.c | 16 ++--
src/lua/lua_common.h | 2 +-
5 files changed, 140 insertions(+), 89 deletions(-)
diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h
index 4d865e273..745f0fb22 100644
--- a/src/libserver/cfg_file.h
+++ b/src/libserver/cfg_file.h
@@ -192,6 +192,7 @@ struct rspamd_classifier_config {
const gchar *backend; /**< name of statfile's backend */
ucl_object_t *opts; /**< other options */
GList *learn_conditions; /**< list of learn condition callbacks */
+ GList *classify_conditions; /**< list of classify condition callbacks */
gchar *name; /**< unique name of classifier */
guint32 min_tokens; /**< minimal number of tokens to process classifier */
guint32 max_tokens; /**< maximum number of tokens */
diff --git a/src/libserver/cfg_rcl.c b/src/libserver/cfg_rcl.c
index 717b16bea..e3c69c343 100644
--- a/src/libserver/cfg_rcl.c
+++ b/src/libserver/cfg_rcl.c
@@ -1299,7 +1299,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
ccf->tokenizer = tkcf;
/* Handle lua conditions */
- val = ucl_object_lookup_any (obj, "condition", "learn_condition", NULL);
+ val = ucl_object_lookup_any (obj, "learn_condition", NULL);
if (val) {
LL_FOREACH (val, cur) {
@@ -1310,7 +1310,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
lua_script = ucl_object_tolstring(cur, &slen);
ref_idx = rspamd_lua_function_ref_from_str(L,
- lua_script, slen, err);
+ lua_script, slen, "learn_condition", err);
if (ref_idx == LUA_NOREF) {
return FALSE;
@@ -1325,6 +1325,32 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
}
}
+ val = ucl_object_lookup_any (obj, "classify_condition", NULL);
+
+ if (val) {
+ LL_FOREACH (val, cur) {
+ if (ucl_object_type(cur) == UCL_STRING) {
+ const gchar *lua_script;
+ gsize slen;
+ gint ref_idx;
+
+ lua_script = ucl_object_tolstring(cur, &slen);
+ ref_idx = rspamd_lua_function_ref_from_str(L,
+ lua_script, slen, "classify_condition", err);
+
+ if (ref_idx == LUA_NOREF) {
+ return FALSE;
+ }
+
+ rspamd_lua_add_ref_dtor (L, cfg->cfg_pool, ref_idx);
+ ccf->classify_conditions = rspamd_mempool_glist_append(
+ cfg->cfg_pool,
+ ccf->classify_conditions,
+ GINT_TO_POINTER (ref_idx));
+ }
+ }
+ }
+
ccf->opts = (ucl_object_t *)obj;
cfg->classifiers = g_list_prepend (cfg->classifiers, ccf);
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 8ac4e499e..4e856b563 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -190,9 +190,75 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx,
b32_hout, g_free);
}
+static gboolean
+rspamd_stat_classifier_is_skipped (struct rspamd_task *task,
+ struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam)
+{
+ GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions;
+ lua_State *L = task->cfg->lua_state;
+ gboolean ret = FALSE;
+
+ while (cur) {
+ gint cb_ref = GPOINTER_TO_INT (cur->data);
+ gint old_top = lua_gettop (L);
+
+ lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
+ /* Push task and two booleans: is_spam and is_unlearn */
+ struct rspamd_task **ptask = lua_newuserdata (L, sizeof (*ptask));
+ *ptask = task;
+ rspamd_lua_setclass (L, "rspamd{task}", -1);
+
+ if (is_learn) {
+ lua_pushboolean(L, is_spam);
+ lua_pushboolean(L,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
+ }
+
+ if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
+ msg_err_task ("call to %s failed: %s",
+ "condition callback",
+ lua_tostring (L, -1));
+ }
+ else {
+ if (lua_isboolean (L, 1)) {
+ if (!lua_toboolean (L, 1)) {
+ ret = TRUE;
+ }
+ }
+
+ if (lua_isstring (L, 2)) {
+ if (ret) {
+ msg_notice_task ("%s condition for classifier %s returned: %s; skip classifier",
+ is_learn ? "learn" : "classify", cl->cfg->name,
+ lua_tostring(L, 2));
+ }
+ else {
+ msg_info_task ("%s condition for classifier %s returned: %s",
+ is_learn ? "learn" : "classify", cl->cfg->name,
+ lua_tostring(L, 2));
+ }
+ }
+ else if (ret) {
+ msg_notice_task("%s condition for classifier %s returned false; skip classifier",
+ is_learn ? "learn" : "classify", cl->cfg->name);
+ }
+
+ if (ret) {
+ lua_settop (L, old_top);
+ break;
+ }
+ }
+
+ lua_settop (L, old_top);
+ cur = g_list_next (cur);
+ }
+
+ return ret;
+}
+
static void
rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
- struct rspamd_task *task, gboolean learn)
+ struct rspamd_task *task, gboolean is_learn, gboolean is_spam)
{
guint i;
struct rspamd_statfile *st;
@@ -207,12 +273,39 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
rspamd_mempool_add_destructor (task->task_pool,
rspamd_ptr_array_free_hard, task->stat_runtimes);
+ /* Temporary set all stat_runtimes to some max size to distinguish from NULL */
+ for (i = 0; i < st_ctx->statfiles->len; i ++) {
+ g_ptr_array_index (task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE);
+ }
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ struct rspamd_classifier *cl = g_ptr_array_index (st_ctx->classifiers, i);
+ gboolean skip_classifier = FALSE;
+
+ if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+ skip_classifier = TRUE;
+ }
+ else {
+ if (rspamd_stat_classifier_is_skipped (task, cl, is_learn , is_spam)) {
+ skip_classifier = TRUE;
+ }
+ }
+
+ if (skip_classifier) {
+ /* Set NULL for all statfiles indexed by id */
+ for (int j = 0; j < cl->statfiles_ids->len; j++) {
+ int id = g_array_index (cl->statfiles_ids, gint, j);
+ g_ptr_array_index (task->stat_runtimes, id) = NULL;
+ }
+ }
+ }
+
for (i = 0; i < st_ctx->statfiles->len; i ++) {
st = g_ptr_array_index (st_ctx->statfiles, i);
g_assert (st != NULL);
- if (st->classifier->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
- g_ptr_array_index (task->stat_runtimes, i) = NULL;
+ if (g_ptr_array_index (task->stat_runtimes, i) == NULL) {
+ /* The whole classifier is skipped */
continue;
}
@@ -224,7 +317,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
continue;
}
- bk_run = st->backend->runtime (task, st->stcf, learn, st->bkcf);
+ bk_run = st->backend->runtime (task, st->stcf, is_learn, st->bkcf);
if (bk_run == NULL) {
msg_err_task ("cannot init backend %s for statfile %s",
@@ -249,11 +342,6 @@ rspamd_stat_backends_process (struct rspamd_stat_ctx *st_ctx,
for (i = 0; i < st_ctx->statfiles->len; i++) {
st = g_ptr_array_index (st_ctx->statfiles, i);
cl = st->classifier;
-
- if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
- continue;
- }
-
bk_run = g_ptr_array_index (task->stat_runtimes, i);
if (bk_run != NULL) {
@@ -302,10 +390,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
st = g_ptr_array_index (st_ctx->statfiles, i);
cl = st->classifier;
- if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
- continue;
- }
-
bk_run = g_ptr_array_index (task->stat_runtimes, i);
g_assert (st != NULL);
@@ -332,10 +416,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
/* Do not process classifiers on backend failures */
for (j = 0; j < cl->statfiles_ids->len; j++) {
- if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
- continue;
- }
-
id = g_array_index (cl->statfiles_ids, gint, j);
bk_run = g_ptr_array_index (task->stat_runtimes, id);
st = g_ptr_array_index (st_ctx->statfiles, id);
@@ -406,7 +486,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage,
if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) {
/* Preprocess tokens */
- rspamd_stat_preprocess (st_ctx, task, FALSE);
+ rspamd_stat_preprocess (st_ctx, task, FALSE, FALSE);
}
else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) {
/* Process backends */
@@ -490,13 +570,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
{
struct rspamd_classifier *cl, *sel = NULL;
guint i;
- gboolean learned = FALSE, too_small = FALSE, too_large = FALSE,
- conditionally_skipped = FALSE;
- lua_State *L;
- struct rspamd_task **ptask;
- GList *cur;
- gint cb_ref;
- gchar *cond_str = NULL;
+ gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
*err == NULL) {
@@ -544,52 +618,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
continue;
}
- /* Check all conditions for this classifier */
- cur = cl->cfg->learn_conditions;
- L = task->cfg->lua_state;
-
- while (cur) {
- cb_ref = GPOINTER_TO_INT (cur->data);
-
- gint old_top = lua_gettop (L);
- lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
- /* Push task and two booleans: is_spam and is_unlearn */
- ptask = lua_newuserdata (L, sizeof (*ptask));
- *ptask = task;
- rspamd_lua_setclass (L, "rspamd{task}", -1);
- lua_pushboolean (L, spam);
- lua_pushboolean (L,
- task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
-
- if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
- msg_err_task ("call to %s failed: %s",
- "condition callback",
- lua_tostring (L, -1));
- }
- else {
- if (lua_isboolean (L, 1)) {
- if (!lua_toboolean (L, 1)) {
- conditionally_skipped = TRUE;
- /* Also check for error string if needed */
- if (lua_isstring (L, 2)) {
- cond_str = rspamd_mempool_strdup (task->task_pool,
- lua_tostring (L, 2));
- }
-
- lua_settop (L, old_top);
- break;
- }
- }
- }
-
- lua_settop (L, old_top);
- cur = g_list_next (cur);
- }
-
- if (conditionally_skipped) {
- break;
- }
-
if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam,
task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
learned = TRUE;
@@ -627,14 +655,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
task->tokens->len,
sel->cfg->min_tokens);
}
- else if (conditionally_skipped) {
- g_set_error (err, rspamd_stat_quark (), 204,
- "<%s> is skipped for %s classifier: "
- "%s",
- MESSAGE_FIELD (task, message_id),
- sel->cfg->name,
- cond_str ? cond_str : "unknown reason");
- }
}
return learned;
@@ -828,7 +848,7 @@ rspamd_stat_learn (struct rspamd_task *task,
if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
/* Process classifiers */
- rspamd_stat_preprocess (st_ctx, task, TRUE);
+ rspamd_stat_preprocess (st_ctx, task, TRUE, spam);
if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) {
return RSPAMD_STAT_PROCESS_ERROR;
diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c
index 5d874d507..ee29f9b9d 100644
--- a/src/lua/lua_common.c
+++ b/src/lua/lua_common.c
@@ -2294,7 +2294,7 @@ rspamd_lua_require_function (lua_State *L, const gchar *modname,
gint
rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
- GError **err)
+ const gchar *modname, GError **err)
{
gint err_idx, ref_idx;
@@ -2302,11 +2302,12 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
err_idx = lua_gettop (L);
/* Load file */
- if (luaL_loadbuffer (L, str, slen, "lua_embedded_str") != 0) {
+ if (luaL_loadbuffer (L, str, slen, modname) != 0) {
g_set_error (err,
lua_error_quark(),
EINVAL,
- "cannot load lua script: %s",
+ "%s: cannot load lua script: %s",
+ modname,
lua_tostring (L, -1));
lua_settop (L, err_idx - 1); /* Error function */
@@ -2318,7 +2319,8 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
g_set_error (err,
lua_error_quark(),
EINVAL,
- "cannot init lua script: %s",
+ "%s: cannot init lua script: %s",
+ modname,
lua_tostring (L, -1));
lua_settop (L, err_idx - 1);
@@ -2329,8 +2331,10 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
g_set_error (err,
lua_error_quark(),
EINVAL,
- "cannot init lua script: "
- "must return function");
+ "%s: cannot init lua script: "
+ "must return function not %s",
+ modname,
+ lua_typename (L, lua_type (L, -1)));
lua_settop (L, err_idx - 1);
return LUA_NOREF;
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h
index b929ab864..10816d450 100644
--- a/src/lua/lua_common.h
+++ b/src/lua/lua_common.h
@@ -572,7 +572,7 @@ void rspamd_lua_add_ref_dtor (lua_State *L, rspamd_mempool_t *pool,
* @return
*/
gint rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
- GError **err);
+ const gchar *modname, GError **err);
/**
* Tries to load some module using `require` and get some method from it
More information about the Commits
mailing list