commit 82e5883: [Feature] Neural: Allow to balance FP/FN for the network
Vsevolod Stakhov
vsevolod at highsecure.ru
Thu Apr 29 18:42:03 UTC 2021
Author: Vsevolod Stakhov
Date: 2021-04-29 19:41:03 +0100
URL: https://github.com/rspamd/rspamd/commit/82e588390a7f0dc000e74497cfb84e25dcbfafe5 (HEAD -> master)
[Feature] Neural: Allow to balance FP/FN for the network
---
lualib/plugins/neural.lua | 3 +++
src/plugins/lua/neural.lua | 18 ++++++++++++++++--
2 files changed, 19 insertions(+), 2 deletions(-)
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
index c35fc0eeb..f0d5cf582 100644
--- a/lualib/plugins/neural.lua
+++ b/lualib/plugins/neural.lua
@@ -54,6 +54,9 @@ local default_options = {
learning_spawned = false,
ann_expire = 60 * 60 * 24 * 2, -- 2 days
hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
+ -- Check ROC curve and AUC in the ML literature
+ spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
+ ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
symbol_spam = 'NEURAL_SPAM',
symbol_ham = 'NEURAL_HAM',
max_inputs = nil, -- when PCA is used
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 894d42e30..ca11d9e66 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -119,10 +119,24 @@ local function ann_scores_filter(task)
if score > 0 then
local result = score
- task:insert_result(rule.symbol_spam, result, symscore)
+
+ if not rule.spam_score_threshold or result >= rule.spam_score_threshold then
+ task:insert_result(rule.symbol_spam, result, symscore)
+ else
+ lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam_score_threshold)',
+ rule.prefix, set.name, set.ann.version, symscore,
+ rule.spam_score_threshold)
+ end
else
local result = -(score)
- task:insert_result(rule.symbol_ham, result, symscore)
+
+ if not rule.ham_score_threshold or result >= rule.ham_score_threshold then
+ task:insert_result(rule.symbol_ham, result, symscore)
+ else
+ lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham_score_threshold)',
+ rule.prefix, set.name, set.ann.version, result,
+ rule.ham_score_threshold)
+ end
end
end
end
More information about the Commits
mailing list