commit 44e393f: [Project] Add kann library to start torch removal

Vsevolod Stakhov vsevolod at highsecure.ru
Thu Jun 27 14:42:05 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-06-27 15:38:34 +0100
URL: https://github.com/rspamd/rspamd/commit/44e393f9fe9a86bd99ebc5cfcddfe8eb50c8813e (HEAD -> master)

[Project] Add kann library to start torch removal

---
 CMakeLists.txt              |   12 +-
 contrib/kann/CMakeLists.txt |   22 +
 contrib/kann/LICENSE.txt    |   24 +
 contrib/kann/kann.c         |  977 ++++++++++++++++++
 contrib/kann/kann.h         |  235 +++++
 contrib/kann/kautodiff.c    | 2396 +++++++++++++++++++++++++++++++++++++++++++
 contrib/kann/kautodiff.h    |  246 +++++
 7 files changed, 3908 insertions(+), 4 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 80e00e67e..7e2bb0184 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -370,7 +370,7 @@ ENDFUNCTION(INSTALL_IF_NOT_EXISTS)
 MACRO(ProcessPackage PKG_NAME)
 
 	CMAKE_PARSE_ARGUMENTS(PKG "OPTIONAL" "ROOT;INCLUDE"
-		"LIBRARY;INCLUDE_SUFFIXES;LIB_SUFFIXES;MODULES" ${ARGN})
+		"LIBRARY;INCLUDE_SUFFIXES;LIB_SUFFIXES;MODULES;LIB_OUTPUT" ${ARGN})
 
 	IF(NOT PKG_LIBRARY)
 		SET(PKG_LIBRARY "${PKG_NAME}")
@@ -378,6 +378,9 @@ MACRO(ProcessPackage PKG_NAME)
 	IF(NOT PKG_INCLUDE)
 		SET(PKG_INCLUDE "${PKG_NAME}.h")
 	ENDIF()
+	IF(NOT PKG_LIB_OUTPUT)
+		SET(PKG_LIB_OUTPUT RSPAMD_REQUIRED_LIBRARIES)
+	ENDIF()
 
 	IF(NOT PKG_ROOT AND PKG_MODULES)
 		PKG_SEARCH_MODULE(${PKG_NAME} ${PKG_MODULES})
@@ -406,7 +409,7 @@ MACRO(ProcessPackage PKG_NAME)
 		FOREACH(_arg ${${_XPREFIX}_LDFLAGS_OTHER})
 			SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${_arg}")
 		ENDFOREACH(_arg ${${_XPREFIX}_LDFLAGS_OTHER})
-		LIST(APPEND RSPAMD_REQUIRED_LIBRARIES "${${_XPREFIX}_LIBRARIES}")
+		LIST(APPEND ${PKG_LIB_OUTPUT} "${${_XPREFIX}_LIBRARIES}")
 		INCLUDE_DIRECTORIES(${${_XPREFIX}_INCLUDEDIR})
 	ELSE()
 		IF(NOT ${PKG_NAME}_GUESSED)
@@ -442,7 +445,7 @@ MACRO(ProcessPackage PKG_NAME)
 				GET_FILENAME_COMPONENT(_lib_path "${_lib}" PATH)
 				INCLUDE_DIRECTORIES("${_stripped_incl}")
 				LINK_DIRECTORIES("${_lib_path}")
-				LIST(APPEND RSPAMD_REQUIRED_LIBRARIES ${_lib})
+				LIST(APPEND ${PKG_LIB_OUTPUT} ${_lib})
 				SET(${PKG_NAME}_INCLUDE "${_stripped_incl}" CACHE INTERNAL "")
 				SET(${PKG_NAME}_LIBRARY_PATH "${_lib_path}" CACHE INTERNAL "")
 				SET(${PKG_NAME}_LIBRARY "${_lib}" CACHE INTERNAL "")
@@ -455,7 +458,7 @@ MACRO(ProcessPackage PKG_NAME)
 			MESSAGE(STATUS "Found package ${PKG_NAME} (cached)")
 			INCLUDE_DIRECTORIES("${${PKG_NAME}_INCLUDE}")
 			LINK_DIRECTORIES("${${PKG_NAME}_LIBRARY_PATH}")
-			LIST(APPEND RSPAMD_REQUIRED_LIBRARIES "${${PKG_NAME}_LIBRARY}")
+			LIST(APPEND ${PKG_LIB_OUTPUT} "${${PKG_NAME}_LIBRARY}")
 		ENDIF()
 	ENDIF(${PKG_NAME}_FOUND)
 
@@ -1211,6 +1214,7 @@ ADD_SUBDIRECTORY(contrib/lua-lpeg)
 ADD_SUBDIRECTORY(contrib/linenoise)
 ADD_SUBDIRECTORY(contrib/t1ha)
 ADD_SUBDIRECTORY(contrib/libev)
+ADD_SUBDIRECTORY(contrib/kann)
 
 IF (ENABLE_SNOWBALL MATCHES "ON")
 	LIST(APPEND RSPAMD_REQUIRED_LIBRARIES stemmer)
diff --git a/contrib/kann/CMakeLists.txt b/contrib/kann/CMakeLists.txt
new file mode 100644
index 000000000..d7bd73d28
--- /dev/null
+++ b/contrib/kann/CMakeLists.txt
@@ -0,0 +1,22 @@
+SET(LIBKANNSRC	kautodiff.c kann.c)
+
+IF(ENABLE_FULL_DEBUG MATCHES "OFF")
+    if ("${CMAKE_C_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_C_COMPILER_ID}" STREQUAL "GNU")
+        SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3")
+    endif ()
+ENDIF()
+
+ADD_LIBRARY(rspamd-kann SHARED ${LIBKANNSRC})
+
+ProcessPackage(BLAS OPTIONAL LIBRARY openblas blas
+        INCLUDE cblas.h INCLUDE_SUFFIXES include/openblas
+        include/blas
+        ROOT ${BLAS_ROOT_DIR}
+        LIB_OUTPUT BLAS_REQUIRED_LIBRARIES)
+IF(WITH_BLAS)
+    MESSAGE(STATUS "Use openblas to accelerate kann")
+    TARGET_LINK_LIBRARIES(rspamd-kann  ${BLAS_REQUIRED_LIBRARIES})
+    ADD_DEFINITIONS(-DHAVE_CBLAS)
+ENDIF(WITH_BLAS)
+
+INSTALL(TARGETS rspamd-kann LIBRARY DESTINATION ${RSPAMD_LIBDIR})
\ No newline at end of file
diff --git a/contrib/kann/LICENSE.txt b/contrib/kann/LICENSE.txt
new file mode 100644
index 000000000..8b2cf1141
--- /dev/null
+++ b/contrib/kann/LICENSE.txt
@@ -0,0 +1,24 @@
+The MIT License
+
+Copyright (c) 2018-2019 Dana-Farber Cancer Institute
+              2016-2018 Broad Institute
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c
new file mode 100644
index 000000000..0af15fb2a
--- /dev/null
+++ b/contrib/kann/kann.c
@@ -0,0 +1,977 @@
+#include <math.h>
+#include <float.h>
+#include <string.h>
+#include <stdlib.h>
+#include <assert.h>
+#include <stdarg.h>
+#include "kann.h"
+
+int kann_verbose = 3;
+
+/******************************************
+ *** @@BASIC: fundamental KANN routines ***
+ ******************************************/
+
+static void kad_ext_collate(int n, kad_node_t **a, float **_x, float **_g, float **_c)
+{
+	int i, j, k, l, n_var;
+	float *x, *g, *c;
+	n_var = kad_size_var(n, a);
+	x = *_x = (float*)realloc(*_x, n_var * sizeof(float));
+	g = *_g = (float*)realloc(*_g, n_var * sizeof(float));
+	c = *_c = (float*)realloc(*_c, kad_size_const(n, a) * sizeof(float));
+	memset(g, 0, n_var * sizeof(float));
+	for (i = j = k = 0; i < n; ++i) {
+		kad_node_t *v = a[i];
+		if (kad_is_var(v)) {
+			l = kad_len(v);
+			memcpy(&x[j], v->x, l * sizeof(float));
+			free(v->x);
+			v->x = &x[j];
+			v->g = &g[j];
+			j += l;
+		} else if (kad_is_const(v)) {
+			l = kad_len(v);
+			memcpy(&c[k], v->x, l * sizeof(float));
+			free(v->x);
+			v->x = &c[k];
+			k += l;
+		}
+	}
+}
+
+static void kad_ext_sync(int n, kad_node_t **a, float *x, float *g, float *c)
+{
+	int i, j, k;
+	for (i = j = k = 0; i < n; ++i) {
+		kad_node_t *v = a[i];
+		if (kad_is_var(v)) {
+			v->x = &x[j];
+			v->g = &g[j];
+			j += kad_len(v);
+		} else if (kad_is_const(v)) {
+			v->x = &c[k];
+			k += kad_len(v);
+		}
+	}
+}
+
+kann_t *kann_new(kad_node_t *cost, int n_rest, ...)
+{
+	kann_t *a;
+	int i, n_roots = 1 + n_rest, has_pivot = 0, has_recur = 0;
+	kad_node_t **roots;
+	va_list ap;
+
+	if (cost->n_d != 0) return 0;
+
+	va_start(ap, n_rest);
+	roots = (kad_node_t**)malloc((n_roots + 1) * sizeof(kad_node_t*));
+	for (i = 0; i < n_rest; ++i)
+		roots[i] = va_arg(ap, kad_node_t*);
+	roots[i++] = cost;
+	va_end(ap);
+
+	cost->ext_flag |= KANN_F_COST;
+	a = (kann_t*)calloc(1, sizeof(kann_t));
+	a->v = kad_compile_array(&a->n, n_roots, roots);
+
+	for (i = 0; i < a->n; ++i) {
+		if (a->v[i]->pre) has_recur = 1;
+		if (kad_is_pivot(a->v[i])) has_pivot = 1;
+	}
+	if (has_recur && !has_pivot) { /* an RNN that doesn't have a pivot; then add a pivot on top of cost and recompile */
+		cost->ext_flag &= ~KANN_F_COST;
+		roots[n_roots-1] = cost = kad_avg(1, &cost), cost->ext_flag |= KANN_F_COST;
+		free(a->v);
+		a->v = kad_compile_array(&a->n, n_roots, roots);
+	}
+	kad_ext_collate(a->n, a->v, &a->x, &a->g, &a->c);
+	free(roots);
+	return a;
+}
+
+kann_t *kann_clone(kann_t *a, int batch_size)
+{
+	kann_t *b;
+	b = (kann_t*)calloc(1, sizeof(kann_t));
+	b->n = a->n;
+	b->v = kad_clone(a->n, a->v, batch_size);
+	kad_ext_collate(b->n, b->v, &b->x, &b->g, &b->c);
+	return b;
+}
+
+kann_t *kann_unroll_array(kann_t *a, int *len)
+{
+	kann_t *b;
+	b = (kann_t*)calloc(1, sizeof(kann_t));
+	b->x = a->x, b->g = a->g, b->c = a->c; /* these arrays are shared */
+	b->v = kad_unroll(a->n, a->v, &b->n, len);
+	return b;
+}
+
+kann_t *kann_unroll(kann_t *a, ...)
+{
+	kann_t *b;
+	va_list ap;
+	int i, n_pivots, *len;
+	n_pivots = kad_n_pivots(a->n, a->v);
+	len = (int*)calloc(n_pivots, sizeof(int));
+	va_start(ap, a);
+	for (i = 0; i < n_pivots; ++i) len[i] = va_arg(ap, int);
+	va_end(ap);
+	b = kann_unroll_array(a, len);
+	free(len);
+	return b;
+}
+
+void kann_delete_unrolled(kann_t *a)
+{
+	if (a && a->mt) kann_mt(a, 0, 0);
+	if (a && a->v) kad_delete(a->n, a->v);
+	free(a);
+}
+
+void kann_delete(kann_t *a)
+{
+	if (a == 0) return;
+	free(a->x); free(a->g); free(a->c);
+	kann_delete_unrolled(a);
+}
+
+static void kann_switch_core(kann_t *a, int is_train)
+{
+	int i;
+	for (i = 0; i < a->n; ++i)
+		if (a->v[i]->op == 12 && a->v[i]->n_child == 2)
+			*(int32_t*)a->v[i]->ptr = !!is_train;
+}
+
+#define chk_flg(flag, mask) ((mask) == 0 || ((flag) & (mask)))
+#define chk_lbl(label, query) ((query) == 0 || (label) == (query))
+
+int kann_find(const kann_t *a, uint32_t ext_flag, int32_t ext_label)
+{
+	int i, k, r = -1;
+	for (i = k = 0; i < a->n; ++i)
+		if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
+			++k, r = i;
+	return k == 1? r : k == 0? -1 : -2;
+}
+
+int kann_feed_bind(kann_t *a, uint32_t ext_flag, int32_t ext_label, float **x)
+{
+	int i, k;
+	if (x == 0) return 0;
+	for (i = k = 0; i < a->n; ++i)
+		if (kad_is_feed(a->v[i]) && chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
+			a->v[i]->x = x[k++];
+	return k;
+}
+
+int kann_feed_dim(const kann_t *a, uint32_t ext_flag, int32_t ext_label)
+{
+	int i, k, n = 0;
+	for (i = k = 0; i < a->n; ++i)
+		if (kad_is_feed(a->v[i]) && chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
+			++k, n = a->v[i]->n_d > 1? kad_len(a->v[i]) / a->v[i]->d[0] : a->v[i]->n_d == 1? a->v[i]->d[0] : 1;
+	return k == 1? n : k == 0? -1 : -2;
+}
+
+static float kann_cost_core(kann_t *a, int cost_label, int cal_grad)
+{
+	int i_cost;
+	float cost;
+	i_cost = kann_find(a, KANN_F_COST, cost_label);
+	assert(i_cost >= 0);
+	cost = *kad_eval_at(a->n, a->v, i_cost);
+	if (cal_grad) kad_grad(a->n, a->v, i_cost);
+	return cost;
+}
+
+int kann_eval(kann_t *a, uint32_t ext_flag, int ext_label)
+{
+	int i, k;
+	for (i = k = 0; i < a->n; ++i)
+		if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
+			++k, a->v[i]->tmp = 1;
+	kad_eval_marked(a->n, a->v);
+	return k;
+}
+
+void kann_rnn_start(kann_t *a)
+{
+	int i;
+	kann_set_batch_size(a, 1);
+	for (i = 0; i < a->n; ++i) {
+		kad_node_t *p = a->v[i];
+		if (p->pre) { /* NB: BE CAREFUL of the interaction between kann_rnn_start() and kann_set_batch_size() */
+			kad_node_t *q = p->pre;
+			if (q->x) memcpy(p->x, q->x, kad_len(p) * sizeof(float));
+			else memset(p->x, 0, kad_len(p) * sizeof(float));
+			if (q->n_child > 0) free(q->x);
+			q->x = p->x;
+		}
+	}
+}
+
+void kann_rnn_end(kann_t *a)
+{
+	int i;
+	kad_ext_sync(a->n, a->v, a->x, a->g, a->c);
+	for (i = 0; i < a->n; ++i)
+		if (a->v[i]->pre && a->v[i]->pre->n_child > 0)
+			a->v[i]->pre->x = (float*)calloc(kad_len(a->v[i]->pre), sizeof(float));
+}
+
+static int kann_class_error_core(const kann_t *ann, int *base)
+{
+	int i, j, k, m, n, off, n_err = 0;
+	for (i = 0, *base = 0; i < ann->n; ++i) {
+		kad_node_t *p = ann->v[i];
+		if (((p->op == 13 && (p->n_child == 2 || p->n_child == 3)) || (p->op == 22 && p->n_child == 2)) && p->n_d == 0) { /* ce_bin or ce_multi */
+			kad_node_t *x = p->child[0], *t = p->child[1];
+			n = t->d[t->n_d - 1], m = kad_len(t) / n;
+			for (j = off = 0; j < m; ++j, off += n) {
+				float t_sum = 0.0f, t_min = 1.0f, t_max = 0.0f, x_max = 0.0f, x_min = 1.0f;
+				int x_max_k = -1, t_max_k = -1;
+				for (k = 0; k < n; ++k) {
+					float xk = x->x[off+k], tk = t->x[off+k];
+					t_sum += tk;
+					t_min = t_min < tk? t_min : tk;
+					x_min = x_min < xk? x_min : xk;
+					if (t_max < tk) t_max = tk, t_max_k = k;
+					if (x_max < xk) x_max = xk, x_max_k = k;
+				}
+				if (t_sum - 1.0f == 0 && t_min >= 0.0f && x_min >= 0.0f && x_max <= 1.0f) {
+					++(*base);
+					n_err += (x_max_k != t_max_k);
+				}
+			}
+		}
+	}
+	return n_err;
+}
+
+/*************************
+ * @@MT: multi-threading *
+ *************************/
+
+#ifdef HAVE_PTHREAD
+#include <pthread.h>
+
+struct mtaux_t;
+
+typedef struct { /* per-worker data */
+	kann_t *a;
+	float cost;
+	int action;
+	pthread_t tid;
+	struct mtaux_t *g;
+} mtaux1_t;
+
+typedef struct mtaux_t { /* cross-worker data */
+	int n_threads, max_batch_size;
+	int cal_grad, cost_label, eval_out;
+	volatile int n_idle; /* we will be busy waiting on this, so volatile necessary */
+	pthread_mutex_t mtx;
+	pthread_cond_t cv;
+	mtaux1_t *mt;
+} mtaux_t;
+
+static void *mt_worker(void *data) /* pthread worker */
+{
+	mtaux1_t *mt1 = (mtaux1_t*)data;
+	mtaux_t *mt = mt1->g;
+	for (;;) {
+		int action;
+		pthread_mutex_lock(&mt->mtx);
+		mt1->action = 0;
+		++mt->n_idle;
+		while (mt1->action == 0)
+			pthread_cond_wait(&mt->cv, &mt->mtx);
+		action = mt1->action;
+		pthread_mutex_unlock(&mt->mtx);
+		if (action == -1) break;
+
+		if (mt->eval_out) kann_eval(mt1->a, KANN_F_OUT, 0);
+		else mt1->cost = kann_cost_core(mt1->a, mt->cost_label, mt->cal_grad);
+	}
+	pthread_exit(0);
+}
+
+static void mt_destroy(mtaux_t *mt) /* de-allocate an entire mtaux_t struct */
+{
+	int i;
+	pthread_mutex_lock(&mt->mtx);
+	mt->n_idle = 0;
+	for (i = 1; i < mt->n_threads; ++i) mt->mt[i].action = -1;
+	pthread_cond_broadcast(&mt->cv);
+	pthread_mutex_unlock(&mt->mtx);
+	for (i = 1; i < mt->n_threads; ++i) pthread_join(mt->mt[i].tid, 0);
+	for (i = 0; i < mt->n_threads; ++i) kann_delete(mt->mt[i].a);
+	free(mt->mt);
+	pthread_cond_destroy(&mt->cv);
+	pthread_mutex_destroy(&mt->mtx);
+	free(mt);
+}
+
+void kann_mt(kann_t *ann, int n_threads, int max_batch_size)
+{
+	mtaux_t *mt;
+	int i, k;
+
+	if (n_threads <= 1) {
+		if (ann->mt) mt_destroy((mtaux_t*)ann->mt);
+		ann->mt = 0;
+		return;
+	}
+	if (n_threads > max_batch_size) n_threads = max_batch_size;
+	if (n_threads <= 1) return;
+
+	mt = (mtaux_t*)calloc(1, sizeof(mtaux_t));
+	mt->n_threads = n_threads, mt->max_batch_size = max_batch_size;
+	pthread_mutex_init(&mt->mtx, 0);
+	pthread_cond_init(&mt->cv, 0);
+	mt->mt = (mtaux1_t*)calloc(n_threads, sizeof(mtaux1_t));
+	for (i = k = 0; i < n_threads; ++i) {
+		int size = (max_batch_size - k) / (n_threads - i);
+		mt->mt[i].a = kann_clone(ann, size);
+		mt->mt[i].g = mt;
+		k += size;
+	}
+	for (i = 1; i < n_threads; ++i)
+		pthread_create(&mt->mt[i].tid, 0, mt_worker, &mt->mt[i]);
+	while (mt->n_idle < n_threads - 1); /* busy waiting until all threads in sync */
+	ann->mt = mt;
+}
+
+static void mt_kickoff(kann_t *a, int cost_label, int cal_grad, int eval_out)
+{
+	mtaux_t *mt = (mtaux_t*)a->mt;
+	int i, j, k, B, n_var;
+
+	B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
+	assert(B <= mt->max_batch_size); /* TODO: can be relaxed */
+	n_var = kann_size_var(a);
+
+	pthread_mutex_lock(&mt->mtx);
+	mt->cost_label = cost_label, mt->cal_grad = cal_grad, mt->eval_out = eval_out;
+	for (i = k = 0; i < mt->n_threads; ++i) {
+		int size = (B - k) / (mt->n_threads - i);
+		for (j = 0; j < a->n; ++j)
+			if (kad_is_feed(a->v[j]))
+				mt->mt[i].a->v[j]->x = &a->v[j]->x[k * kad_len(a->v[j]) / a->v[j]->d[0]];
+		kad_sync_dim(mt->mt[i].a->n, mt->mt[i].a->v, size); /* TODO: we can point ->x to internal nodes, too */
+		k += size;
+		memcpy(mt->mt[i].a->x, a->x, n_var * sizeof(float));
+		mt->mt[i].action = 1;
+	}
+	mt->n_idle = 0;
+	pthread_cond_broadcast(&mt->cv);
+	pthread_mutex_unlock(&mt->mtx);
+}
+
+float kann_cost(kann_t *a, int cost_label, int cal_grad)
+{
+	mtaux_t *mt = (mtaux_t*)a->mt;
+	int i, j, B, k, n_var;
+	float cost;
+
+	if (mt == 0) return kann_cost_core(a, cost_label, cal_grad);
+	B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
+	n_var = kann_size_var(a);
+
+	mt_kickoff(a, cost_label, cal_grad, 0);
+	mt->mt[0].cost = kann_cost_core(mt->mt[0].a, cost_label, cal_grad);
+	while (mt->n_idle < mt->n_threads - 1); /* busy waiting until all threads in sync */
+
+	memset(a->g, 0, n_var * sizeof(float)); /* TODO: check if this is necessary when cal_grad is false */
+	for (i = k = 0, cost = 0.0f; i < mt->n_threads; ++i) {
+		int size = (B - k) / (mt->n_threads - i);
+		cost += mt->mt[i].cost * size / B;
+		kad_saxpy(n_var, (float)size / B, mt->mt[i].a->g, a->g);
+		k += size;
+	}
+	for (j = 0; j < a->n; ++j) { /* copy values back at recurrent nodes (needed by textgen; TODO: temporary solution) */
+		kad_node_t *p = a->v[j];
+		if (p->pre && p->n_d >= 2 && p->d[0] == B) {
+			for (i = k = 0; i < mt->n_threads; ++i) {
+				kad_node_t *q = mt->mt[i].a->v[j];
+				memcpy(&p->x[k], q->x, kad_len(q) * sizeof(float));
+				k += kad_len(q);
+			}
+		}
+	}
+	return cost;
+}
+
+int kann_eval_out(kann_t *a)
+{
+	mtaux_t *mt = (mtaux_t*)a->mt;
+	int j, B, n_eval;
+	if (mt == 0) return kann_eval(a, KANN_F_OUT, 0);
+	B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
+	mt_kickoff(a, 0, 0, 1);
+	n_eval = kann_eval(mt->mt[0].a, KANN_F_OUT, 0);
+	while (mt->n_idle < mt->n_threads - 1); /* busy waiting until all threads in sync */
+	for (j = 0; j < a->n; ++j) { /* copy output values back */
+		kad_node_t *p = a->v[j];
+		if (p->ext_flag & KANN_F_OUT) {
+			int i, t, k, d0 = p->d[0] / B, d1 = 1; /* for RNN, p->d[0] may equal unroll_len * batch_size */
+			assert(p->d[0] % B == 0);
+			for (i = 1; i < p->n_d; ++i) d1 *= p->d[i];
+			for (i = 0; i < d0; ++i) {
+				for (t = k = 0; t < mt->n_threads; ++t) { /* similar to the forward pass of kad_op_concat() */
+					kad_node_t *q = mt->mt[t].a->v[j];
+					int size = q->d[0] / d0;
+					memcpy(&p->x[(i * B + k) * d1], &q->x[i * size * d1], size * d1 * sizeof(float));
+					k += size;
+				}
+			}
+		}
+	}
+	return n_eval;
+}
+
+int kann_class_error(const kann_t *ann, int *base)
+{
+	mtaux_t *mt = (mtaux_t*)ann->mt;
+	int i, n_err = 0, b = 0;
+	if (mt == 0) return kann_class_error_core(ann, base);
+	for (i = 0; i < mt->n_threads; ++i) {
+		n_err += kann_class_error_core(mt->mt[i].a, &b);
+		*base += b;
+	}
+	return n_err;
+}
+
+void kann_switch(kann_t *ann, int is_train)
+{
+	mtaux_t *mt = (mtaux_t*)ann->mt;
+	int i;
+	if (mt == 0) {
+		kann_switch_core(ann, is_train);
+		return;
+	}
+	for (i = 0; i < mt->n_threads; ++i)
+		kann_switch_core(mt->mt[i].a, is_train);
+}
+#else
+void kann_mt(kann_t *ann, int n_threads, int max_batch_size) {}
+float kann_cost(kann_t *a, int cost_label, int cal_grad) { return kann_cost_core(a, cost_label, cal_grad); }
+int kann_eval_out(kann_t *a) { return kann_eval(a, KANN_F_OUT, 0); }
+int kann_class_error(const kann_t *a, int *base) { return kann_class_error_core(a, base); }
+void kann_switch(kann_t *ann, int is_train) { return kann_switch_core(ann, is_train); }
+#endif
+
+/***********************
+ *** @@IO: model I/O ***
+ ***********************/
+
+#define KANN_MAGIC "KAN\1"
+
+void kann_save_fp(FILE *fp, kann_t *ann)
+{
+	kann_set_batch_size(ann, 1);
+	fwrite(KANN_MAGIC, 1, 4, fp);
+	kad_save(fp, ann->n, ann->v);
+	fwrite(ann->x, sizeof(float), kann_size_var(ann), fp);
+	fwrite(ann->c, sizeof(float), kann_size_const(ann), fp);
+}
+
+void kann_save(const char *fn, kann_t *ann)
+{
+	FILE *fp;
+	fp = fn && strcmp(fn, "-")? fopen(fn, "wb") : stdout;
+	kann_save_fp(fp, ann);
+	fclose(fp);
+}
+
+kann_t *kann_load_fp(FILE *fp)
+{
+	char magic[4];
+	kann_t *ann;
+	int n_var, n_const;
+
+	fread(magic, 1, 4, fp);
+	if (strncmp(magic, KANN_MAGIC, 4) != 0) {
+		fclose(fp);
+		return 0;
+	}
+	ann = (kann_t*)calloc(1, sizeof(kann_t));
+	ann->v = kad_load(fp, &ann->n);
+	n_var = kad_size_var(ann->n, ann->v);
+	n_const = kad_size_const(ann->n, ann->v);
+	ann->x = (float*)malloc(n_var * sizeof(float));
+	ann->g = (float*)calloc(n_var, sizeof(float));
+	ann->c = (float*)malloc(n_const * sizeof(float));
+	fread(ann->x, sizeof(float), n_var, fp);
+	fread(ann->c, sizeof(float), n_const, fp);
+	kad_ext_sync(ann->n, ann->v, ann->x, ann->g, ann->c);
+	return ann;
+}
+
+kann_t *kann_load(const char *fn)
+{
+	FILE *fp;
+	kann_t *ann;
+	fp = fn && strcmp(fn, "-")? fopen(fn, "rb") : stdin;
+	ann = kann_load_fp(fp);
+	fclose(fp);
+	return ann;
+}
+
+/**********************************************
+ *** @@LAYER: layers and model generation ***
+ **********************************************/
+
+/********** General but more complex APIs **********/
+
+kad_node_t *kann_new_leaf_array(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, int32_t d[KAD_MAX_DIM])
+{
+	int i, len, off = offset && par? *offset : -1;
+	kad_node_t *p;
+
+	if (off >= 0 && par[off]) return par[(*offset)++];
+	p = (kad_node_t*)calloc(1, sizeof(kad_node_t));
+	p->n_d = n_d, p->flag = flag;
+	memcpy(p->d, d, n_d * sizeof(int32_t));
+	len = kad_len(p);
+	p->x = (float*)calloc(len, sizeof(float));
+	if (p->n_d <= 1) {
+		for (i = 0; i < len; ++i)
+			p->x[i] = x0_01;
+	} else {
+		double sdev_inv;
+		sdev_inv = 1.0 / sqrt((double)len / p->d[0]);
+		for (i = 0; i < len; ++i)
+			p->x[i] = (float)(kad_drand_normal(0) * sdev_inv);
+	}
+	if (off >= 0) par[off] = p, ++(*offset);
+	return p;
+}
+
+kad_node_t *kann_new_leaf2(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, ...)
+{
+	int32_t i, d[KAD_MAX_DIM];
+	va_list ap;
+	va_start(ap, n_d); for (i = 0; i < n_d; ++i) d[i] = va_arg(ap, int); va_end(ap);
+	return kann_new_leaf_array(offset, par, flag, x0_01, n_d, d);
+}
*** OUTPUT TRUNCATED, 3311 LINES SKIPPED ***


More information about the Commits mailing list