commit 7ff9314: [Project] Rework stat runtime

Vsevolod Stakhov vsevolod at rspamd.com
Mon Jul 29 17:49:48 UTC 2024


Author: Vsevolod Stakhov
Date: 2023-12-06 14:46:45 +0000
URL: https://github.com/rspamd/rspamd/commit/7ff93147757a3c491a0dba20558fa54eb97b48b0

[Project] Rework stat runtime

---
 src/libstat/backends/redis_backend.cxx | 177 +++++++++++++++++----------------
 1 file changed, 90 insertions(+), 87 deletions(-)

diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
index 973e60671..46b27cb15 100644
--- a/src/libstat/backends/redis_backend.cxx
+++ b/src/libstat/backends/redis_backend.cxx
@@ -19,6 +19,11 @@
 #include "stat_internal.h"
 #include "upstream.h"
 #include "libserver/mempool_vars_internal.h"
+#include "fmt/core.h"
+
+#include <string>
+#include <cstdint>
+#include <vector>
 
 #define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr,                                                 \
 																rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \
@@ -28,7 +33,7 @@
 INIT_LOG_MODULE(stat_redis)
 
 #define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p))
-#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime *>(p))
+#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p))
 #define REDIS_DEFAULT_OBJECT "%s%l"
 #define REDIS_DEFAULT_USERS_OBJECT "%s%l%r"
 #define REDIS_DEFAULT_TIMEOUT 0.5
@@ -38,31 +43,68 @@ INIT_LOG_MODULE(stat_redis)
 struct redis_stat_ctx {
 	lua_State *L;
 	struct rspamd_statfile_config *stcf;
-	gint conf_ref;
 	struct rspamd_stat_async_elt *stat_elt;
-	const char *redis_object;
-	gboolean enable_users;
-	gboolean store_tokens;
-	gboolean new_schema;
-	gboolean enable_signatures;
-	guint expiry;
-	guint max_users;
-	gint cbref_user;
-
-	gint cbref_classify;
-	gint cbref_learn;
+	const char *redis_object = REDIS_DEFAULT_OBJECT;
+	bool enable_users = false;
+	bool store_tokens = false;
+	bool enable_signatures = false;
+	unsigned expiry;
+	unsigned max_users = REDIS_MAX_USERS;
+	int cbref_user = -1;
+
+	int cbref_classify = -1;
+	int cbref_learn = -1;
+	int conf_ref = -1;
 };
 
 
+template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
 struct redis_stat_runtime {
 	struct redis_stat_ctx *ctx;
 	struct rspamd_task *task;
 	struct rspamd_statfile_config *stcf;
 	GPtrArray *tokens;
-	gchar *redis_object_expanded;
-	guint64 learned;
-	gint id;
-	GError *err;
+	const char *redis_object_expanded;
+	std::uint64_t learned = 0;
+	int id;
+	std::vector<std::pair<int, T>> *results = nullptr;
+
+	using result_type = std::vector<std::pair<int, T>>;
+
+	explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
+		: ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
+	{
+	}
+
+	void init()
+	{
+	}
+
+	void set_results(std::vector<std::pair<int, T>> *_results)
+	{
+		results = _results;
+	}
+
+	~redis_stat_runtime()
+	{
+		g_ptr_array_unref(tokens);
+		delete results;
+	}
+
+	/* Propagate results from internal representation to the tokens array */
+	auto process_tokens(GPtrArray *tokens) const -> bool
+	{
+		rspamd_token_t *tok;
+
+		if (!results) {
+			return false;
+		}
+
+		for (auto [idx, val]: *results) {
+			tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx);
+			tok->values[id] = val;
+		}
+	}
 };
 
 /* Used to get statistics from redis */
@@ -217,14 +259,7 @@ gsize rspamd_redis_expand_object(const gchar *pattern,
 				/* Label miss is OK */
 				break;
 			case 's':
-				if (ctx->new_schema) {
-					tlen += sizeof("RS") - 1;
-				}
-				else {
-					if (stcf->symbol) {
-						tlen += strlen(stcf->symbol);
-					}
-				}
+				tlen += sizeof("RS") - 1;
 				break;
 			default:
 				state = just_char;
@@ -306,14 +341,7 @@ gsize rspamd_redis_expand_object(const gchar *pattern,
 				}
 				break;
 			case 's':
-				if (ctx->new_schema) {
-					d += rspamd_strlcpy(d, "RS", end - d);
-				}
-				else {
-					if (stcf->symbol) {
-						d += rspamd_strlcpy(d, stcf->symbol, end - d);
-					}
-				}
+				d += rspamd_strlcpy(d, "RS", end - d);
 				break;
 			default:
 				state = just_char;
@@ -1071,15 +1099,9 @@ rspamd_redis_async_stat_fin(struct rspamd_stat_async_elt *elt, gpointer d)
 static void
 rspamd_redis_fin(gpointer data)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(data);
-
-	if (rt->err) {
-		g_error_free(rt->err);
-	}
+	auto *rt = REDIS_RUNTIME(data);
 
-	if (rt->tokens) {
-		g_ptr_array_unref(rt->tokens);
-	}
+	delete rt;
 }
 
 
@@ -1260,7 +1282,6 @@ rspamd_redis_runtime(struct rspamd_task *task,
 					 gboolean learn, gpointer c, gint _id)
 {
 	struct redis_stat_ctx *ctx = REDIS_CTX(c);
-	struct redis_stat_runtime *rt;
 	char *object_expanded = nullptr;
 
 	g_assert(ctx != nullptr);
@@ -1275,16 +1296,18 @@ rspamd_redis_runtime(struct rspamd_task *task,
 		return nullptr;
 	}
 
-	/* Look for the cached results */
-
+	auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+	rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt);
 
-	rt = (struct redis_stat_runtime *) rspamd_mempool_alloc0(task->task_pool, sizeof(*rt));
-	rt->task = task;
-	rt->ctx = ctx;
-	rt->redis_object_expanded = object_expanded;
-	rt->stcf = stcf;
+	/* Look for the cached results */
+	if (!learn) {
+		auto var_name = fmt::format("{}_{}", object_expanded, stcf->is_spam ? "S" : "H");
+		auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str());
 
-	rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt);
+		if (res) {
+			rt->set_results(reinterpret_cast<redis_stat_runtime<float>::result_type *>(res));
+		}
+	}
 
 	return rt;
 }
@@ -1348,9 +1371,9 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize
 static gint
 rspamd_redis_classified(lua_State *L)
 {
-	const gchar *cookie = lua_tostring(L, lua_upvalueindex(1));
-	struct rspamd_task *task = lua_check_task(L, 1);
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
+	const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+	auto *task = lua_check_task(L, 1);
+	auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
 	/* TODO: write it */
 
 	if (rt == nullptr) {
@@ -1374,8 +1397,8 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
 							GPtrArray *tokens,
 							gint id, gpointer p)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(p);
-	lua_State *L = rt->ctx->L;
+	auto *rt = REDIS_RUNTIME(p);
+	auto *L = rt->ctx->L;
 
 	if (rspamd_session_blocked(task->s)) {
 		return FALSE;
@@ -1385,7 +1408,12 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
 		return FALSE;
 	}
 
-	/* TODO: check if we have tokens for that particular id for this class */
+	if (rt->results) {
+		/* No need to do anything, we have results ready */
+		rt->process_tokens(tokens);
+
+		return TRUE;
+	}
 
 	gsize tokens_len;
 	gchar *tokens_buf = rspamd_redis_serialize_tokens(task, tokens, &tokens_len);
@@ -1429,19 +1457,6 @@ gboolean
 rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime,
 							  gpointer ctx)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
-
-	if (rt->err) {
-		msg_info_task("cannot retrieve stat tokens from Redis: %e", rt->err);
-		g_error_free(rt->err);
-		rt->err = nullptr;
-		rspamd_redis_fin(rt);
-
-		return FALSE;
-	}
-
-	rspamd_redis_fin(rt);
-
 	return TRUE;
 }
 
@@ -1449,7 +1464,7 @@ gboolean
 rspamd_redis_learn_tokens(struct rspamd_task *task, GPtrArray *tokens,
 						  gint id, gpointer p)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(p);
+	auto *rt = REDIS_RUNTIME(p);
 
 	/* TODO: write learn function */
 
@@ -1461,18 +1476,6 @@ gboolean
 rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime,
 							gpointer ctx, GError **err)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
-
-	if (rt->err) {
-		g_propagate_error(err, rt->err);
-		rt->err = nullptr;
-		rspamd_redis_fin(rt);
-
-		return FALSE;
-	}
-
-	rspamd_redis_fin(rt);
-
 	return TRUE;
 }
 
@@ -1480,7 +1483,7 @@ gulong
 rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime,
 						  gpointer ctx)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+	auto *rt = REDIS_RUNTIME(runtime);
 
 	return rt->learned;
 }
@@ -1489,7 +1492,7 @@ gulong
 rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime,
 						gpointer ctx)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+	auto *rt = REDIS_RUNTIME(runtime);
 
 	/* XXX: may cause races */
 	return rt->learned + 1;
@@ -1499,7 +1502,7 @@ gulong
 rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime,
 						gpointer ctx)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+	auto *rt = REDIS_RUNTIME(runtime);
 
 	/* XXX: may cause races */
 	return rt->learned + 1;
@@ -1509,7 +1512,7 @@ gulong
 rspamd_redis_learns(struct rspamd_task *task, gpointer runtime,
 					gpointer ctx)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+	auto *rt = REDIS_RUNTIME(runtime);
 
 	return rt->learned;
 }
@@ -1518,7 +1521,7 @@ ucl_object_t *
 rspamd_redis_get_stat(gpointer runtime,
 					  gpointer ctx)
 {
-	struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+	auto *rt = REDIS_RUNTIME(runtime);
 	struct rspamd_redis_stat_elt *st;
 	redisAsyncContext *redis;
 


More information about the Commits mailing list