commit bef7060: [Feature] Add ROC feature to neural network plugin

Pragadeesh Chandiran pchandiran at mimecast.com
Mon Nov 15 19:07:04 UTC 2021


Author: Pragadeesh Chandiran
Date: 2021-11-08 00:13:04 -0500
URL: https://github.com/rspamd/rspamd/commit/bef70607af40943fa1626d1c0a32f94925d4f15a (refs/pull/3980/head)

[Feature] Add ROC feature to neural network plugin

---
 lualib/plugins/neural.lua  | 161 +++++++++++++++++++++++++++++++++++++++++++--
 src/plugins/lua/neural.lua |  51 +++++++++++---
 2 files changed, 198 insertions(+), 14 deletions(-)

diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
index 64d21ce37..f677119fe 100644
--- a/lualib/plugins/neural.lua
+++ b/lualib/plugins/neural.lua
@@ -54,7 +54,8 @@ local default_options = {
   learning_spawned = false,
   ann_expire = 60 * 60 * 24 * 2, -- 2 days
   hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
-  -- Check ROC curve and AUC in the ML literature
+  roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
+  roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
   spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
   ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
   flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
@@ -170,7 +171,8 @@ local redis_lua_script_maybe_lock = [[
 -- key5 - expire in seconds
 -- key6 - current time
 -- key7 - old key
--- key8 - optional PCA
+-- key8 - ROC Thresholds
+-- key9 - optional PCA
 local redis_lua_script_save_unlock = [[
   local now = tonumber(KEYS[6])
   redis.call('ZADD', KEYS[2], now, KEYS[4])
@@ -180,8 +182,9 @@ local redis_lua_script_save_unlock = [[
   redis.call('HDEL', KEYS[1], 'lock')
   redis.call('HDEL', KEYS[7], 'lock')
   redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
-  if KEYS[8] then
-    redis.call('HSET', KEYS[1], 'pca', KEYS[8])
+  redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8])
+  if KEYS[9] then
+    redis.call('HSET', KEYS[1], 'pca', KEYS[9])
   end
   return 1
 ]]
@@ -239,6 +242,126 @@ local function learn_pca(inputs, max_inputs)
   return w
 end
 
+-- This function computes optimal threshold using ROC for the given set of inputs.
+-- Returns a threshold that minimizes:
+--        alpha * (false_positive_rate)  +  beta * (false_negative_rate)
+--        Where alpha is cost of false positive result
+--              beta is cost of false negative result
+local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
+
+  -- Sorts list x and list y based on the values in list x.
+  local sort_relative = function(x, y)
+
+    local r = {}
+
+    assert(#x == #y)
+    local n = #x
+
+    local a = {}
+    local b = {}
+    for i=1,n do
+      r[i] = i
+    end
+
+    local cmp = function(p, q) return p < q end
+
+    table.sort(r, function(p, q) return cmp(x[p], x[q]) end)
+
+    for i=1,n do 
+      a[i] = x[r[i]]
+      b[i] = y[r[i]]
+    end
+
+    return a, b
+  end
+
+  local function get_scores(nn, input_vectors)
+    local scores = {}
+    for i=1,#inputs do
+      local score = nn:apply1(input_vectors[i], nn.pca)[1]
+      scores[#scores+1] = score
+    end
+
+    return scores
+  end
+
+  local fpr = {}
+	local fnr = {}
+	local scores = get_scores(ann, inputs)
+
+	scores, outputs = sort_relative(scores, outputs)
+
+	local n_samples = #outputs
+	local n_spam = 0
+	local n_ham = 0
+	local ham_count_ahead = {}
+	local spam_count_ahead = {}
+	local ham_count_behind = {}
+	local spam_count_behind = {}
+
+	ham_count_ahead[n_samples + 1] = 0
+	spam_count_ahead[n_samples + 1] = 0
+
+	for i=n_samples,1,-1 do
+
+		if outputs[i][1] == 0 then
+			n_ham = n_ham + 1
+			ham_count_ahead[i] = 1
+			spam_count_ahead[i] = 0
+		else
+			n_spam = n_spam + 1
+			ham_count_ahead[i] = 0
+			spam_count_ahead[i] = 1
+		end
+
+		ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1]
+		spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1]
+	end
+
+	for i=1,n_samples do
+    if outputs[i][1] == 0 then
+			ham_count_behind[i] = 1
+			spam_count_behind[i] = 0
+		else
+			ham_count_behind[i] = 0
+			spam_count_behind[i] = 1
+		end
+
+		if i ~= 1 then
+			ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1]
+			spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1]
+		end
+	end
+
+	for i=1,n_samples do
+		fpr[i] = 0
+		fnr[i] = 0
+
+		if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then
+			fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i])
+		end
+
+		if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then
+			fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1])
+		end
+	end
+
+	local p = n_spam / (n_spam + n_ham)
+
+	local cost = {}
+	local min_cost_idx = 0
+	local min_cost = math.huge
+	for i=1,n_samples do
+		cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i])
+		if min_cost >= cost[i] then
+			min_cost = cost[i]
+			min_cost_idx = i
+		end
+	end
+
+	return scores[min_cost_idx]
+end
+
 -- This function is intended to extend lock for ANN during training
 -- It registers periodic that increases locked key each 30 seconds unless
 -- `set.learning_spawned` is set to `true`
@@ -497,6 +620,24 @@ local function spawn_train(params)
             params.rule.prefix, params.set.name)
       end
 
+      local roc_thresholds
+      if params.rule.roc_enabled then
+        local spam_threshold = get_roc_thresholds(train_ann,
+                                                  inputs,
+                                                  outputs,
+                                                  1 - params.rule.roc_misclassification_cost,
+                                                  params.rule.roc_misclassification_cost)
+        local ham_threshold = get_roc_thresholds(train_ann,
+                                                  inputs,
+                                                  outputs,
+                                                  params.rule.roc_misclassification_cost,
+                                                  1 - params.rule.roc_misclassification_cost)
+        roc_thresholds = {spam_threshold, ham_threshold}
+      end
+
+      rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
+                              roc_thresholds[1], roc_thresholds[2])
+
       if not seen_nan then
         -- Convert to strings as ucl cannot rspamd_text properly
         local pca_data
@@ -506,6 +647,7 @@ local function spawn_train(params)
         local out = {
           ann_data = tostring(train_ann:save()),
           pca_data = pca_data,
+          roc_thresholds = roc_thresholds,
         }
 
         local final_data = ucl.to_format(out, 'msgpack')
@@ -559,12 +701,19 @@ local function spawn_train(params)
         local parsed = parser:get_object()
         local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
         local pca_data = parsed.pca_data
+        local roc_thresholds = parsed.roc_thresholds
 
         fill_set_ann(params.set, params.ann_key)
         if pca_data then
           params.set.ann.pca = rspamd_tensor.load(pca_data)
           pca_data = rspamd_util.zstd_compress(pca_data)
         end
+
+        if roc_thresholds then
+          params.set.ann.roc_thresholds = roc_thresholds
+        end
+
+
         -- Deserialise ANN from the child process
         ann_trained = rspamd_kann.load(parsed.ann_data)
         local version = (params.set.ann.version or 0) + 1
@@ -581,6 +730,7 @@ local function spawn_train(params)
         }
 
         local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+        local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
 
         rspamd_logger.infox(rspamd_config,
             'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
@@ -599,7 +749,8 @@ local function spawn_train(params)
              tostring(params.rule.ann_expire),
              tostring(os.time()),
              params.ann_key, -- old key to unlock...
-             pca_data
+             roc_thresholds_serialized,
+             pca_data,
             })
       end
     end
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 5458dd007..36eb9adaf 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -120,31 +120,47 @@ local function ann_scores_filter(task)
 
       if score > 0 then
         local result = score
+        
+        -- If spam_score_threshold is defined, override all other thresholds.
+        local spam_threshold = 0
+        if rule.spam_score_threshold then
+          spam_threshold = rule.spam_score_threshold
+        elseif rule.roc_enabled and not set.ann.roc_thresholds then
+          spam_threshold = set.ann.roc_thresholds[1]
+        end
 
-        if not rule.spam_score_threshold or result >= rule.spam_score_threshold then
+        if result >= spam_threshold then
           if rule.flat_threshold_curve then
             task:insert_result(rule.symbol_spam, 1.0, symscore)
           else
             task:insert_result(rule.symbol_spam, result, symscore)
           end
         else
-          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam_score_threshold)',
+          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
               rule.prefix, set.name, set.ann.version, symscore,
-              rule.spam_score_threshold)
+              spam_threshold)
         end
       else
         local result = -(score)
 
-        if not rule.ham_score_threshold or result >= rule.ham_score_threshold then
+        -- If ham_score_threshold is defined, override all other thresholds.
+        local ham_threshold = 0
+        if rule.ham_score_threshold then
+          ham_threshold = rule.ham_score_threshold
+        elseif rule.roc_enabled and not set.ann.roc_thresholds then
+          ham_threshold = set.ann.roc_thresholds[2]
+        end
+
+        if result >= ham_threshold then
           if rule.flat_threshold_curve then
             task:insert_result(rule.symbol_ham, 1.0, symscore)
           else
             task:insert_result(rule.symbol_ham, result, symscore)
           end
         else
-          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham_score_threshold)',
+          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
               rule.prefix, set.name, set.ann.version, result,
-              rule.ham_score_threshold)
+              ham_threshold)
         end
       end
     end
@@ -481,16 +497,32 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
           lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
               rule.prefix, set.name, ann_key)
         end
+
         if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
+          if rule.roc_enabled then
+            local ucl = require "ucl"
+            local parser = ucl.parser()
+            local ok, parse_err = parser:parse_text(data[2])
+            assert(ok, parse_err)
+            local roc_thresholds = parser:get_object()
+            set.ann.roc_thresholds = roc_thresholds
+            rspamd_logger.infox(rspamd_config, 
+                                'loaded ROC thresholds for %s:%s; version=%s',
+                                rule.prefix, set.name, profile.version)
+            rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
+          end
+        end
+
+        if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then
           -- PCA table
-          local _err,pca_data = rspamd_util.zstd_decompress(data[2])
+          local _err,pca_data = rspamd_util.zstd_decompress(data[3])
           if pca_data then
             if rule.max_inputs then
               -- We can use PCA
               set.ann.pca = rspamd_tensor.load(pca_data)
               rspamd_logger.infox(rspamd_config,
                   'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                  rule.prefix, set.name, ann_key, #data[2], profile.version)
+                  rule.prefix, set.name, ann_key, #data[3], profile.version)
             else
               -- no need in pca, why is it there?
               rspamd_logger.warnx(rspamd_config,
@@ -509,6 +541,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
             end
           end
         end
+
       else
         lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
             rule.prefix, set.name, ann_key)
@@ -522,7 +555,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
       false, -- is write
       data_cb, --callback
       'HMGET', -- command
-      {ann_key, 'ann', 'pca'}, -- arguments
+      {ann_key, 'ann', 'roc_thresholds', 'pca'}, -- arguments
       {opaque_data = true}
   )
 end


More information about the Commits mailing list