commit ad97a14: [Minor] Lua_tensor: Implement non-owning tensors (slices)
Vsevolod Stakhov
vsevolod at highsecure.ru
Wed Aug 19 13:07:06 UTC 2020
Author: Vsevolod Stakhov
Date: 2020-08-19 12:43:06 +0100
URL: https://github.com/rspamd/rspamd/commit/ad97a143fd5dae292d901402828e4fb059de0b7e
[Minor] Lua_tensor: Implement non-owning tensors (slices)
---
src/lua/lua_tensor.c | 45 ++++++++++++++++++++++++++++++---------------
src/lua/lua_tensor.h | 2 +-
2 files changed, 31 insertions(+), 16 deletions(-)
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index 91fcd763e..cf91006d0 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -53,7 +53,7 @@ static luaL_reg rspamd_tensor_m[] = {
};
static struct rspamd_lua_tensor *
-lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill)
+lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own)
{
struct rspamd_lua_tensor *res;
@@ -68,10 +68,16 @@ lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill)
}
/* To avoid allocating large stuff in Lua */
- res->data = g_malloc (sizeof (rspamd_tensor_num_t) * res->size);
+ if (own) {
+ res->data = g_malloc (sizeof (rspamd_tensor_num_t) * res->size);
- if (zero_fill) {
- memset (res->data, 0, sizeof (rspamd_tensor_num_t) * res->size);
+ if (zero_fill) {
+ memset (res->data, 0, sizeof (rspamd_tensor_num_t) * res->size);
+ }
+ }
+ else {
+ /* Mark size negative to distinguish */
+ res->size = -(res->size);
}
rspamd_lua_setclass (L, TENSOR_CLASS, -1);
@@ -96,7 +102,7 @@ lua_tensor_new (lua_State *L)
dims[i] = lua_tointeger (L, i + 2);
}
- (void)lua_newtensor (L, ndims, dims, true);
+ (void)lua_newtensor (L, ndims, dims, true, true);
}
else {
return luaL_error (L, "incorrect dimensions number: %d", ndims);
@@ -124,7 +130,7 @@ lua_tensor_fromtable (lua_State *L)
dims[1] = rspamd_lua_table_size (L, 1);
struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
- dims, false);
+ dims, false, true);
for (guint i = 0; i < dims[1]; i ++) {
lua_rawgeti (L, 1, i + 1);
@@ -179,7 +185,7 @@ lua_tensor_fromtable (lua_State *L)
dims[1] = ncols;
struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
- dims, false);
+ dims, false, true);
for (gint i = 0; i < nrows; i ++) {
lua_rawgeti (L, 1, i + 1);
@@ -219,7 +225,9 @@ lua_tensor_destroy (lua_State *L)
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
if (t) {
- g_free (t->data);
+ if (t->size > 0) {
+ g_free (t->data);
+ }
}
return 0;
@@ -234,19 +242,27 @@ static gint
lua_tensor_save (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+ gint size;
if (t) {
- gsize sz = sizeof (gint) * 4 + t->size * sizeof (rspamd_tensor_num_t);
+ if (t->size > 0) {
+ size = t->size;
+ }
+ else {
+ size = -(t->size);
+ }
+
+ gsize sz = sizeof (gint) * 4 + size * sizeof (rspamd_tensor_num_t);
guchar *data;
struct rspamd_lua_text *out = lua_new_text (L, NULL, 0, TRUE);
data = g_malloc (sz);
memcpy (data, &t->ndims, sizeof (int));
- memcpy (data + sizeof (int), &t->size, sizeof (int));
+ memcpy (data + sizeof (int), &size, sizeof (int));
memcpy (data + 2 * sizeof (int), t->dim, sizeof (int) * 2);
memcpy (data + 4 * sizeof (int), t->data,
- t->size * sizeof (rspamd_tensor_num_t));
+ size * sizeof (rspamd_tensor_num_t));
out->start = (const gchar *)data;
out->len = sz;
@@ -324,11 +340,10 @@ lua_tensor_index (lua_State *L)
if (idx <= t->dim[0]) {
+ /* Non-owning tensor */
struct rspamd_lua_tensor *res =
- lua_newtensor (L, 1, &dim, false);
- for (gint i = 0; i < dim; i++) {
- res->data[i] = t->data[(idx - 1) * t->dim[1] + i];
- }
+ lua_newtensor (L, 1, &dim, false, false);
+ res->data = &t->data[(idx - 1) * t->dim[1]];
}
else {
lua_pushnil (L);
diff --git a/src/lua/lua_tensor.h b/src/lua/lua_tensor.h
index 554245f0b..e4c110011 100644
--- a/src/lua/lua_tensor.h
+++ b/src/lua/lua_tensor.h
@@ -23,8 +23,8 @@ typedef float rspamd_tensor_num_t;
struct rspamd_lua_tensor {
int ndims;
int size; /* overall size (product of dims) */
- rspamd_tensor_num_t *data;
int dim[2];
+ rspamd_tensor_num_t *data;
};
struct rspamd_lua_tensor *lua_check_tensor (lua_State *L, int pos);
More information about the Commits
mailing list