commit f8004be: [Test] Add unit test for kann

Vsevolod Stakhov vsevolod at highsecure.ru
Mon Jul 1 13:07:04 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-07-01 13:50:21 +0100
URL: https://github.com/rspamd/rspamd/commit/f8004be4c94ca214cd399cdb18aa0d085abf0fd7 (HEAD -> master)

[Test] Add unit test for kann

---
 test/lua/unit/kann.lua | 43 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 43 insertions(+)

diff --git a/test/lua/unit/kann.lua b/test/lua/unit/kann.lua
new file mode 100644
index 000000000..bb6930203
--- /dev/null
+++ b/test/lua/unit/kann.lua
@@ -0,0 +1,43 @@
+-- Simple kann test (xor function vs 2 layer MLP)
+
+context("Kann test", function()
+  local kann = require "rspamd_kann"
+  local k
+  local inputs = {
+    {0, 0},
+    {0, 1},
+    {1, 0},
+    {1, 1}
+  }
+
+  local outputs = {
+    {0},
+    {1},
+    {1},
+    {0}
+  }
+
+  local t = kann.layer.input(2)
+  t = kann.transform.relu(t)
+  t = kann.transform.tanh(kann.layer.dense(t, 2));
+  t = kann.layer.cost(t, 1, kann.cost.mse)
+  k = kann.new.kann(t)
+
+  local iters = 500
+  local niter = k:train1(inputs, outputs, {
+    lr = 0.01,
+    max_epoch = iters,
+    mini_size = 80,
+  })
+
+  for i,inp in ipairs(inputs) do
+    test(string.format("Check XOR MLP %s ^ %s == %s", inp[1], inp[2], outputs[i][1]),
+        function()
+          local res = math.floor(k:apply1(inp)[1] + 0.5)
+          assert_equal(outputs[i][1], res,
+              tostring(outputs[i][1]) .. " but test returned " .. tostring(res))
+        end)
+  end
+
+
+end)
\ No newline at end of file


More information about the Commits mailing list