commit 0dd7203: [Minor] Move redis scripts from ratelimit file to simplify checks

Vsevolod Stakhov vsevolod at rspamd.com
Sat Mar 25 12:42:05 UTC 2023


Author: Vsevolod Stakhov
Date: 2023-03-25 12:40:15 +0000
URL: https://github.com/rspamd/rspamd/commit/0dd7203ee541be133c5fab808305146fc311c4f0 (HEAD -> master)

[Minor] Move redis scripts from ratelimit file to simplify checks

---
 .luacheckrc                               |   5 +
 lualib/lua_redis.lua                      |   2 +-
 lualib/redis_scripts/ratelimit_check.lua  |  62 ++++++++++++
 lualib/redis_scripts/ratelimit_update.lua |  80 ++++++++++++++++
 src/plugins/lua/ratelimit.lua             | 154 ++----------------------------
 5 files changed, 154 insertions(+), 149 deletions(-)

diff --git a/.luacheckrc b/.luacheckrc
index 099d9ab65..d5a18cc3b 100644
--- a/.luacheckrc
+++ b/.luacheckrc
@@ -60,6 +60,11 @@ files['/**/lualib/lua_redis.lua'].globals = {
   'rspamadm_ev_base',
 }
 
+files['/**/lualib/redis_scripts/**'].globals = {
+  'redis',
+  'KEYS',
+}
+
 files['/**/src/rspamadm/*'].globals = {
   'ansicolors',
   'getopt',
diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua
index 76e8db972..62511451e 100644
--- a/lualib/lua_redis.lua
+++ b/lualib/lua_redis.lua
@@ -1293,7 +1293,7 @@ local function load_redis_script_from_file(filename, redis_params, dir)
   if not dir then dir = rspamd_paths.LUALIBDIR end
   if filename:sub(1, 1) ~= package.config:sub(1,1) then
     -- Relative path
-    filename = lua_util.join_path(dir, filename)
+    filename = lua_util.join_path(dir, "redis_scripts", filename)
   end
   -- Read file contents
   local file = io.open(filename, "r")
diff --git a/lualib/redis_scripts/ratelimit_check.lua b/lualib/redis_scripts/ratelimit_check.lua
new file mode 100644
index 000000000..2b2af11bf
--- /dev/null
+++ b/lualib/redis_scripts/ratelimit_check.lua
@@ -0,0 +1,62 @@
+-- Checks bucket, updating it if needed
+-- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
+-- KEYS[2] - current time in milliseconds
+-- KEYS[3] - bucket leak rate (messages per millisecond)
+-- KEYS[4] - bucket burst
+-- KEYS[5] - expire for a bucket
+-- KEYS[6] - number of recipients
+-- return 1 if message should be ratelimited and 0 if not
+-- Redis keys used:
+--   l - last hit
+--   b - current burst
+--   p - pending messages (those that are currently processing)
+--   dr - current dynamic rate multiplier (*10000)
+--   db - current dynamic burst multiplier (*10000)
+
+local last = redis.call('HGET', KEYS[1], 'l')
+local now = tonumber(KEYS[2])
+local nrcpt = tonumber(KEYS[6])
+local dynr, dynb, leaked = 0, 0, 0
+if not last then
+  -- New bucket
+  redis.call('HMSET', KEYS[1], 'l', KEYS[2], 'b', '0', 'dr', '10000', 'db', '10000', 'p', tostring(nrcpt))
+  redis.call('EXPIRE', KEYS[1], KEYS[5])
+  return {0, '0', '1', '1', '0'}
+end
+
+last = tonumber(last)
+local burst,pending = unpack(redis.call('HMGET', KEYS[1], 'b', 'p'))
+burst,pending = tonumber(burst or '0'),tonumber(pending or '0')
+-- Sanity to avoid races
+if burst < 0 then burst = 0 end
+if pending < 0 then pending = 0 end
+pending = pending + nrcpt -- this message
+-- Perform leak
+if burst + pending > 0 then
+  if burst > 0 and last < tonumber(KEYS[2]) then
+    local rate = tonumber(KEYS[3])
+    dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0
+    if dynr == 0 then dynr = 0.0001 end
+    rate = rate * dynr
+    leaked = ((now - last) * rate)
+    if leaked > burst then leaked = burst end
+    burst = burst - leaked
+    redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked))
+    redis.call('HSET', KEYS[1], 'l', KEYS[2])
+  end
+
+  dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0
+  if dynb == 0 then dynb = 0.0001 end
+
+  burst = burst + pending
+  if burst > 0 and (burst + tonumber(KEYS[6])) > tonumber(KEYS[4]) * dynb then
+    return {1, tostring(burst - pending), tostring(dynr), tostring(dynb), tostring(leaked)}
+  end
+  -- Increase pending if we allow ratelimit
+  redis.call('HINCRBY', KEYS[1], 'p', nrcpt)
+else
+  burst = 0
+  redis.call('HMSET', KEYS[1], 'b', '0', 'p', tostring(nrcpt))
+end
+
+return {0, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked)}
\ No newline at end of file
diff --git a/lualib/redis_scripts/ratelimit_update.lua b/lualib/redis_scripts/ratelimit_update.lua
new file mode 100644
index 000000000..682ddd0c6
--- /dev/null
+++ b/lualib/redis_scripts/ratelimit_update.lua
@@ -0,0 +1,80 @@
+-- Updates a bucket
+-- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
+-- KEYS[2] - current time in milliseconds
+-- KEYS[3] - dynamic rate multiplier
+-- KEYS[4] - dynamic burst multiplier
+-- KEYS[5] - max dyn rate (min: 1/x)
+-- KEYS[6] - max burst rate (min: 1/x)
+-- KEYS[7] - expire for a bucket
+-- KEYS[8] - number of recipients (or increase rate)
+-- Redis keys used:
+--   l - last hit
+--   b - current burst
+--   p - messages pending (must be decreased by 1)
+--   dr - current dynamic rate multiplier
+--   db - current dynamic burst multiplier
+
+local last = redis.call('HGET', KEYS[1], 'l')
+local nrcpt = tonumber(KEYS[8])
+if not last then
+  -- New bucket (why??)
+  redis.call('HMSET', KEYS[1], 'l', KEYS[2], 'b', tostring(nrcpt), 'dr', '10000', 'db', '10000', 'p', '0')
+  redis.call('EXPIRE', KEYS[1], KEYS[7])
+  return {1, 1, 1}
+end
+
+local dr, db = 1.0, 1.0
+
+if tonumber(KEYS[5]) > 1 then
+  local rate_mult = tonumber(KEYS[3])
+  local rate_limit = tonumber(KEYS[5])
+  dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000
+
+  if rate_mult > 1.0 and dr < rate_limit then
+    dr = dr * rate_mult
+    if dr > 0.0001 then
+      redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
+    else
+      redis.call('HSET', KEYS[1], 'dr', '1')
+    end
+  elseif rate_mult < 1.0 and dr > (1.0 / rate_limit) then
+    dr = dr * rate_mult
+    if dr > 0.0001 then
+      redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
+    else
+      redis.call('HSET', KEYS[1], 'dr', '1')
+    end
+  end
+end
+
+if tonumber(KEYS[6]) > 1 then
+  local rate_mult = tonumber(KEYS[4])
+  local rate_limit = tonumber(KEYS[6])
+  db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000
+
+  if rate_mult > 1.0 and db < rate_limit then
+    db = db * rate_mult
+    if db > 0.0001 then
+      redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
+    else
+      redis.call('HSET', KEYS[1], 'db', '1')
+    end
+  elseif rate_mult < 1.0 and db > (1.0 / rate_limit) then
+    db = db * rate_mult
+    if db > 0.0001 then
+      redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
+    else
+      redis.call('HSET', KEYS[1], 'db', '1')
+    end
+  end
+end
+
+local burst,pending = unpack(redis.call('HMGET', KEYS[1], 'b', 'p'))
+burst,pending = tonumber(burst or '0'),tonumber(pending or '0')
+if burst < 0 then burst = nrcpt else burst = burst + nrcpt end
+if pending < nrcpt then pending = 0 else pending = pending - nrcpt end
+
+redis.call('HMSET', KEYS[1], 'b', tostring(burst), 'p', tostring(pending), 'l', KEYS[2])
+redis.call('EXPIRE', KEYS[1], KEYS[7])
+
+return {tostring(burst), tostring(dr), tostring(db)}
\ No newline at end of file
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua
index d61e3990f..520efc99e 100644
--- a/src/plugins/lua/ratelimit.lua
+++ b/src/plugins/lua/ratelimit.lua
@@ -54,154 +54,12 @@ local settings = {
   prefilter = true,
 }
 
--- Checks bucket, updating it if needed
--- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
--- KEYS[2] - current time in milliseconds
--- KEYS[3] - bucket leak rate (messages per millisecond)
--- KEYS[4] - bucket burst
--- KEYS[5] - expire for a bucket
--- KEYS[6] - number of recipients
--- return 1 if message should be ratelimited and 0 if not
--- Redis keys used:
---   l - last hit
---   b - current burst
---   p - pending messages (those that are currently processing)
---   dr - current dynamic rate multiplier (*10000)
---   db - current dynamic burst multiplier (*10000)
-local bucket_check_script = [[
-  local last = redis.call('HGET', KEYS[1], 'l')
-  local now = tonumber(KEYS[2])
-  local nrcpt = tonumber(KEYS[6])
-  local dynr, dynb, leaked = 0, 0, 0
-  if not last then
-    -- New bucket
-    redis.call('HMSET', KEYS[1], 'l', KEYS[2], 'b', '0', 'dr', '10000', 'db', '10000', 'p', tostring(nrcpt))
-    redis.call('EXPIRE', KEYS[1], KEYS[5])
-    return {0, '0', '1', '1', '0'}
-  end
-
-  last = tonumber(last)
-  local burst,pending = unpack(redis.call('HMGET', KEYS[1], 'b', 'p'))
-  burst,pending = tonumber(burst or '0'),tonumber(pending or '0')
-  -- Sanity to avoid races
-  if burst < 0 then burst = 0 end
-  if pending < 0 then pending = 0 end
-  pending = pending + nrcpt -- this message
-  -- Perform leak
-  if burst + pending > 0 then
-   if burst > 0 and last < tonumber(KEYS[2]) then
-    local rate = tonumber(KEYS[3])
-    dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0
-    if dynr == 0 then dynr = 0.0001 end
-    rate = rate * dynr
-    leaked = ((now - last) * rate)
-    if leaked > burst then leaked = burst end
-    burst = burst - leaked
-    redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked))
-    redis.call('HSET', KEYS[1], 'l', KEYS[2])
-   end
-
-   dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0
-   if dynb == 0 then dynb = 0.0001 end
-
-   burst = burst + pending
-   if burst > 0 and (burst + tonumber(KEYS[6])) > tonumber(KEYS[4]) * dynb then
-     return {1, tostring(burst - pending), tostring(dynr), tostring(dynb), tostring(leaked)}
-   end
-   -- Increase pending if we allow ratelimit
-   redis.call('HINCRBY', KEYS[1], 'p', nrcpt)
-  else
-   burst = 0
-   redis.call('HMSET', KEYS[1], 'b', '0', 'p', tostring(nrcpt))
-  end
 
-  return {0, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked)}
-]]
+local bucket_check_script = "ratelimit_check.lua"
 local bucket_check_id
 
 
--- Updates a bucket
--- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
--- KEYS[2] - current time in milliseconds
--- KEYS[3] - dynamic rate multiplier
--- KEYS[4] - dynamic burst multiplier
--- KEYS[5] - max dyn rate (min: 1/x)
--- KEYS[6] - max burst rate (min: 1/x)
--- KEYS[7] - expire for a bucket
--- KEYS[8] - number of recipients (or increase rate)
--- Redis keys used:
---   l - last hit
---   b - current burst
---   p - messages pending (must be decreased by 1)
---   dr - current dynamic rate multiplier
---   db - current dynamic burst multiplier
-local bucket_update_script = [[
-  local last = redis.call('HGET', KEYS[1], 'l')
-  local now = tonumber(KEYS[2])
-  local nrcpt = tonumber(KEYS[8])
-  if not last then
-    -- New bucket (why??)
-    redis.call('HMSET', KEYS[1], 'l', KEYS[2], 'b', tostring(nrcpt), 'dr', '10000', 'db', '10000', 'p', '0')
-    redis.call('EXPIRE', KEYS[1], KEYS[7])
-    return {1, 1, 1}
-  end
-
-  local dr, db = 1.0, 1.0
-
-  if tonumber(KEYS[5]) > 1 then
-    local rate_mult = tonumber(KEYS[3])
-    local rate_limit = tonumber(KEYS[5])
-    dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000
-
-    if rate_mult > 1.0 and dr < rate_limit then
-      dr = dr * rate_mult
-      if dr > 0.0001 then
-        redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
-      else
-        redis.call('HSET', KEYS[1], 'dr', '1')
-      end
-    elseif rate_mult < 1.0 and dr > (1.0 / rate_limit) then
-      dr = dr * rate_mult
-      if dr > 0.0001 then
-        redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
-      else
-        redis.call('HSET', KEYS[1], 'dr', '1')
-      end
-    end
-  end
-
-  if tonumber(KEYS[6]) > 1 then
-    local rate_mult = tonumber(KEYS[4])
-    local rate_limit = tonumber(KEYS[6])
-    db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000
-
-    if rate_mult > 1.0 and db < rate_limit then
-      db = db * rate_mult
-      if db > 0.0001 then
-        redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
-      else
-        redis.call('HSET', KEYS[1], 'db', '1')
-      end
-    elseif rate_mult < 1.0 and db > (1.0 / rate_limit) then
-      db = db * rate_mult
-      if db > 0.0001 then
-        redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
-      else
-        redis.call('HSET', KEYS[1], 'db', '1')
-      end
-    end
-  end
-
-  local burst,pending = unpack(redis.call('HMGET', KEYS[1], 'b', 'p'))
-  burst,pending = tonumber(burst or '0'),tonumber(pending or '0')
-  if burst < 0 then burst = nrcpt else burst = burst + nrcpt end
-  if pending < nrcpt then pending = 0 else pending = pending - nrcpt end
-
-  redis.call('HMSET', KEYS[1], 'b', tostring(burst), 'p', tostring(pending), 'l', KEYS[2])
-  redis.call('EXPIRE', KEYS[1], KEYS[7])
-
-  return {tostring(burst), tostring(dr), tostring(db)}
-]]
+local bucket_update_script = "ratelimit_update.lua"
 local bucket_update_id
 
 -- message_func(task, limit_type, prefix, bucket, limit_key)
@@ -210,9 +68,9 @@ local message_func = function(_, limit_type, _, _, _)
 end
 
 
-local function load_scripts(cfg, ev_base)
-  bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params)
-  bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params)
+local function load_scripts(_, _)
+  bucket_check_id = lua_redis.load_redis_script_from_file(bucket_check_script, redis_params)
+  bucket_update_id = lua_redis.load_redis_script_from_file(bucket_update_script, redis_params)
 end
 
 local limit_parser
@@ -927,6 +785,6 @@ if opts then
   end
 end
 
-rspamd_config:add_on_load(function(cfg, ev_base, worker)
+rspamd_config:add_on_load(function(cfg, ev_base, _)
   load_scripts(cfg, ev_base)
 end)


More information about the Commits mailing list