commit e1fadcc: [Feature] Improve autolearning

Vsevolod Stakhov vsevolod at highsecure.ru
Wed Jul 24 14:07:04 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-07-24 15:03:29 +0100
URL: https://github.com/rspamd/rspamd/commit/e1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed

[Feature] Improve autolearning

---
 conf/statistic.conf                   |  2 +-
 lualib/lua_bayes_learn.lua            | 67 ++++++++++++++++++++++++++++++
 src/libserver/mempool_vars_internal.h |  2 +
 src/libstat/backends/redis_backend.c  | 26 ++++++++++++
 src/libstat/stat_internal.h           |  1 +
 src/libstat/stat_process.c            | 76 ++++++++++++++++++++++++++++++++++-
 6 files changed, 172 insertions(+), 2 deletions(-)

diff --git a/conf/statistic.conf b/conf/statistic.conf
index 1e78c73cd..396564a23 100644
--- a/conf/statistic.conf
+++ b/conf/statistic.conf
@@ -41,7 +41,7 @@ classifier "bayes" {
     symbol = "BAYES_SPAM";
     spam = true;
   }
-  learn_condition = "return require("lua_bayes_learn").can_learn"
+  learn_condition = 'return require("lua_bayes_learn").can_learn';
 
   .include(try=true; priority=1) "$LOCAL_CONFDIR/local.d/classifier-bayes.conf"
   .include(try=true; priority=10) "$LOCAL_CONFDIR/override.d/classifier-bayes.conf"
diff --git a/lualib/lua_bayes_learn.lua b/lualib/lua_bayes_learn.lua
index 7df52a2ef..5a46265e7 100644
--- a/lualib/lua_bayes_learn.lua
+++ b/lualib/lua_bayes_learn.lua
@@ -16,6 +16,10 @@ limitations under the License.
 
 -- This file contains functions to simplify bayes classifier auto-learning
 
+local lua_util = require "lua_util"
+
+local N = "lua_bayes"
+
 local exports = {}
 
 exports.can_learn = function(task, is_spam, is_unlearn)
@@ -46,4 +50,67 @@ exports.can_learn = function(task, is_spam, is_unlearn)
   return true
 end
 
+exports.autolearn = function(task, conf)
+  -- We have autolearn config so let's figure out what is requested
+  local verdict,score = lua_util.get_task_verdict(task)
+  local learn_spam,learn_ham = false, false
+
+  if verdict == 'passthrough' then
+    -- No need to autolearn
+    lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
+        verdict)
+    return
+  end
+
+  if conf.spam_threshold and conf.ham_threshold then
+    if verdict == 'spam' then
+      if conf.spam_threshold and score >= conf.spam_threshold then
+        lua_util.debugm(N, task, 'can autolearn spam: score %s >= %s',
+            score, conf.spam_threshold)
+        learn_spam = true
+      end
+    elseif verdict == 'ham' then
+      if conf.ham_threshold and score <= conf.ham_threshold then
+        lua_util.debugm(N, task, 'can autolearn ham: score %s <= %s',
+            score, conf.ham_threshold)
+        learn_ham = true
+      end
+    end
+  end
+
+  if conf.check_balance then
+    -- Check balance of learns
+    local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
+    local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
+
+    local min_balance = 0.9
+    if conf.min_balance then min_balance = conf.min_balance end
+
+    if spam_learns > 0 or ham_learns > 0 then
+      local max_ratio = 1.0 / min_balance
+      local spam_learns_ratio = spam_learns / (ham_learns + 1)
+      if  spam_learns_ratio > max_ratio and learn_spam then
+        lua_util.debugm(N, task,
+            'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
+            spam_learns_ratio, min_balance, spam_learns, ham_learns)
+        learn_spam = false
+      end
+
+      local ham_learns_ratio = ham_learns / (spam_learns + 1)
+      if  ham_learns_ratio > max_ratio and learn_ham then
+        lua_util.debugm(N, task,
+            'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
+            ham_learns_ratio, min_balance, spam_learns, ham_learns)
+        learn_ham = false
+      end
+    end
+  end
+
+  if learn_spam then
+    return 'spam'
+  elseif learn_ham then
+    return 'ham'
+  end
+end
+
 return exports
\ No newline at end of file
diff --git a/src/libserver/mempool_vars_internal.h b/src/libserver/mempool_vars_internal.h
index c062d44d4..576635a9b 100644
--- a/src/libserver/mempool_vars_internal.h
+++ b/src/libserver/mempool_vars_internal.h
@@ -38,5 +38,7 @@
 #define RSPAMD_MEMPOOL_ARC_SIGN_SELECTOR "arc_selector"
 #define RSPAMD_MEMPOOL_STAT_SIGNATURE "stat_signature"
 #define RSPAMD_MEMPOOL_FUZZY_RESULT "fuzzy_hashes"
+#define RSPAMD_MEMPOOL_SPAM_LEARNS "spam_learns"
+#define RSPAMD_MEMPOOL_HAM_LEARNS "ham_learns"
 
 #endif
diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c
index 7263b3c16..9ac6fb445 100644
--- a/src/libstat/backends/redis_backend.c
+++ b/src/libstat/backends/redis_backend.c
@@ -1230,6 +1230,32 @@ rspamd_redis_connected (redisAsyncContext *c, gpointer r, gpointer priv)
 					rt->redis_object_expanded, rt->learned);
 			rspamd_upstream_ok (rt->selected);
 
+			/* Save learn count in mempool variable */
+			gint64 *learns_cnt;
+			const gchar *var_name;
+
+			if (rt->stcf->is_spam) {
+				var_name = RSPAMD_MEMPOOL_SPAM_LEARNS;
+			}
+			else {
+				var_name = RSPAMD_MEMPOOL_HAM_LEARNS;
+			}
+
+			learns_cnt = rspamd_mempool_get_variable (task->task_pool,
+					var_name);
+
+			if (learns_cnt) {
+				(*learns_cnt) += rt->learned;
+			}
+			else {
+				learns_cnt = rspamd_mempool_alloc (task->task_pool,
+						sizeof (*learns_cnt));
+				*learns_cnt = rt->learned;
+				rspamd_mempool_set_variable (task->task_pool,
+						var_name,
+						learns_cnt, NULL);
+			}
+
 			if (rt->learned >= rt->stcf->clcf->min_learns && rt->learned > 0) {
 				rspamd_fstring_t *query = rspamd_redis_tokens_to_query (
 						task,
diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h
index 967a3c4d6..5e2578177 100644
--- a/src/libstat/stat_internal.h
+++ b/src/libstat/stat_internal.h
@@ -43,6 +43,7 @@ struct rspamd_classifier {
 	gpointer cachecf;
 	gulong spam_learns;
 	gulong ham_learns;
+	gint autolearn_cbref;
 	struct rspamd_classifier_config *cfg;
 	struct rspamd_stat_classifier *subrs;
 	gpointer specific;
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 034e1a5be..d720a77ab 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -906,6 +906,19 @@ rspamd_stat_has_classifier_symbols (struct rspamd_task *task,
 	return FALSE;
 }
 
+struct cl_cbref_dtor_data {
+	lua_State *L;
+	gint ref_idx;
+};
+
+static void
+rspamd_stat_cbref_dtor (void *d)
+{
+	struct cl_cbref_dtor_data *data = (struct cl_cbref_dtor_data *)d;
+
+	luaL_unref (data->L, LUA_REGISTRYINDEX, data->ref_idx);
+}
+
 gboolean
 rspamd_stat_check_autolearn (struct rspamd_task *task)
 {
@@ -925,6 +938,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
 	st_ctx = rspamd_stat_get_ctx ();
 	g_assert (st_ctx != NULL);
 
+	L = task->cfg->lua_state;
+
 	for (i = 0; i < st_ctx->classifiers->len; i ++) {
 		cl = g_ptr_array_index (st_ctx->classifiers, i);
 		ret = FALSE;
@@ -933,6 +948,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
 			obj = ucl_object_lookup (cl->cfg->opts, "autolearn");
 
 			if (ucl_object_type (obj) == UCL_BOOLEAN) {
+				/* Legacy true/false */
 				if (ucl_object_toboolean (obj)) {
 					/*
 					 * Default learning algorithm:
@@ -956,6 +972,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
 				}
 			}
 			else if (ucl_object_type (obj) == UCL_ARRAY && obj->len == 2) {
+				/* Legacy thresholds */
 				/*
 				 * We have an array of 2 elements, treat it as a
 				 * ham_score, spam_score
@@ -994,8 +1011,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
 				}
 			}
 			else if (ucl_object_type (obj) == UCL_STRING) {
+				/* Legacy sript */
 				lua_script = ucl_object_tostring (obj);
-				L = task->cfg->lua_state;
 
 				if (luaL_dostring (L, lua_script) != 0) {
 					msg_err_task ("cannot execute lua script for autolearn "
@@ -1018,6 +1035,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
 						else {
 							lua_ret = lua_tostring (L, -1);
 
+							/* We can have immediate results */
 							if (lua_ret) {
 								if (strcmp (lua_ret, "ham") == 0) {
 									task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
@@ -1041,6 +1059,62 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
 					}
 				}
 			}
+			else if (ucl_object_type (obj) == UCL_OBJECT) {
+				/* Try to find autolearn callback */
+				if (cl->autolearn_cbref == 0) {
+					/* We don't have preprocessed cb id, so try to get it */
+					if (!rspamd_lua_require_function (L, "lua_bayes_learn",
+							"autolearn")) {
+						msg_err_task ("cannot get autolearn library from "
+									  "`lua_bayes_learn`");
+					}
+					else {
+						struct cl_cbref_dtor_data *dtor_data;
+
+						dtor_data = (struct cl_cbref_dtor_data *)
+								rspamd_mempool_alloc (task->cfg->cfg_pool,
+									sizeof (*dtor_data));
+						cl->autolearn_cbref = luaL_ref (L, LUA_REGISTRYINDEX);
+						dtor_data->L = L;
+						dtor_data->ref_idx = cl->autolearn_cbref;
+						rspamd_mempool_add_destructor (task->cfg->cfg_pool,
+								rspamd_stat_cbref_dtor, dtor_data);
+					}
+				}
+
+				if (cl->autolearn_cbref != -1) {
+					lua_pushcfunction (L, &rspamd_lua_traceback);
+					err_idx = lua_gettop (L);
+					lua_rawgeti (L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
+
+					ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
+					*ptask = task;
+					rspamd_lua_setclass (L, "rspamd{task}", -1);
+					/* Push the whole object as well */
+					ucl_object_push_lua (L, obj, true);
+
+					if (lua_pcall (L, 2, 1, err_idx) != 0) {
+						msg_err_task ("call to autolearn script failed: "
+									  "%s", lua_tostring (L, -1));
+					}
+					else {
+						lua_ret = lua_tostring (L, -1);
+
+						if (lua_ret) {
+							if (strcmp (lua_ret, "ham") == 0) {
+								task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+								ret = TRUE;
+							}
+							else if (strcmp (lua_ret, "spam") == 0) {
+								task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+								ret = TRUE;
+							}
+						}
+					}
+
+					lua_settop (L, err_idx - 1);
+				}
+			}
 
 			if (ret) {
 				/* Do not autolearn if we have this symbol already */


More information about the Commits mailing list