commit 5e40de9: [Minor] Neural: Fix random sampling

Vsevolod Stakhov vsevolod at highsecure.ru
Thu Oct 24 17:28:06 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-10-24 18:22:51 +0100
URL: https://github.com/rspamd/rspamd/commit/5e40de9bed1151e8f85c67b21d5d7ce87d6dc014 (HEAD -> master)

[Minor] Neural: Fix random sampling
Issue: #3119

---
 src/plugins/lua/neural.lua | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 87df49325..faeb66412 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -97,6 +97,7 @@ end
 -- key1 - ann key
 -- key2 - spam or ham
 -- key3 - maximum trains
+-- key4 - sampling coin (as Redis scripts do not allow math.random calls)
 -- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
 local redis_lua_script_can_store_train_vec = [[
   local prefix = KEYS[1]
@@ -105,6 +106,7 @@ local redis_lua_script_can_store_train_vec = [[
   local nspam = 0
   local nham = 0
   local lim = tonumber(KEYS[3])
+  local coin = tonumber(KEYS[4])
 
   local ret = redis.call('LLEN', prefix .. '_spam')
   if ret then nspam = tonumber(ret) end
@@ -116,7 +118,7 @@ local redis_lua_script_can_store_train_vec = [[
       if nspam > nham then
         -- Apply sampling
         local skip_rate = 1.0 - nham / (nspam + 1)
-        if math.random() < skip_rate then
+        if coun < skip_rate then
           return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}
         end
       end
@@ -129,7 +131,7 @@ local redis_lua_script_can_store_train_vec = [[
       if nham > nspam then
         -- Apply sampling
         local skip_rate = 1.0 - nspam / (nham + 1)
-        if math.random() < skip_rate then
+        if coin < skip_rate then
           return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}
         end
       end
@@ -488,6 +490,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
           set.ann.redis_key,
           learn_type,
           tostring(train_opts.max_trains),
+          tostring(math.random()),
         })
   else
     lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s',


More information about the Commits mailing list