commit 73d2cee: [Project] Reputation: Migrate to adaptive EMA model

Vsevolod Stakhov vsevolod at highsecure.ru
Wed May 15 14:14:04 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-05-15 14:09:40 +0100
URL: https://github.com/rspamd/rspamd/commit/73d2cee82a5d55a239c628c21454137027a29db2

[Project] Reputation: Migrate to adaptive EMA model

---
 src/plugins/lua/reputation.lua | 371 ++++++++++++++++-------------------------
 1 file changed, 145 insertions(+), 226 deletions(-)

diff --git a/src/plugins/lua/reputation.lua b/src/plugins/lua/reputation.lua
index ad05023be..f1062dbaa 100644
--- a/src/plugins/lua/reputation.lua
+++ b/src/plugins/lua/reputation.lua
@@ -35,12 +35,8 @@ local ts = require("tableshape").types
 
 local redis_params = nil
 local default_expiry = 864000 -- 10 day by default
+local default_prefix = 'RR:' -- Rspamd Reputation
 
-local keymap_schema = ts.shape{
-  ['spam'] = ts.string,
-  ['junk'] = ts.string,
-  ['ham'] = ts.string,
-}
 
 -- Get reputation from ham/spam/probable hits
 local function generic_reputation_calc(token, rule, mult)
@@ -50,16 +46,11 @@ local function generic_reputation_calc(token, rule, mult)
     return cfg.score_calc_func(rule, token, mult)
   end
 
-  local ham_samples = token.h or 0
-  local spam_samples = token.s or 0
-  local probable_samples = token.p or 0
-  local total_samples = ham_samples + spam_samples + probable_samples
-
-  if total_samples < cfg.lower_bound then return 0 end
+  if token[1] < cfg.lower_bound then return 0 end
 
-  local score = (ham_samples / total_samples) * -1.0 +
-      (spam_samples / total_samples) +
-      (probable_samples / total_samples) * 0.5
+  local score = fun.foldl(function(acc, v)
+    return acc + v
+  end, 0.0, fun.map(tonumber, token[2])) / #token[2]
 
   return score
 end
@@ -79,6 +70,38 @@ local function add_symbol_score(task, rule, mult, params)
   end
 end
 
+local function sub_symbol_score(task, rule, score)
+  local function sym_score(sym)
+    local s = task:get_symbol(sym)[1]
+    return s.score
+  end
+  if rule.config.split_symbols then
+    local spam_sym = rule.symbol .. '_SPAM'
+    local ham_sym = rule.symbol .. '_HAM'
+
+    if task:has_symbol(spam_sym) then
+      score = score - sym_score(spam_sym)
+    elseif task:has_symbol(ham_sym) then
+      score = score - sym_score(ham_sym)
+    end
+  else
+    if task:has_symbol(rule.symbol) then
+      score = score - sym_score(rule.symbol)
+    end
+  end
+
+  return score
+end
+
+-- Extracts task score and subtracts score of the rule itself
+local function extract_task_score(task, rule)
+  local _,score = lua_util.get_task_verdict(task)
+
+  if not score then return nil end
+
+  return sub_symbol_score(task, rule, score)
+end
+
 -- DKIM Selector functions
 local gr
 local function gen_dkim_queries(task, rule)
@@ -164,28 +187,14 @@ local function dkim_reputation_filter(task, rule)
 end
 
 local function dkim_reputation_idempotent(task, rule)
-  local verdict = lua_util.get_task_verdict(task)
-  local token = {
-  }
-  local cfg = rule.selector.config
-  local need_set = false
-
-  -- TODO: take metric score into consideration
-  local k = cfg.keys_map[verdict]
-
-  if k then
-    token[k] = 1.0
-    need_set = true
-  end
-
-  if need_set then
-
-    local requests = gen_dkim_queries(task, rule)
+  local requests = gen_dkim_queries(task, rule)
+  local sc = extract_task_score(task, rule)
 
+  if sc then
     for dom,res in pairs(requests) do
       -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs
       local query = string.format('%s.%s', dom, res)
-      rule.backend.set_token(task, rule, query, token)
+      rule.backend.set_token(task, rule, query, sc)
     end
   end
 end
@@ -212,15 +221,6 @@ end
 
 local dkim_selector = {
   config = {
-    -- keys map between actions and hash elements in bucket,
-    -- h is for ham,
-    -- s is for spam,
-    -- p is for probable spam
-    keys_map = {
-      ['spam'] = 's',
-      ['junk'] = 'p',
-      ['ham'] = 'h'
-    },
     symbol = 'DKIM_SCORE', -- symbol to be inserted
     lower_bound = 10, -- minimum number of messages to be scored
     min_score = nil,
@@ -270,7 +270,7 @@ local function gen_url_queries(task, rule)
 end
 
 local function url_reputation_filter(task, rule)
-  local requests = gen_url_queries(task, rule)
+  local requests = lua_util.extract_specific_urls(task, rule.selector.config.max_urls)
   local results = {}
   local nchecked = 0
 
@@ -304,47 +304,24 @@ local function url_reputation_filter(task, rule)
     end
   end
 
-  for _,tld in ipairs(requests) do
-    rule.backend.get_token(task, rule, tld[1], tokens_cb)
+  for _,u in ipairs(requests) do
+    rule.backend.get_token(task, rule, u:get_tld(), tokens_cb)
   end
 end
 
 local function url_reputation_idempotent(task, rule)
-  local verdict = lua_util.get_task_verdict(task)
-  local token = {
-  }
-  local cfg = rule.selector.config
-  local need_set = false
-
-  -- TODO: take metric score into consideration
-  local k = cfg.keys_map[verdict]
-
-  if k then
-    token[k] = 1.0
-    need_set = true
-  end
-
-  if need_set then
-
-    local requests = gen_url_queries(task, rule)
+  local requests = gen_url_queries(task, rule)
+  local sc = extract_task_score(task, rule)
 
+  if sc then
     for _,tld in ipairs(requests) do
-      rule.backend.set_token(task, rule, tld[1], token)
+      rule.backend.set_token(task, rule, tld[1], sc)
     end
   end
 end
 
 local url_selector = {
   config = {
-    -- keys map between actions and hash elements in bucket,
-    -- h is for ham,
-    -- s is for spam,
-    -- p is for probable spam
-    keys_map = {
-      ['spam'] = 's',
-      ['junk'] = 'p',
-      ['ham'] = 'h'
-    },
     symbol = 'URL_SCORE', -- symbol to be inserted
     lower_bound = 10, -- minimum number of messages to be scored
     min_score = nil,
@@ -489,44 +466,20 @@ local function ip_reputation_idempotent(task, rule)
       return
     end
   end
-
-  local verdict = lua_util.get_task_verdict(task)
-  local token = {
-  }
-  local need_set = false
-
-  -- TODO: take metric score into consideration
-  local k = cfg.keys_map[verdict]
-
-  if k then
-    token[k] = 1.0
-    need_set = true
+  local sc = extract_task_score(task, rule)
+  if asn then
+    rule.backend.set_token(task, rule, cfg.asn_prefix .. asn, sc)
   end
-
-  if need_set then
-    if asn then
-      rule.backend.set_token(task, rule, cfg.asn_prefix .. asn, token)
-    end
-    if country then
-      rule.backend.set_token(task, rule, cfg.country_prefix .. country, token)
-    end
-
-    rule.backend.set_token(task, rule, cfg.ip_prefix .. tostring(ip), token)
+  if country then
+    rule.backend.set_token(task, rule, cfg.country_prefix .. country, sc)
   end
+
+  rule.backend.set_token(task, rule, cfg.ip_prefix .. tostring(ip), sc)
 end
 
 -- Selectors are used to extract reputation tokens
 local ip_selector = {
   config = {
-    -- keys map between actions and hash elements in bucket,
-    -- h is for ham,
-    -- s is for spam,
-    -- p is for probable spam
-    keys_map = {
-      ['spam'] = 's',
-      ['junk'] = 'p',
-      ['ham'] = 'h'
-    },
     scores = { -- how each component is evaluated
       ['asn'] = 0.4,
       ['country'] = 0.01,
@@ -578,46 +531,23 @@ local function spf_reputation_filter(task, rule)
 end
 
 local function spf_reputation_idempotent(task, rule)
-  local verdict = lua_util.get_task_verdict(task)
+  local sc = extract_task_score(task, rule)
   local spf_record = task:get_mempool():get_variable('spf_record')
   local spf_allow = task:has_symbol('R_SPF_ALLOW')
-  local token = {
-  }
-  local cfg = rule.selector.config
-  local need_set = false
 
-  if not spf_record or not spf_allow then return end
+  if not spf_record or not spf_allow or not sc then return end
 
-  -- TODO: take metric score into consideration
-  local k = cfg.keys_map[verdict]
-
-  if k then
-    token[k] = 1.0
-    need_set = true
-  end
-
-  if need_set then
-    local cr = require "rspamd_cryptobox_hash"
-    local hkey = cr.create(spf_record):base32():sub(1, 32)
+  local cr = require "rspamd_cryptobox_hash"
+  local hkey = cr.create(spf_record):base32():sub(1, 32)
 
-    lua_util.debugm(N, task, 'set spf record %s -> %s = %s',
-        spf_record, hkey, token)
-    rule.backend.set_token(task, rule, hkey, token)
-  end
+  lua_util.debugm(N, task, 'set spf record %s -> %s = %s',
+      spf_record, hkey, token)
+  rule.backend.set_token(task, rule, hkey, sc)
 end
 
 
 local spf_selector = {
   config = {
-    -- keys map between actions and hash elements in bucket,
-    -- h is for ham,
-    -- s is for spam,
-    -- p is for probable spam
-    keys_map = {
-      ['spam'] = 's',
-      ['junk'] = 'p',
-      ['ham'] = 'h'
-    },
     symbol = 'SPF_SCORE', -- symbol to be inserted
     lower_bound = 10, -- minimum number of messages to be scored
     min_score = nil,
@@ -694,32 +624,23 @@ local function generic_reputation_filter(task, rule)
 end
 
 local function generic_reputation_idempotent(task, rule)
-  local verdict = lua_util.get_task_verdict(task)
+  local sc = extract_task_score(task, rule)
   local cfg = rule.selector.config
-  local need_set = false
-  local token = {}
 
   local selector_res = cfg.selector(task)
   if not selector_res then return end
 
-  local k = cfg.keys_map[verdict]
-
-  if k then
-    token[k] = 1.0
-    need_set = true
-  end
-
-  if need_set then
+  if sc then
     if type(selector_res) == 'table' then
       fun.each(function(e)
         lua_util.debugm(N, task, 'set generic selector (%s) %s = %s',
-            rule['symbol'], e, token)
-        rule.backend.set_token(task, rule, e, token)
+            rule['symbol'], e, sc)
+        rule.backend.set_token(task, rule, e, sc)
       end, selector_res)
     else
       lua_util.debugm(N, task, 'set generic selector (%s) %s = %s',
-          rule['symbol'], selector_res, token)
-      rule.backend.set_token(task, rule, selector_res, token)
+          rule['symbol'], selector_res, sc)
+      rule.backend.set_token(task, rule, selector_res, sc)
     end
   end
 end
@@ -727,7 +648,6 @@ end
 
 local generic_selector = {
   schema = ts.shape{
-    keys_map = keymap_schema,
     lower_bound = ts.number + ts.string / tonumber,
     max_score = ts.number:is_optional(),
     min_score = ts.number:is_optional(),
@@ -738,15 +658,6 @@ local generic_selector = {
     whitelist = ts.string:is_optional(),
   },
   config = {
-    -- keys map between actions and hash elements in bucket,
-    -- h is for ham,
-    -- s is for spam,
-    -- p is for probable spam
-    keys_map = {
-      ['spam'] = 's',
-      ['junk'] = 'p',
-      ['ham'] = 'h'
-    },
     lower_bound = 10, -- minimum number of messages to be scored
     min_score = nil,
     max_score = nil,
@@ -806,6 +717,10 @@ local function gen_token_key(token, rule)
     res = string.sub(res, 1, rule.backend.config.hashlen)
   end
 
+  if rule.backend.config.prefix then
+    res = rule.backend.config.prefix .. res
+  end
+
   return res
 end
 
@@ -887,71 +802,78 @@ local function reputation_redis_init(rule, cfg, ev_base, worker)
     return false
   end
   -- Init scripts for buckets
+  -- Redis script to extract data from Redis buckets
+  -- KEYS[1] - key to extract
+  -- Value returned - table of scores as a strings vector + number of scores
   local redis_get_script_tpl = [[
-local key = KEYS[1] .. '${name}'
-local vals = redis.call('HGETALL', key)
-for i=1,#vals,2 do
-  local k = vals[i]
-  local v = vals[i + 1]
-  if scores[k] then
-    scores[k] = scores[k] + tonumber(v) * ${mult}
+  local cnt = redis.call('HGET', KEYS[1], 'n')
+  local results = {}
+  if cnt then
+  {% for w in windows %}
+  local sc = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}'))
+  table.insert(results, tostring(sc * {= w.mult =}))
+  {% endfor %}
   else
-    scores[k] = tonumber(v) * ${mult}
-  end
-end
-]]
-  local redis_script_tbl = {'local scores = {}'}
-  for _,bucket in ipairs(rule.backend.config.buckets) do
-    table.insert(redis_script_tbl, lua_util.template(redis_get_script_tpl, bucket))
-  end
-  table.insert(redis_script_tbl, [[
-  local result = {}
-  for k,v in pairs(scores) do
-   table.insert(result, k)
-   table.insert(result, v)
-  end
-
-  return result
-]])
-  rule.backend.script_get = lua_redis.add_redis_script(table.concat(redis_script_tbl, '\n'),
-      our_redis_params)
-
-  redis_script_tbl = {}
-  local redis_set_script_tpl = [[
-local key = KEYS[1] .. '${name}'
-local last = tonumber(redis.call('HGET', key, 'start'))
-local now = tonumber(KEYS[2])
-if not last then
-  last = 0
-end
-local discriminate_bucket = false
-if now - last > ${time} then
-  discriminate_bucket = true
-  redis.call('HSET', key, 'start', now)
-end
-for i=1,#ARGV,2 do
-  local k = ARGV[i]
-  local v = tonumber(ARGV[i + 1])
-
-  if discriminate_bucket then
-    local last_value = redis.call('HGET', key, k)
-    if last_value then
-      redis.call('HSET', key, k, last_value / 2.0)
+  {% for w in windows %}
+  table.insert(results, '0')
+  {% endfor %}
+  end
+
+  return results,cnt
+  ]]
+
+  local get_script = lua_util.jinja_template(redis_get_script_tpl,
+      {windows = rule.backend.config.buckets})
+  rspamd_logger.debugm(N, rspamd_config, 'added extraction script %s', get_script)
+  rule.backend.script_get = lua_redis.add_redis_script(get_script, our_redis_params)
+
+  -- Redis script to update Redis buckets
+  -- KEYS[1] - key to update
+  -- KEYS[2] - current time in milliseconds
+  -- KEYS[3] - message score
+  -- KEYS[4] - expire for a bucket
+  -- Value returned - table of scores as a strings vector
+  local redis_adaptive_emea_script_tpl = [[
+  local last = redis.call('HGET', KEYS[1], 'l')
+  local score = tonumber(KEYS[3])
+  local now = tonumber(KEYS[2])
+  local scores = {}
+
+  if last then
+    {% for w in windows %}
+    local last_value = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}'))
+    local window = {= w.time =}
+    -- Adjust alpha
+    local time_diff = now - last_value
+    if time_diff > 0 then
+      time_diff = 0
     end
+    local alpha = 1.0 - math.exp((-time_diff) / (1000 * window))
+    local nscore = alpha * score + (1.0 - alpha) * last_value
+    table.insert(scores, tostring(nscore * {= w.mult =}))
+    {% endfor %}
+  else
+    {% for w in windows %}
+    table.insert(scores, tostring(score * {= w.mult =}))
+    {% endfor %}
   end
-  redis.call('HINCRBYFLOAT', key, k, v)
-end
 
-redis.call('EXPIRE', key, KEYS[3])
-redis.call('HSET', key, 'last', now)
+  local i = 1
+  {% for w in windows %}
+    redis.call('HSET', KEYS[1], 'v' .. '{= w.name =}', scores[i])
+    i = i + 1
+  {% endfor %}
+  redis.call('HSET', KEYS[1], 'l', now)
+  redis.call('HINCRBY', KEYS[1], 'n', 1)
+  redis.call('EXPIRE', KEYS[1], tonumber(KEYS[4]))
+
+  return scores
 ]]
-  for _,bucket in ipairs(rule.backend.config.buckets) do
-    table.insert(redis_script_tbl, lua_util.template(redis_set_script_tpl,
-        bucket))
-  end
 
-  rule.backend.script_set = lua_redis.add_redis_script(table.concat(redis_script_tbl, '\n'),
-      our_redis_params)
+  local set_script = lua_util.jinja_template(redis_adaptive_emea_script_tpl,
+      {windows = rule.backend.config.buckets})
+  rspamd_logger.debugm(N, rspamd_config, 'added emea update script %s', set_script)
+  rule.backend.script_set = lua_redis.add_redis_script(set_script, our_redis_params)
 
   return true
 end
@@ -992,13 +914,13 @@ local function reputation_redis_get_token(task, rule, token, continuation_cb)
   local ret = lua_redis.exec_redis_script(rule.backend.script_get,
       {task = task, is_write = false},
       redis_get_cb,
-      {token})
+      {key})
   if not ret then
     rspamd_logger.errx(task, 'cannot make redis request to check results')
   end
 end
 
-local function reputation_redis_set_token(task, rule, token, values, continuation_cb)
+local function reputation_redis_set_token(task, rule, token, sc, continuation_cb)
   local key = gen_token_key(token, rule)
 
   local function redis_set_cb(err, data)
@@ -1015,19 +937,14 @@ local function reputation_redis_set_token(task, rule, token, values, continuatio
     end
   end
 
-  -- We start from expiry update
-  local args = {}
-  for k,v in pairs(values) do
-    table.insert(args, k)
-    table.insert(args, v)
-  end
   lua_util.debugm(N, task, 'rule %s - set values for key %s -> %s',
-      rule['symbol'], key, values)
+      rule['symbol'], key, sc)
   local ret = lua_redis.exec_redis_script(rule.backend.script_set,
       {task = task, is_write = true},
       redis_set_cb,
-      {token, tostring(rspamd_util:get_time()),
-       tostring(rule.backend.config.expiry)}, args)
+      {key, tostring(os.time() * 1000),
+       tonumber(sc),
+       tostring(rule.backend.config.expiry)})
   if not ret then
     rspamd_logger.errx(task, 'got error while connecting to redis')
   end
@@ -1043,6 +960,7 @@ end
 local backends = {
   redis = {
     schema = ts.shape({
+      prefix = ts.string,
       expiry = ts.number + ts.string / lua_util.parse_time_interval,
       buckets = ts.array_of(ts.shape{
         time = ts.number + ts.string / lua_util.parse_time_interval,
@@ -1052,6 +970,7 @@ local backends = {
     }, {extra_fields = lua_redis.config_schema}),
     config = {
       expiry = default_expiry,
+      prefix = default_prefix,
       buckets = {
         {
           time = 60 * 60,


More information about the Commits mailing list