commit 055640c: [Project] Lua_magic: Improve short patterns performance
Vsevolod Stakhov
vsevolod at highsecure.ru
Fri Sep 6 17:49:08 UTC 2019
Author: Vsevolod Stakhov
Date: 2019-09-06 14:05:31 +0100
URL: https://github.com/rspamd/rspamd/commit/055640c105492d4aaa8a75f973ce208ffd8cc045
[Project] Lua_magic: Improve short patterns performance
---
lualib/lua_magic/init.lua | 131 ++++++++++++++++++++++++++++++++++------------
1 file changed, 97 insertions(+), 34 deletions(-)
diff --git a/lualib/lua_magic/init.lua b/lualib/lua_magic/init.lua
index 1ba899b06..a2b2c9882 100644
--- a/lualib/lua_magic/init.lua
+++ b/lualib/lua_magic/init.lua
@@ -31,17 +31,43 @@ local N = "lua_magic"
local exports = {}
-- trie object
local compiled_patterns
+local compiled_short_patterns -- short patterns
-- {<str>, <match_object>, <pattern_object>} indexed by pattern number
local processed_patterns = {}
+local short_patterns = {}
+
+local short_match_limit = 128
+local max_short_offset = -1
+
+local function process_patterns(log_obj)
+ -- Add pattern to either short patterns or to normal patterns
+ local function add_processed(str, match, pattern)
+ if match.position and type(match.position) == 'number' and
+ match.position < short_match_limit then
+ short_patterns[#short_patterns + 1] = {
+ str, match, pattern
+ }
+
+ if max_short_offset < match.position then
+ max_short_offset = match.position
+ end
+ else
+ processed_patterns[#processed_patterns + 1] = {
+ str, match, pattern
+ }
+ end
+ end
-local function process_patterns()
if not compiled_patterns then
- for _,pattern in ipairs(patterns) do
+ for ext,pattern in pairs(patterns) do
+ assert(types[ext])
+ pattern.ext = ext
for _,match in ipairs(pattern.matches) do
if match.string then
- processed_patterns[#processed_patterns + 1] = {
- match.string, match, pattern
- }
+ if match.relative_position and not match.position then
+ match.position = match.relative_position + #match.string
+ end
+ add_processed(match.string, match, pattern)
elseif match.hex then
local hex_table = {}
@@ -49,9 +75,11 @@ local function process_patterns()
local subc = match.hex:sub(i, i + 1)
hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
end
- processed_patterns[#processed_patterns + 1] = {
- table.concat(hex_table), match, pattern
- }
+
+ if match.relative_position and not match.position then
+ match.position = match.relative_position + #match.hex / 2
+ end
+ add_processed(table.concat(hex_table), match, pattern)
end
end
end
@@ -60,16 +88,19 @@ local function process_patterns()
fun.map(function(t) return t[1] end, processed_patterns)),
rspamd_trie.flags.re
)
+ compiled_short_patterns = rspamd_trie.create(fun.totable(
+ fun.map(function(t) return t[1] end, short_patterns)),
+ rspamd_trie.flags.re
+ )
- lua_util.debugm(N, rspamd_config, 'compiled %s patterns',
- #processed_patterns)
+ lua_util.debugm(N, log_obj,
+ 'compiled %s (%s short and %s long) patterns',
+ #processed_patterns + #short_patterns, #short_patterns, #processed_patterns)
end
end
-local function match_chunk(input, offset, log_obj, res)
- local matches = compiled_patterns:match(input)
-
- if not log_obj then log_obj = rspamd_config end
+local function match_chunk(input, offset, trie, processed_tbl, log_obj, res)
+ local matches = trie:match(input)
local function add_result(match, pattern)
if not res[pattern.ext] then
@@ -86,7 +117,7 @@ local function match_chunk(input, offset, log_obj, res)
end
for npat,matched_positions in pairs(matches) do
- local pat_data = processed_patterns[npat]
+ local pat_data = processed_tbl[npat]
local pattern = pat_data[3]
local match = pat_data[2]
@@ -132,8 +163,25 @@ local function match_chunk(input, offset, log_obj, res)
end
end
end
+
+local function process_detected(res)
+ local extensions = lua_util.keys(res)
+
+ if #extensions > 0 then
+ table.sort(extensions, function(ex1, ex2)
+ return res[ex1] > res[ex2]
+ end)
+
+ return extensions,res[extensions[1]]
+ end
+
+ return nil
+end
+
exports.detect = function(input, log_obj)
- process_patterns()
+ if not log_obj then log_obj = rspamd_config end
+ process_patterns(log_obj)
+
local res = {}
if type(input) == 'string' then
@@ -141,28 +189,43 @@ exports.detect = function(input, log_obj)
input = rspamd_text.fromstring(input)
end
- if type(input) == 'userdata' and #input > exports.chunk_size * 3 then
- -- Split by chunks
- local chunk1, chunk2, chunk3 =
- input:span(1, exports.chunk_size),
- input:span(exports.chunk_size, exports.chunk_size),
- input:span(#input - exports.chunk_size, exports.chunk_size)
- local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size
-
- match_chunk(chunk1, offset1, log_obj, res)
- match_chunk(chunk2, offset2, log_obj, res)
- match_chunk(chunk3, offset3, log_obj, res)
+
+ if type(input) == 'userdata' then
+ -- Try short match
+ local head = input:span(1, math.min(max_short_offset, #input))
+ match_chunk(head, 0, compiled_short_patterns, short_patterns, log_obj, res)
+
+ local extensions,confidence = process_detected(res)
+
+ if extensions and #extensions > 0 and confidence > 30 then
+ -- We are done on short patterns
+ return extensions[1],types[extensions[1]]
+ end
+
+ if #input > exports.chunk_size * 3 then
+ -- Chunked version as input is too long
+ local chunk1, chunk2, chunk3 =
+ input:span(1, exports.chunk_size),
+ input:span(exports.chunk_size, exports.chunk_size),
+ input:span(#input - exports.chunk_size, exports.chunk_size)
+ local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size
+
+ match_chunk(chunk1, offset1, compiled_patterns, processed_patterns, log_obj, res)
+ match_chunk(chunk2, offset2, compiled_patterns, processed_patterns, log_obj, res)
+ match_chunk(chunk3, offset3, compiled_patterns, processed_patterns, log_obj, res)
+ else
+ -- Input is short enough to match it at all
+ match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
+ end
else
- match_chunk(input, 0, log_obj, res)
+ -- Input is a table so just try to match it all...
+ match_chunk(input, 0, compiled_short_patterns, short_patterns, log_obj, res)
+ match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
end
- local extensions = lua_util.keys(res)
-
- if #extensions > 0 then
- table.sort(extensions, function(ex1, ex2)
- return res[ex1] > res[ex2]
- end)
+ local extensions = process_detected(res)
+ if extensions and #extensions > 0 then
return extensions[1],types[extensions[1]]
end
More information about the Commits
mailing list