commit 9fd03ab: [Project] Add ssyev method interface

Vsevolod Stakhov vsevolod at highsecure.ru
Tue Aug 4 14:35:11 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-04 14:56:32 +0100
URL: https://github.com/rspamd/rspamd/commit/9fd03abf5d911af6732b180707d9cb92d662572a

[Project] Add ssyev method interface

---
 contrib/kann/kautodiff.c  | 29 +++++++++++++++++++++++++++++
 contrib/kann/kautodiff.h  |  9 ++++++++-
 lualib/lua_ffi/linalg.lua | 23 ++++++++++++++++++++---
 3 files changed, 57 insertions(+), 4 deletions(-)

diff --git a/contrib/kann/kautodiff.c b/contrib/kann/kautodiff.c
index 47a86a71e..7b0bf8e93 100644
--- a/contrib/kann/kautodiff.c
+++ b/contrib/kann/kautodiff.c
@@ -900,6 +900,7 @@ void kad_vec_mul_sum(int n, float *a, const float *b, const float *c)
 void kad_saxpy(int n, float a, const float *x, float *y) { kad_saxpy_inlined(n, a, x, y); }
 
 #ifdef HAVE_CBLAS
+extern void ssyev(const char* jobz, const char* uplo, int* n, float* a, int* lda, float* w, float* work, int* lwork, int* info);
 #ifdef HAVE_CBLAS_H
 #include "cblas.h"
 #else
@@ -947,6 +948,34 @@ void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float
 }
 #endif
 
+bool kad_ssyev_simple(int N, float *A, float *eugenvals)
+{
+#ifndef HAVE_CBLAS
+	return false;
+#else
+	int n = N, lda = N, info, lwork;
+	float wkopt;
+	float *work;
+
+	/* Query and allocate the optimal workspace */
+	lwork = -1;
+	ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, &wkopt, &lwork, &info);
+	lwork = wkopt;
+	work = (float*) g_malloc(lwork * sizeof(double));
+	ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, work, &lwork, &info);
+	/* Check for convergence */
+	if (info > 0) {
+		g_free (work);
+
+		return false;
+	}
+
+	g_free (work);
+
+	return true;
+#endif
+}
+
 /***************************
  * Random number generator *
  ***************************/
diff --git a/contrib/kann/kautodiff.h b/contrib/kann/kautodiff.h
index e51176c84..8c797205c 100644
--- a/contrib/kann/kautodiff.h
+++ b/contrib/kann/kautodiff.h
@@ -244,6 +244,13 @@ static inline int kad_len(const kad_node_t *p) /* calculate the size of p->x */
 }
 
 /* Additions by Rspamd */
-void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
+void kad_sgemm_simple (int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
+/**
+ * Calculate eugenvectors and eugenvalues
+ * @param N dimensions of A (must be NxN)
+ * @param A input matrix (part of it will be destroyed, so copy if needed), on finish the first `nwork` columns will have eugenvectors
+ * @param eugenvals eugenvalues, must be N elements vector
+ */
+bool kad_ssyev_simple (int N, float *A, float *eugenvals);
 
 #endif
diff --git a/lualib/lua_ffi/linalg.lua b/lualib/lua_ffi/linalg.lua
index c3f6eff5a..85e84b5ac 100644
--- a/lualib/lua_ffi/linalg.lua
+++ b/lualib/lua_ffi/linalg.lua
@@ -25,13 +25,14 @@ local exports = {}
 
 ffi.cdef[[
   void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
+  bool kad_ssyev_simple (int N, float *A, float *output);
 ]]
 
 local function table_to_ffi(a, m, n)
-  local a_conv = ffi.new(string.format("float[%d][%d]", m, n), {})
+  local a_conv = ffi.new("float[?]", m * n)
   for i=1,m or #a do
     for j=1,n or #a[1] do
-      a_conv[i - 1][j - 1] = a[i][j]
+      a_conv[(i - 1) * n + (j - 1)] = a[i][j]
     end
   end
   return a_conv
@@ -58,12 +59,28 @@ exports.sgemm = function(a, m, b, n, k, trans_a, trans_b)
   if type(b) == 'table' then
     b = table_to_ffi(b, k, n)
   end
-  local res = ffi.new(string.format("float[%d][%d]", m, n), {})
+  local res = ffi.new("float[?]", m * n)
   ffi.C.kad_sgemm_simple(trans_a or 0, trans_b or 0, m, n, k, ffi.cast('const float*', a),
       ffi.cast('const float*', b), ffi.cast('float*', res))
   return res
 end
 
+exports.eugen = function(a, n)
+  if type(a) == 'table' then
+    -- Need to convert, slow!
+    n = n or #a
+    a = table_to_ffi(a, n, n)
+  end
+
+  local res = ffi.new("float[?]", n)
+
+  if ffi.C.kad_ssyev_simple(n, ffi.cast('float*', a), res) then
+    return res,a
+  end
+
+  return nil
+end
+
 exports.ffi_to_table = ffi_to_table
 exports.table_to_ffi = table_to_ffi
 


More information about the Commits mailing list