commit b821683: [Project] Add tensors index method

Vsevolod Stakhov vsevolod at highsecure.ru
Wed Aug 5 20:07:14 UTC 2020


Author: Vsevolod Stakhov
Date: 2020-08-05 21:04:32 +0100
URL: https://github.com/rspamd/rspamd/commit/b8216839b2a9f083259b71947bf0caa4b4eef091

[Project] Add tensors index method

---
 src/lua/lua_tensor.c | 50 +++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 49 insertions(+), 1 deletion(-)

diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index 21bdf9673..85aaa2e95 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -32,6 +32,7 @@ LUA_FUNCTION_DEF (tensor, fromtable);
 LUA_FUNCTION_DEF (tensor, destroy);
 LUA_FUNCTION_DEF (tensor, mul);
 LUA_FUNCTION_DEF (tensor, tostring);
+LUA_FUNCTION_DEF (tensor, index);
 
 static luaL_reg rspamd_tensor_f[] = {
 		LUA_INTERFACE_DEF (tensor, load),
@@ -45,8 +46,9 @@ static luaL_reg rspamd_tensor_m[] = {
 		{"__gc", lua_tensor_destroy},
 		{"__mul", lua_tensor_mul},
 		{"mul", lua_tensor_mul},
-		{"__tostring", lua_tensor_tostring},
 		{"tostring", lua_tensor_tostring},
+		{"__tostring", lua_tensor_tostring},
+		{"__index", lua_tensor_index},
 		{NULL, NULL},
 };
 
@@ -284,6 +286,52 @@ lua_tensor_tostring (lua_State *L)
 	return 1;
 }
 
+static gint
+lua_tensor_index (lua_State *L)
+{
+	struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+	gint idx;
+
+	if (t) {
+		if (lua_isnumber (L, 2)) {
+			idx = lua_tointeger (L, 2);
+
+			if (t->ndims == 1) {
+				/* Individual element */
+				if (idx <= t->dim[0]) {
+					lua_pushnumber (L, t->data[idx - 1]);
+				}
+				else {
+					lua_pushnil (L);
+				}
+			}
+			else {
+				/* Push row */
+				gint dim = t->dim[1];
+
+
+				if (idx <= t->dim[0]) {
+					struct rspamd_lua_tensor *res =
+							lua_newtensor (L, 1, &dim, false);
+					for (gint i = 0; i < dim; i++) {
+						res->data[i] = t->data[(idx - 1) * t->dim[1] + i];
+					}
+				}
+				else {
+					lua_pushnil (L);
+				}
+			}
+		}
+		else if (lua_isstring (L, 2)) {
+			lua_getmetatable (L, 1);
+			lua_pushvalue (L, 2);
+			lua_rawget (L, -2);
+		}
+	}
+
+	return 1;
+}
+
 /***
  * @method tensor:mul(other, [transA, [transB]])
  * Multiply two tensors (optionally transposed) and return a new tensor


More information about the Commits mailing list