commit 2381494: [Project] Lua_magic: Some rework in detection

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Sep 9 12:14:10 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-09-09 12:37:06 +0100
URL: https://github.com/rspamd/rspamd/commit/23814946671f0385b38650aa855f701a0bfcd003

[Project] Lua_magic: Some rework in detection

---
 lualib/lua_magic/heuristics.lua | 82 ++++++++++++++++++++++-------------------
 lualib/lua_magic/init.lua       |  5 ++-
 src/libmime/message.c           |  6 +--
 3 files changed, 51 insertions(+), 42 deletions(-)

diff --git a/lualib/lua_magic/heuristics.lua b/lualib/lua_magic/heuristics.lua
index b30f95794..4f2b583f0 100644
--- a/lualib/lua_magic/heuristics.lua
+++ b/lualib/lua_magic/heuristics.lua
@@ -46,42 +46,50 @@ local msoffice_patterns_indexes = {}
 
 local exports = {}
 
-local function compile_msoffice_trie(log_obj)
-  if not msoffice_trie then
-    -- Directory names
+local function compile_tries()
+  local function compile_pats(patterns, indexes, transform_func)
     local strs = {}
-    for ext,pats in pairs(msoffice_patterns) do
+    for ext,pats in pairs(patterns) do
       for _,pat in ipairs(pats) do
         -- These are utf16 strings in fact...
-        strs[#strs + 1] = '^' ..
-            table.concat(
-                fun.totable(
-                    fun.map(function(c) return c .. [[\x{00}]] end,
-                        fun.iter(pat))))
-        msoffice_patterns_indexes[#msoffice_patterns_indexes + 1] = ext
-
+        strs[#strs + 1] = transform_func(pat)
+        indexes[#indexes + 1] = ext
       end
     end
-    msoffice_trie = rspamd_trie.create(strs, rspamd_trie.flags.re)
-    -- Clsids
-    strs = {}
-    for ext,pats in pairs(msoffice_clsids) do
-      for _,pat in ipairs(pats) do
-        -- Convert hex to re
-        local hex_table = {}
-        for i=1,#pat,2 do
-          local subc = pat:sub(i, i + 1)
-          hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
-        end
-        strs[#strs + 1] = '^' .. table.concat(hex_table) .. '$'
-        msoffice_clsid_indexes[#msoffice_clsid_indexes + 1] = ext
 
+    return rspamd_trie.create(strs, rspamd_trie.flags.re)
+  end
+
+  if not msoffice_trie then
+    -- Directory names
+    local function msoffice_pattern_transform(pat)
+      return '^' ..
+          table.concat(
+              fun.totable(
+                  fun.map(function(c) return c .. [[\x{00}]] end,
+                      fun.iter(pat))))
+    end
+    local function msoffice_clsid_transform(pat)
+      local hex_table = {}
+      for i=1,#pat,2 do
+        local subc = pat:sub(i, i + 1)
+        hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
       end
+
+      return '^' .. table.concat(hex_table) .. '$'
     end
-    msoffice_trie_clsid = rspamd_trie.create(strs, rspamd_trie.flags.re)
+    -- Directory entries
+    msoffice_trie = compile_pats(msoffice_patterns, msoffice_patterns_indexes,
+        msoffice_pattern_transform)
+    -- Clsids
+    msoffice_trie_clsid = compile_pats(msoffice_clsids, msoffice_clsid_indexes,
+        msoffice_clsid_transform)
   end
 end
 
+-- Call immediately on require
+compile_tries()
+
 local function detect_ole_format(input, log_obj)
   local inplen = #input
   if inplen < 0x31 + 4 then
@@ -89,7 +97,6 @@ local function detect_ole_format(input, log_obj)
     return nil
   end
 
-  compile_msoffice_trie(log_obj)
   local bom,sec_size = rspamd_util.unpack('<I2<I2', input:span(29, 4))
   if bom == 0xFFFE then
     bom = '<'
@@ -167,7 +174,7 @@ end
 
 exports.ole_format_heuristic = detect_ole_format
 
-local function process_detected(res)
+local function process_top_detected(res)
   local extensions = lua_util.keys(res)
 
   if #extensions > 0 then
@@ -175,13 +182,13 @@ local function process_detected(res)
       return res[ex1] > res[ex2]
     end)
 
-    return extensions,res[extensions[1]]
+    return extensions[1],res[extensions[1]]
   end
 
   return nil
 end
 
-local function detect_archive_flaw(part, arch)
+local function detect_archive_flaw(part, arch, log_obj)
   local arch_type = arch:get_type()
   local res = {
     docx = 0,
@@ -206,12 +213,12 @@ local function detect_archive_flaw(part, arch)
     for _,file in ipairs(files) do
       if file == '[Content_Types].xml' then
         add_msoffice_confidence(10)
-      elseif file == 'xl/' then
+      elseif file:sub(1, 3) == 'xl/' then
         res.xlsx = res.xlsx + 30
-      elseif file == 'word/' then
-        res.xlsx = res.docx + 30
-      elseif file == 'ppt/' then
-        res.xlsx = res.pptx + 30
+      elseif file:sub(1, 5) == 'word/' then
+        res.docx = res.docx + 30
+      elseif file:sub(1, 4) == 'ppt/' then
+        res.pptx = res.pptx + 30
       elseif file == 'META-INF/manifest.xml' then
         -- Apply ODT detection logic
         local content = part:get_content()
@@ -245,7 +252,7 @@ local function detect_archive_flaw(part, arch)
       end
     end
 
-    local ext,weight = process_detected(res)
+    local ext,weight = process_top_detected(res)
 
     if weight >= 40 then
       return ext,weight
@@ -254,7 +261,8 @@ local function detect_archive_flaw(part, arch)
 
   return arch_type:lower(),40
 end
-exports.mime_part_heuristic = function(part)
+
+exports.mime_part_heuristic = function(part, log_obj)
   if part:is_text() then
     if part:get_text():is_html() then
       return 'html',60
@@ -270,7 +278,7 @@ exports.mime_part_heuristic = function(part)
 
   if part:is_archive() then
     local arch = part:get_archive()
-    return detect_archive_flaw(part, arch)
+    return detect_archive_flaw(part, arch, log_obj)
   end
 
   return nil
diff --git a/lualib/lua_magic/init.lua b/lualib/lua_magic/init.lua
index 8b5064bfe..e9e0297e9 100644
--- a/lualib/lua_magic/init.lua
+++ b/lualib/lua_magic/init.lua
@@ -132,6 +132,8 @@ local function process_patterns(log_obj)
   end
 end
 
+process_patterns(rspamd_config)
+
 local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_obj, res)
   local matches = trie:match(chunk)
 
@@ -253,7 +255,6 @@ end
 
 exports.detect = function(input, log_obj)
   if not log_obj then log_obj = rspamd_config end
-  process_patterns(log_obj)
 
   local res = {}
 
@@ -319,7 +320,7 @@ exports.detect = function(input, log_obj)
 end
 
 exports.detect_mime_part = function(part, log_obj)
-  local ext,weight = heuristics.mime_part_heuristic(part)
+  local ext,weight = heuristics.mime_part_heuristic(part, log_obj)
 
   if ext and weight and weight > 20 then
     return ext,types[ext]
diff --git a/src/libmime/message.c b/src/libmime/message.c
index 90df43b12..92fa1f51b 100644
--- a/src/libmime/message.c
+++ b/src/libmime/message.c
@@ -1409,6 +1409,9 @@ rspamd_message_process (struct rspamd_task *task)
 	guint tw, *ptw, dw;
 	struct rspamd_mime_part *part;
 
+	rspamd_images_process (task);
+	rspamd_archives_process (task);
+
 	PTR_ARRAY_FOREACH (MESSAGE_FIELD (task, parts), i, part) {
 		if (!rspamd_message_process_text_part_maybe (task, part) &&
 				part->parsed_data.len > 0) {
@@ -1430,9 +1433,6 @@ rspamd_message_process (struct rspamd_task *task)
 		}
 	}
 
-	rspamd_images_process (task);
-	rspamd_archives_process (task);
-
 	/* Calculate average words length and number of short words */
 	struct rspamd_mime_text_part *text_part;
 	gdouble *var;


More information about the Commits mailing list