commit cbc9079: [Project] Further caching logic modifications

Vsevolod Stakhov vsevolod at rspamd.com
Mon Jul 29 17:49:51 UTC 2024


Author: Vsevolod Stakhov
Date: 2023-12-06 15:36:52 +0000
URL: https://github.com/rspamd/rspamd/commit/cbc907994e73a2d0aa16900c587298fb307cca5b

[Project] Further caching logic modifications

---
 lualib/lua_bayes_redis.lua             |   2 +-
 src/libstat/backends/redis_backend.cxx | 161 +++++++++++++++++++++++++++------
 2 files changed, 132 insertions(+), 31 deletions(-)

diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua
index 5dca2db43..25c56d58b 100644
--- a/lualib/lua_bayes_redis.lua
+++ b/lualib/lua_bayes_redis.lua
@@ -31,7 +31,7 @@ local function gen_classify_functor(redis_params, classify_script_id)
       if err then
         callback(task, false, err)
       else
-        callback(task, true, data[1], data[2], data[3])
+        callback(task, true, data[1], data[2], data[3], data[4])
       end
     end
 
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
index 46b27cb15..46f88de19 100644
--- a/src/libstat/backends/redis_backend.cxx
+++ b/src/libstat/backends/redis_backend.cxx
@@ -63,32 +63,58 @@ struct redis_stat_runtime {
 	struct redis_stat_ctx *ctx;
 	struct rspamd_task *task;
 	struct rspamd_statfile_config *stcf;
-	GPtrArray *tokens;
+	GPtrArray *tokens = nullptr;
 	const char *redis_object_expanded;
 	std::uint64_t learned = 0;
 	int id;
 	std::vector<std::pair<int, T>> *results = nullptr;
+	bool need_redis_call = true;
 
 	using result_type = std::vector<std::pair<int, T>>;
 
-	explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
-		: ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
+private:
+	/* Called on connection termination */
+	static void rt_dtor(gpointer data)
 	{
+		auto *rt = REDIS_RUNTIME(data);
+
+		delete rt;
 	}
 
-	void init()
+	/* Avoid occasional deletion */
+	~redis_stat_runtime()
 	{
+		if (tokens) {
+			g_ptr_array_unref(tokens);
+		}
+
+		delete results;
 	}
 
-	void set_results(std::vector<std::pair<int, T>> *_results)
+public:
+	explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
+		: ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
 	{
-		results = _results;
+		rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this);
 	}
 
-	~redis_stat_runtime()
+	static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
+										   bool is_spam) -> std::optional<redis_stat_runtime<T> *>
 	{
-		g_ptr_array_unref(tokens);
-		delete results;
+		auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+		auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str());
+
+		if (res) {
+			return reinterpret_cast<redis_stat_runtime<T> *>(res);
+		}
+		else {
+			return std::nullopt;
+		}
+	}
+
+	void set_results(std::vector<std::pair<int, T>> *results)
+	{
+		this->results = results;
 	}
 
 	/* Propagate results from internal representation to the tokens array */
@@ -104,6 +130,15 @@ struct redis_stat_runtime {
 			tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx);
 			tok->values[id] = val;
 		}
+
+		return true;
+	}
+
+	auto save_in_mempool(bool is_spam) const
+	{
+		auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+		/* We do not set destructor for the variable, as it should be already added on creation */
+		rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
 	}
 };
 
@@ -1095,16 +1130,6 @@ rspamd_redis_async_stat_fin(struct rspamd_stat_async_elt *elt, gpointer d)
 #endif
 
 
-/* Called on connection termination */
-static void
-rspamd_redis_fin(gpointer data)
-{
-	auto *rt = REDIS_RUNTIME(data);
-
-	delete rt;
-}
-
-
 static bool
 rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
 								   const ucl_object_t *statfile_obj,
@@ -1296,19 +1321,40 @@ rspamd_redis_runtime(struct rspamd_task *task,
 		return nullptr;
 	}
 
-	auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
-	rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt);
-
 	/* Look for the cached results */
 	if (!learn) {
-		auto var_name = fmt::format("{}_{}", object_expanded, stcf->is_spam ? "S" : "H");
-		auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str());
+		auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+																					object_expanded, stcf->is_spam);
 
-		if (res) {
-			rt->set_results(reinterpret_cast<redis_stat_runtime<float>::result_type *>(res));
+		if (maybe_existing) {
+			/* Update stcf to correspond to what we have been asked */
+			maybe_existing.value()->stcf = stcf;
+			return maybe_existing.value();
+		}
+	}
+
+	/* No cached result, create new one */
+	auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+
+	if (!learn) {
+		/*
+		 * For check, we also need to create the opposite class runtime to avoid
+		 * double call for Redis scripts.
+		 * This runtime will be filled later.
+		 */
+		auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+																					   object_expanded,
+																					   !stcf->is_spam);
+
+		if (!maybe_opposite_rt) {
+			auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+			opposite_rt->save_in_mempool(!stcf->is_spam);
+			opposite_rt->need_redis_call = false;
 		}
 	}
 
+	rt->save_in_mempool(stcf->is_spam);
+
 	return rt;
 }
 
@@ -1385,8 +1431,65 @@ rspamd_redis_classified(lua_State *L)
 	bool result = lua_toboolean(L, 2);
 
 	if (result) {
+		/* Indexes:
+		 * 3 - learned_ham (int)
+		 * 4 - learned_spam (int)
+		 * 5 - ham_tokens (pair<int, int>)
+		 * 6 - spam_tokens (pair<int, int>)
+		 */
+
+		/*
+		 * We need to fill our runtime AND the opposite runtime
+		 */
+		auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
+			rt->learned = learned;
+			redis_stat_runtime<float>::result_type *res;
+
+			res = new redis_stat_runtime<float>::result_type(lua_objlen(L, tokens_pos));
+
+			for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
+				lua_rawgeti(L, -1, 1);
+				auto idx = lua_tointeger(L, -1);
+				lua_pop(L, 1);
+
+				lua_rawgeti(L, -1, 2);
+				auto value = lua_tonumber(L, -1);
+				lua_pop(L, 1);
+
+				res->emplace_back(idx, value);
+			}
+
+			rt->set_results(res);
+		};
+
+		auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+																					   rt->redis_object_expanded,
+																					   !rt->stcf->is_spam);
+
+		if (!opposite_rt_maybe) {
+			msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+
+			return 0;
+		}
+
+		if (rt->stcf->is_spam) {
+			filler_func(rt, L, lua_tointeger(L, 4), 6);
+			filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+		}
+		else {
+			filler_func(rt, L, lua_tointeger(L, 3), 5);
+			filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+		}
+
+		/* Process all tokens */
+		g_assert(rt->tokens != nullptr);
+		rt->process_tokens(rt->tokens);
+		opposite_rt_maybe.value()->process_tokens(rt->tokens);
 	}
 	else {
+		/* Error message is on index 3 */
+		msg_err_task("cannot classify task: %s",
+					 lua_tostring(L, 3));
 	}
 
 	return 0;
@@ -1408,9 +1511,8 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
 		return FALSE;
 	}
 
-	if (rt->results) {
-		/* No need to do anything, we have results ready */
-		rt->process_tokens(tokens);
+	if (!rt->need_redis_call) {
+		/* No need to do anything, as it is already done in the opposite class processing */
 
 		return TRUE;
 	}
@@ -1440,7 +1542,6 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
 	lua_pushstring(L, cookie);
 	lua_pushcclosure(L, &rspamd_redis_classified, 1);
 
-
 	if (lua_pcall(L, 6, 0, err_idx) != 0) {
 		msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
 		lua_settop(L, err_idx - 1);


More information about the Commits mailing list