commit d7d7100: [Project] Add linalg ffi library for prototyping

Vsevolod Stakhov vsevolod at highsecure.ru
Tue Aug 4 14:35:09 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-04 14:17:01 +0100
URL: https://github.com/rspamd/rspamd/commit/d7d71002117e4ec30d96ca91c54f971b8e835325

[Project] Add linalg ffi library for prototyping

---
 contrib/kann/kautodiff.h  |  5 +++-
 lualib/lua_ffi/init.lua   |  1 +
 lualib/lua_ffi/linalg.lua | 70 +++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 75 insertions(+), 1 deletion(-)

diff --git a/contrib/kann/kautodiff.h b/contrib/kann/kautodiff.h
index a2c648835..e51176c84 100644
--- a/contrib/kann/kautodiff.h
+++ b/contrib/kann/kautodiff.h
@@ -102,7 +102,7 @@ void kad_delete(int n, kad_node_t **a); /* deallocate a compiled/linearized grap
 
 /**
  * Compute the value at a node
- * 
+ *
  * @param n       number of nodes
  * @param a       list of nodes
  * @param from    compute the value at this node, 0<=from<n
@@ -243,4 +243,7 @@ static inline int kad_len(const kad_node_t *p) /* calculate the size of p->x */
 	return n;
 }
 
+/* Additions by Rspamd */
+void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
+
 #endif
diff --git a/lualib/lua_ffi/init.lua b/lualib/lua_ffi/init.lua
index 02b54f932..08a6763bb 100644
--- a/lualib/lua_ffi/init.lua
+++ b/lualib/lua_ffi/init.lua
@@ -49,6 +49,7 @@ pcall(ffi.load, "rspamd-server", true)
 exports.common = require "lua_ffi/common"
 exports.dkim = require "lua_ffi/dkim"
 exports.spf = require "lua_ffi/spf"
+exports.linalg = require "lua_ffi/linalg"
 
 for k,v in pairs(ffi) do
   -- Preserve all stuff to use lua_ffi as ffi itself
diff --git a/lualib/lua_ffi/linalg.lua b/lualib/lua_ffi/linalg.lua
new file mode 100644
index 000000000..c3f6eff5a
--- /dev/null
+++ b/lualib/lua_ffi/linalg.lua
@@ -0,0 +1,70 @@
+--[[
+Copyright (c) 2020, Vsevolod Stakhov <vsevolod at highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_ffi/linalg
+-- This module contains ffi interfaces to linear algebra routines
+--]]
+
+local ffi = require 'ffi'
+
+local exports = {}
+
+ffi.cdef[[
+  void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
+]]
+
+local function table_to_ffi(a, m, n)
+  local a_conv = ffi.new(string.format("float[%d][%d]", m, n), {})
+  for i=1,m or #a do
+    for j=1,n or #a[1] do
+      a_conv[i - 1][j - 1] = a[i][j]
+    end
+  end
+  return a_conv
+end
+
+local function ffi_to_table(a, m, n)
+  local res = {}
+
+  for i=0,m-1 do
+    res[i + 1] = {}
+    for j=0,n-1 do
+      res[i + 1][j + 1] = a[i][j]
+    end
+  end
+
+  return res
+end
+
+exports.sgemm = function(a, m, b, n, k, trans_a, trans_b)
+  if type(a) == 'table' then
+    -- Need to convert, slow!
+    a = table_to_ffi(a, m, k)
+  end
+  if type(b) == 'table' then
+    b = table_to_ffi(b, k, n)
+  end
+  local res = ffi.new(string.format("float[%d][%d]", m, n), {})
+  ffi.C.kad_sgemm_simple(trans_a or 0, trans_b or 0, m, n, k, ffi.cast('const float*', a),
+      ffi.cast('const float*', b), ffi.cast('float*', res))
+  return res
+end
+
+exports.ffi_to_table = ffi_to_table
+exports.table_to_ffi = table_to_ffi
+
+return exports
\ No newline at end of file


More information about the Commits mailing list