commit bef7060: [Feature] Add ROC feature to neural network plugin
Pragadeesh Chandiran
pchandiran at mimecast.com
Mon Nov 15 19:07:04 UTC 2021
Author: Pragadeesh Chandiran
Date: 2021-11-08 00:13:04 -0500
URL: https://github.com/rspamd/rspamd/commit/bef70607af40943fa1626d1c0a32f94925d4f15a (refs/pull/3980/head)
[Feature] Add ROC feature to neural network plugin
---
lualib/plugins/neural.lua | 161 +++++++++++++++++++++++++++++++++++++++++++--
src/plugins/lua/neural.lua | 51 +++++++++++---
2 files changed, 198 insertions(+), 14 deletions(-)
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
index 64d21ce37..f677119fe 100644
--- a/lualib/plugins/neural.lua
+++ b/lualib/plugins/neural.lua
@@ -54,7 +54,8 @@ local default_options = {
learning_spawned = false,
ann_expire = 60 * 60 * 24 * 2, -- 2 days
hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
- -- Check ROC curve and AUC in the ML literature
+ roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
+ roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
@@ -170,7 +171,8 @@ local redis_lua_script_maybe_lock = [[
-- key5 - expire in seconds
-- key6 - current time
-- key7 - old key
--- key8 - optional PCA
+-- key8 - ROC Thresholds
+-- key9 - optional PCA
local redis_lua_script_save_unlock = [[
local now = tonumber(KEYS[6])
redis.call('ZADD', KEYS[2], now, KEYS[4])
@@ -180,8 +182,9 @@ local redis_lua_script_save_unlock = [[
redis.call('HDEL', KEYS[1], 'lock')
redis.call('HDEL', KEYS[7], 'lock')
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
- if KEYS[8] then
- redis.call('HSET', KEYS[1], 'pca', KEYS[8])
+ redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8])
+ if KEYS[9] then
+ redis.call('HSET', KEYS[1], 'pca', KEYS[9])
end
return 1
]]
@@ -239,6 +242,126 @@ local function learn_pca(inputs, max_inputs)
return w
end
+-- This function computes optimal threshold using ROC for the given set of inputs.
+-- Returns a threshold that minimizes:
+-- alpha * (false_positive_rate) + beta * (false_negative_rate)
+-- Where alpha is cost of false positive result
+-- beta is cost of false negative result
+local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
+
+ -- Sorts list x and list y based on the values in list x.
+ local sort_relative = function(x, y)
+
+ local r = {}
+
+ assert(#x == #y)
+ local n = #x
+
+ local a = {}
+ local b = {}
+ for i=1,n do
+ r[i] = i
+ end
+
+ local cmp = function(p, q) return p < q end
+
+ table.sort(r, function(p, q) return cmp(x[p], x[q]) end)
+
+ for i=1,n do
+ a[i] = x[r[i]]
+ b[i] = y[r[i]]
+ end
+
+ return a, b
+ end
+
+ local function get_scores(nn, input_vectors)
+ local scores = {}
+ for i=1,#inputs do
+ local score = nn:apply1(input_vectors[i], nn.pca)[1]
+ scores[#scores+1] = score
+ end
+
+ return scores
+ end
+
+ local fpr = {}
+ local fnr = {}
+ local scores = get_scores(ann, inputs)
+
+ scores, outputs = sort_relative(scores, outputs)
+
+ local n_samples = #outputs
+ local n_spam = 0
+ local n_ham = 0
+ local ham_count_ahead = {}
+ local spam_count_ahead = {}
+ local ham_count_behind = {}
+ local spam_count_behind = {}
+
+ ham_count_ahead[n_samples + 1] = 0
+ spam_count_ahead[n_samples + 1] = 0
+
+ for i=n_samples,1,-1 do
+
+ if outputs[i][1] == 0 then
+ n_ham = n_ham + 1
+ ham_count_ahead[i] = 1
+ spam_count_ahead[i] = 0
+ else
+ n_spam = n_spam + 1
+ ham_count_ahead[i] = 0
+ spam_count_ahead[i] = 1
+ end
+
+ ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1]
+ spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1]
+ end
+
+ for i=1,n_samples do
+ if outputs[i][1] == 0 then
+ ham_count_behind[i] = 1
+ spam_count_behind[i] = 0
+ else
+ ham_count_behind[i] = 0
+ spam_count_behind[i] = 1
+ end
+
+ if i ~= 1 then
+ ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1]
+ spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1]
+ end
+ end
+
+ for i=1,n_samples do
+ fpr[i] = 0
+ fnr[i] = 0
+
+ if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then
+ fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i])
+ end
+
+ if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then
+ fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1])
+ end
+ end
+
+ local p = n_spam / (n_spam + n_ham)
+
+ local cost = {}
+ local min_cost_idx = 0
+ local min_cost = math.huge
+ for i=1,n_samples do
+ cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i])
+ if min_cost >= cost[i] then
+ min_cost = cost[i]
+ min_cost_idx = i
+ end
+ end
+
+ return scores[min_cost_idx]
+end
+
-- This function is intended to extend lock for ANN during training
-- It registers periodic that increases locked key each 30 seconds unless
-- `set.learning_spawned` is set to `true`
@@ -497,6 +620,24 @@ local function spawn_train(params)
params.rule.prefix, params.set.name)
end
+ local roc_thresholds
+ if params.rule.roc_enabled then
+ local spam_threshold = get_roc_thresholds(train_ann,
+ inputs,
+ outputs,
+ 1 - params.rule.roc_misclassification_cost,
+ params.rule.roc_misclassification_cost)
+ local ham_threshold = get_roc_thresholds(train_ann,
+ inputs,
+ outputs,
+ params.rule.roc_misclassification_cost,
+ 1 - params.rule.roc_misclassification_cost)
+ roc_thresholds = {spam_threshold, ham_threshold}
+ end
+
+ rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
+ roc_thresholds[1], roc_thresholds[2])
+
if not seen_nan then
-- Convert to strings as ucl cannot rspamd_text properly
local pca_data
@@ -506,6 +647,7 @@ local function spawn_train(params)
local out = {
ann_data = tostring(train_ann:save()),
pca_data = pca_data,
+ roc_thresholds = roc_thresholds,
}
local final_data = ucl.to_format(out, 'msgpack')
@@ -559,12 +701,19 @@ local function spawn_train(params)
local parsed = parser:get_object()
local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
local pca_data = parsed.pca_data
+ local roc_thresholds = parsed.roc_thresholds
fill_set_ann(params.set, params.ann_key)
if pca_data then
params.set.ann.pca = rspamd_tensor.load(pca_data)
pca_data = rspamd_util.zstd_compress(pca_data)
end
+
+ if roc_thresholds then
+ params.set.ann.roc_thresholds = roc_thresholds
+ end
+
+
-- Deserialise ANN from the child process
ann_trained = rspamd_kann.load(parsed.ann_data)
local version = (params.set.ann.version or 0) + 1
@@ -581,6 +730,7 @@ local function spawn_train(params)
}
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+ local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
rspamd_logger.infox(rspamd_config,
'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
@@ -599,7 +749,8 @@ local function spawn_train(params)
tostring(params.rule.ann_expire),
tostring(os.time()),
params.ann_key, -- old key to unlock...
- pca_data
+ roc_thresholds_serialized,
+ pca_data,
})
end
end
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 5458dd007..36eb9adaf 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -120,31 +120,47 @@ local function ann_scores_filter(task)
if score > 0 then
local result = score
+
+ -- If spam_score_threshold is defined, override all other thresholds.
+ local spam_threshold = 0
+ if rule.spam_score_threshold then
+ spam_threshold = rule.spam_score_threshold
+ elseif rule.roc_enabled and not set.ann.roc_thresholds then
+ spam_threshold = set.ann.roc_thresholds[1]
+ end
- if not rule.spam_score_threshold or result >= rule.spam_score_threshold then
+ if result >= spam_threshold then
if rule.flat_threshold_curve then
task:insert_result(rule.symbol_spam, 1.0, symscore)
else
task:insert_result(rule.symbol_spam, result, symscore)
end
else
- lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam_score_threshold)',
+ lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
rule.prefix, set.name, set.ann.version, symscore,
- rule.spam_score_threshold)
+ spam_threshold)
end
else
local result = -(score)
- if not rule.ham_score_threshold or result >= rule.ham_score_threshold then
+ -- If ham_score_threshold is defined, override all other thresholds.
+ local ham_threshold = 0
+ if rule.ham_score_threshold then
+ ham_threshold = rule.ham_score_threshold
+ elseif rule.roc_enabled and not set.ann.roc_thresholds then
+ ham_threshold = set.ann.roc_thresholds[2]
+ end
+
+ if result >= ham_threshold then
if rule.flat_threshold_curve then
task:insert_result(rule.symbol_ham, 1.0, symscore)
else
task:insert_result(rule.symbol_ham, result, symscore)
end
else
- lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham_score_threshold)',
+ lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
rule.prefix, set.name, set.ann.version, result,
- rule.ham_score_threshold)
+ ham_threshold)
end
end
end
@@ -481,16 +497,32 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
end
+
if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
+ if rule.roc_enabled then
+ local ucl = require "ucl"
+ local parser = ucl.parser()
+ local ok, parse_err = parser:parse_text(data[2])
+ assert(ok, parse_err)
+ local roc_thresholds = parser:get_object()
+ set.ann.roc_thresholds = roc_thresholds
+ rspamd_logger.infox(rspamd_config,
+ 'loaded ROC thresholds for %s:%s; version=%s',
+ rule.prefix, set.name, profile.version)
+ rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
+ end
+ end
+
+ if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then
-- PCA table
- local _err,pca_data = rspamd_util.zstd_decompress(data[2])
+ local _err,pca_data = rspamd_util.zstd_decompress(data[3])
if pca_data then
if rule.max_inputs then
-- We can use PCA
set.ann.pca = rspamd_tensor.load(pca_data)
rspamd_logger.infox(rspamd_config,
'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #data[2], profile.version)
+ rule.prefix, set.name, ann_key, #data[3], profile.version)
else
-- no need in pca, why is it there?
rspamd_logger.warnx(rspamd_config,
@@ -509,6 +541,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
end
end
end
+
else
lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
@@ -522,7 +555,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
false, -- is write
data_cb, --callback
'HMGET', -- command
- {ann_key, 'ann', 'pca'}, -- arguments
+ {ann_key, 'ann', 'roc_thresholds', 'pca'}, -- arguments
{opaque_data = true}
)
end
More information about the Commits
mailing list