commit 2d80fda: [Fix] Avoid cyclic references in symcache and fix memory leaks

Vsevolod Stakhov vsevolod at rspamd.com
Sat Sep 17 16:14:05 UTC 2022


Author: Vsevolod Stakhov
Date: 2022-09-17 17:10:22 +0100
URL: https://github.com/rspamd/rspamd/commit/2d80fdad5fb5e918a5bbe8385fafb67eb7e4797a (HEAD -> master)

[Fix] Avoid cyclic references in symcache and fix memory leaks

---
 src/libserver/symcache/symcache_impl.cxx     | 48 ++++++++++++++--------------
 src/libserver/symcache/symcache_internal.hxx | 18 +++++------
 src/libserver/symcache/symcache_item.cxx     | 14 ++++----
 src/libserver/symcache/symcache_item.hxx     | 42 ++++++++++++------------
 src/libserver/symcache/symcache_runtime.cxx  |  8 ++---
 5 files changed, 66 insertions(+), 64 deletions(-)

diff --git a/src/libserver/symcache/symcache_impl.cxx b/src/libserver/symcache/symcache_impl.cxx
index 7b1127a7a..a7721f6df 100644
--- a/src/libserver/symcache/symcache_impl.cxx
+++ b/src/libserver/symcache/symcache_impl.cxx
@@ -151,12 +151,12 @@ auto symcache::init() -> bool
 
 		auto &additional_vec = get_item_specific_vector(*deleted_element_refcount);
 #if defined(__cpp_lib_erase_if)
-		std::erase_if(additional_vec, [id_to_disable](const cache_item_ptr &elt) {
+		std::erase_if(additional_vec, [id_to_disable](cache_item *elt) {
 			return elt->id == id_to_disable;
 		});
 #else
 		auto it = std::remove_if(additional_vec.begin(),
-		additional_vec.end(), [id_to_disable](const cache_item_ptr &elt) {
+		additional_vec.end(), [id_to_disable](cache_item *elt) {
 			return elt->id == id_to_disable;
 		});
 		additional_vec.erase(it, additional_vec.end());
@@ -499,7 +499,7 @@ auto symcache::get_item_by_name(std::string_view name, bool resolve_parent) cons
 		return it->second->get_parent(*this);
 	}
 
-	return it->second.get();
+	return it->second;
 }
 
 auto symcache::get_item_by_name_mut(std::string_view name, bool resolve_parent) const -> cache_item *
@@ -514,7 +514,7 @@ auto symcache::get_item_by_name_mut(std::string_view name, bool resolve_parent)
 		return (cache_item *) it->second->get_parent(*this);
 	}
 
-	return it->second.get();
+	return it->second;
 }
 
 auto symcache::add_dependency(int id_from, std::string_view to, int virtual_id_from) -> void
@@ -523,7 +523,7 @@ auto symcache::add_dependency(int id_from, std::string_view to, int virtual_id_f
 	const auto &source = items_by_id[id_from];
 	g_assert (source.get() != nullptr);
 
-	source->deps.emplace_back(cache_item_ptr{nullptr},
+	source->deps.emplace_back(nullptr,
 			std::string(to),
 			id_from,
 			-1);
@@ -534,7 +534,7 @@ auto symcache::add_dependency(int id_from, std::string_view to, int virtual_id_f
 		/* We need that for settings id propagation */
 		const auto &vsource = items_by_id[virtual_id_from];
 		g_assert (vsource.get() != nullptr);
-		vsource->deps.emplace_back(cache_item_ptr{nullptr},
+		vsource->deps.emplace_back(nullptr,
 				std::string(to),
 				-1,
 				virtual_id_from);
@@ -557,7 +557,7 @@ auto symcache::resort() -> void
 			total_hits += it->st->total_hits;
 			/* Unmask topological order */
 			it->order = 0;
-			ord->d.emplace_back(it);
+			ord->d.emplace_back(it->getptr());
 		}
 	}
 
@@ -614,7 +614,7 @@ auto symcache::resort() -> void
 
 		for (const auto &dep: it->deps) {
 			msg_debug_cache_lambda("visiting dep: %s (%d)", dep.item->symbol.c_str(), cur_order + 1);
-			rec(dep.item.get(), cur_order + 1, rec);
+			rec(dep.item, cur_order + 1, rec);
 		}
 
 		it->order = cur_order;
@@ -686,7 +686,7 @@ auto symcache::resort() -> void
 	constexpr auto append_items_vec = [](const auto &vec, auto &out) {
 		for (const auto &it: vec) {
 			if (it) {
-				out.emplace_back(it);
+				out.emplace_back(it->getptr());
 			}
 		}
 	};
@@ -700,7 +700,7 @@ auto symcache::resort() -> void
 
 	/* After sorting is done, we can assign all elements in the by_symbol hash */
 	for (const auto [i, it] : rspamd::enumerate(ord->d)) {
-		ord->by_symbol[it->get_name()] = i;
+		ord->by_symbol.emplace(it->get_name(), i);
 		ord->by_cache_id[it->id] = i;
 	}
 	/* Finally set the current order */
@@ -768,9 +768,9 @@ auto symcache::add_symbol_with_callback(std::string_view name,
 			priority, func, user_data,
 			real_type_pair.first, real_type_pair.second);
 
-	items_by_symbol[item->get_name()] = item;
-	get_item_specific_vector(*item).push_back(item);
-	items_by_id.emplace(id, item);
+	items_by_symbol.emplace(item->get_name(), item.get());
+	get_item_specific_vector(*item).push_back(item.get());
+	items_by_id.emplace(id, std::move(item)); // Takes ownership
 
 	if (!(real_type_pair.second & SYMBOL_TYPE_NOSTAT)) {
 		cksum = t1ha(name.data(), name.size(), cksum);
@@ -813,11 +813,11 @@ auto symcache::add_virtual_symbol(std::string_view name, int parent_id, enum rsp
 			id,
 			std::string{name},
 			parent_id, real_type_pair.first, real_type_pair.second);
-	const auto &parent = items_by_id[parent_id];
-	parent->add_child(item);
-	items_by_symbol[item->get_name()] = item;
-	get_item_specific_vector(*item).push_back(item);
-	items_by_id.emplace(id, item);
+	const auto &parent = items_by_id[parent_id].get();
+	parent->add_child(item.get());
+	items_by_symbol.emplace(item->get_name(), item.get());
+	get_item_specific_vector(*item).push_back(item.get());
+	items_by_id.emplace(id, std::move(item)); // Takes ownership
 
 	return id;
 }
@@ -1194,12 +1194,12 @@ auto symcache::get_max_timeout(std::vector<std::pair<double, const cache_item *>
 	auto log_func = RSPAMD_LOG_FUNC;
 	ankerl::unordered_dense::set<const cache_item *> seen_items;
 
-	auto get_item_timeout = [](const cache_item_ptr &it) {
+	auto get_item_timeout = [](cache_item *it) {
 		return it->get_numeric_augmentation("timeout").value_or(0.0);
 	};
 
 	/* This function returns the timeout for an item and all it's dependencies */
-	auto get_filter_timeout = [&](const cache_item_ptr &it, auto self) -> double {
+	auto get_filter_timeout = [&](cache_item *it, auto self) -> double {
 		auto own_timeout = get_item_timeout(it);
 		auto max_child_timeout = 0.0;
 
@@ -1241,7 +1241,7 @@ auto symcache::get_max_timeout(std::vector<std::pair<double, const cache_item *>
 
 			if (timeout > max_timeout) {
 				max_timeout = timeout;
-				max_elt = it.get();
+				max_elt = it;
 			}
 		}
 
@@ -1274,9 +1274,9 @@ auto symcache::get_max_timeout(std::vector<std::pair<double, const cache_item *>
 
 		if (timeout > max_filters_timeout) {
 			max_filters_timeout = timeout;
-			if (!seen_items.contains(it.get())) {
-				elts.emplace_back(timeout, it.get());
-				seen_items.insert(it.get());
+			if (!seen_items.contains(it)) {
+				elts.emplace_back(timeout, it);
+				seen_items.insert(it);
 			}
 		}
 	}
diff --git a/src/libserver/symcache/symcache_internal.hxx b/src/libserver/symcache/symcache_internal.hxx
index f97c2694e..f2a1e6669 100644
--- a/src/libserver/symcache/symcache_internal.hxx
+++ b/src/libserver/symcache/symcache_internal.hxx
@@ -236,9 +236,9 @@ struct delayed_symbol_elt_hash {
 
 class symcache {
 private:
-	using items_ptr_vec = std::vector<cache_item_ptr>;
+	using items_ptr_vec = std::vector<cache_item *>;
 	/* Map indexed by symbol name: all symbols must have unique names, so this map holds ownership */
-	ankerl::unordered_dense::map<std::string_view, cache_item_ptr> items_by_symbol;
+	ankerl::unordered_dense::map<std::string_view, cache_item *> items_by_symbol;
 	ankerl::unordered_dense::map<int, cache_item_ptr> items_by_id;
 
 	/* Items sorted into some order */
@@ -502,7 +502,7 @@ public:
 	template<typename Functor>
 	auto symbols_foreach(Functor f) -> void {
 		for (const auto &sym_it : items_by_symbol) {
-			f(sym_it.second.get());
+			f(sym_it.second);
 		}
 	}
 
@@ -514,7 +514,7 @@ public:
 	template<typename Functor>
 	auto composites_foreach(Functor f) -> void {
 		for (const auto &sym_it : composites) {
-			f(sym_it.get());
+			f(sym_it);
 		}
 	}
 
@@ -527,35 +527,35 @@ public:
 	auto connfilters_foreach(Functor f) -> bool {
 		return std::all_of(std::begin(connfilters), std::end(connfilters),
 						   [&](const auto &sym_it){
-			return f(sym_it.get());
+			return f(sym_it);
 		});
 	}
 	template<typename Functor>
 	auto prefilters_foreach(Functor f) -> bool {
 		return std::all_of(std::begin(prefilters), std::end(prefilters),
 				[&](const auto &sym_it){
-					return f(sym_it.get());
+					return f(sym_it);
 				});
 	}
 	template<typename Functor>
 	auto postfilters_foreach(Functor f) -> bool {
 		return std::all_of(std::begin(postfilters), std::end(postfilters),
 				[&](const auto &sym_it){
-					return f(sym_it.get());
+					return f(sym_it);
 				});
 	}
 	template<typename Functor>
 	auto idempotent_foreach(Functor f) -> bool {
 		return std::all_of(std::begin(idempotent), std::end(idempotent),
 				[&](const auto &sym_it){
-					return f(sym_it.get());
+					return f(sym_it);
 				});
 	}
 	template<typename Functor>
 	auto filters_foreach(Functor f) -> bool {
 		return std::all_of(std::begin(filters), std::end(filters),
 				[&](const auto &sym_it){
-					return f(sym_it.get());
+					return f(sym_it);
 				});
 	}
 
diff --git a/src/libserver/symcache/symcache_item.cxx b/src/libserver/symcache/symcache_item.cxx
index 87e857d23..cfe8b4c14 100644
--- a/src/libserver/symcache/symcache_item.cxx
+++ b/src/libserver/symcache/symcache_item.cxx
@@ -183,8 +183,8 @@ auto cache_item::process_deps(const symcache &cache) -> void
 						auto *parent = get_parent_mut(cache);
 
 						if (parent) {
-							dit->rdeps.emplace_back(parent->getptr(), dep.sym, parent->id, -1);
-							dep.item = dit->getptr();
+							dit->rdeps.emplace_back(parent, parent->symbol, parent->id, -1);
+							dep.item = dit;
 							dep.id = dit->id;
 
 							msg_debug_cache ("added reverse dependency from %d on %d", parent->id,
@@ -192,9 +192,9 @@ auto cache_item::process_deps(const symcache &cache) -> void
 						}
 					}
 					else {
-						dep.item = dit->getptr();
+						dep.item = dit;
 						dep.id = dit->id;
-						dit->rdeps.emplace_back(getptr(), dep.sym, id, -1);
+						dit->rdeps.emplace_back(this, symbol, id, -1);
 						msg_debug_cache ("added reverse dependency from %d on %d", id,
 								dit->id);
 					}
@@ -525,7 +525,7 @@ auto cache_item::get_numeric_augmentation(std::string_view name) const -> std::o
 auto virtual_item::get_parent(const symcache &cache) const -> const cache_item *
 {
 	if (parent) {
-		return parent.get();
+		return parent;
 	}
 
 	return cache.get_item_by_id(parent_id, false);
@@ -534,7 +534,7 @@ auto virtual_item::get_parent(const symcache &cache) const -> const cache_item *
 auto virtual_item::get_parent_mut(const symcache &cache) -> cache_item *
 {
 	if (parent) {
-		return parent.get();
+		return parent;
 	}
 
 	return const_cast<cache_item *>(cache.get_item_by_id(parent_id, false));
@@ -549,7 +549,7 @@ auto virtual_item::resolve_parent(const symcache &cache) -> bool
 	auto item_ptr = cache.get_item_by_id(parent_id, true);
 
 	if (item_ptr) {
-		parent = const_cast<cache_item *>(item_ptr)->getptr();
+		parent = const_cast<cache_item *>(item_ptr);
 
 		return true;
 	}
diff --git a/src/libserver/symcache/symcache_item.hxx b/src/libserver/symcache/symcache_item.hxx
index 67c0960f3..de25199f2 100644
--- a/src/libserver/symcache/symcache_item.hxx
+++ b/src/libserver/symcache/symcache_item.hxx
@@ -109,9 +109,9 @@ public:
 
 class normal_item {
 private:
-	symbol_func_t func;
-	void *user_data;
-	std::vector<cache_item_ptr> virtual_children;
+	symbol_func_t func = nullptr;
+	void *user_data = nullptr;
+	std::vector<cache_item *> virtual_children;
 	std::vector<item_condition> conditions;
 public:
 	explicit normal_item(symbol_func_t _func, void *_user_data) : func(_func), user_data(_user_data)
@@ -137,19 +137,19 @@ public:
 		return user_data;
 	}
 
-	auto add_child(const cache_item_ptr &ptr) -> void {
+	auto add_child(cache_item *ptr) -> void {
 		virtual_children.push_back(ptr);
 	}
 
-	auto get_childen() const -> const std::vector<cache_item_ptr>& {
+	auto get_childen() const -> const std::vector<cache_item *>& {
 		return virtual_children;
 	}
 };
 
 class virtual_item {
 private:
-	int parent_id;
-	cache_item_ptr parent;
+	int parent_id = -1;
+	cache_item *parent = nullptr;
 public:
 	explicit virtual_item(int _parent_id) : parent_id(_parent_id)
 	{
@@ -162,14 +162,14 @@ public:
 };
 
 struct cache_dependency {
-	cache_item_ptr item; /* Real dependency */
+	cache_item *item; /* Real dependency */
 	std::string sym; /* Symbolic dep name */
 	int id; /* Real from */
 	int vid; /* Virtual from */
 public:
 	/* Default piecewise constructor */
-	cache_dependency(cache_item_ptr _item, std::string _sym, int _id, int _vid) :
-			item(std::move(_item)), sym(std::move(_sym)), id(_id), vid(_vid)
+	explicit cache_dependency(cache_item *_item, std::string _sym, int _id, int _vid) :
+			item(_item), sym(std::move(_sym)), id(_id), vid(_vid)
 	{
 	}
 };
@@ -236,9 +236,10 @@ public:
 	 * @param flags
 	 * @return
 	 */
-	[[nodiscard]] static auto create_with_function(rspamd_mempool_t *pool,
+	 template <typename T>
+	 static auto create_with_function(rspamd_mempool_t *pool,
 												   int id,
-												   std::string &&name,
+												   T &&name,
 												   int priority,
 												   symbol_func_t func,
 												   void *user_data,
@@ -246,7 +247,7 @@ public:
 												   int flags) -> cache_item_ptr
 	{
 		return std::shared_ptr<cache_item>(new cache_item(pool,
-				id, std::move(name), priority,
+				id, std::forward<T>(name), priority,
 				func, user_data,
 				type, flags));
 	}
@@ -260,21 +261,22 @@ public:
 	 * @param flags
 	 * @return
 	 */
-	[[nodiscard]] static auto create_with_virtual(rspamd_mempool_t *pool,
+	template <typename T>
+	static auto create_with_virtual(rspamd_mempool_t *pool,
 												  int id,
-												  std::string &&name,
+												  T &&name,
 												  int parent,
 												  symcache_item_type type,
 												  int flags) -> cache_item_ptr
 	{
-		return std::shared_ptr<cache_item>(new cache_item(pool, id, std::move(name),
+		return std::shared_ptr<cache_item>(new cache_item(pool, id, std::forward<T>(name),
 				parent, type, flags));
 	}
 
 	/**
 	 * Share ownership on the item
-	 * @return
-	 */
+ 	 * @return
+ 	 */
 	auto getptr() -> cache_item_ptr
 	{
 		return shared_from_this();
@@ -435,7 +437,7 @@ public:
 	 * Add a virtual symbol as a child of some normal symbol
 	 * @param ptr
 	 */
-	auto add_child(const cache_item_ptr &ptr) -> void {
+	auto add_child(cache_item *ptr) -> void {
 		if (std::holds_alternative<normal_item>(specific)) {
 			auto &filter_data = std::get<normal_item>(specific);
 
@@ -451,7 +453,7 @@ public:
 	 * @param ptr
 	 * @return
 	 */
-	auto get_children() const -> std::optional<std::reference_wrapper<const std::vector<cache_item_ptr>>> {
+	auto get_children() const -> std::optional<std::reference_wrapper<const std::vector<cache_item *>>> {
 		if (std::holds_alternative<normal_item>(specific)) {
 			const auto &filter_data = std::get<normal_item>(specific);
 
diff --git a/src/libserver/symcache/symcache_runtime.cxx b/src/libserver/symcache/symcache_runtime.cxx
index 0d3594a8b..f60b81fbd 100644
--- a/src/libserver/symcache/symcache_runtime.cxx
+++ b/src/libserver/symcache/symcache_runtime.cxx
@@ -582,7 +582,7 @@ auto symcache_runtime::check_item_deps(struct rspamd_task *task, symcache &cache
 					/* Not started */
 					if (!check_only) {
 						if (!rec_functor(recursion + 1,
-								dep.item.get(),
+								dep.item,
 								dep_dyn_item,
 								rec_functor)) {
 
@@ -591,7 +591,7 @@ auto symcache_runtime::check_item_deps(struct rspamd_task *task, symcache &cache
 												 "symbol %d(%s)",
 									dep.id, dep.sym.c_str(), item->id, item->symbol.c_str());
 						}
-						else if (!process_symbol(task, cache, dep.item.get(), dep_dyn_item)) {
+						else if (!process_symbol(task, cache, dep.item, dep_dyn_item)) {
 							/* Now started, but has events pending */
 							ret = false;
 							msg_debug_cache_task_lambda("started check of %d(%s) symbol "
@@ -801,13 +801,13 @@ auto symcache_runtime::process_item_rdeps(struct rspamd_task *task, cache_item *
 				msg_debug_cache_task ("check item %d(%s) rdep of %s ",
 						rdep.item->id, rdep.item->symbol.c_str(), item->symbol.c_str());
 
-				if (!check_item_deps(task, *cache_ptr, rdep.item.get(), dyn_item, false)) {
+				if (!check_item_deps(task, *cache_ptr, rdep.item, dyn_item, false)) {
 					msg_debug_cache_task ("blocked execution of %d(%s) rdep of %s "
 										  "unless deps are resolved",
 							rdep.item->id, rdep.item->symbol.c_str(), item->symbol.c_str());
 				}
 				else {
-					process_symbol(task, *cache_ptr, rdep.item.get(),
+					process_symbol(task, *cache_ptr, rdep.item,
 							dyn_item);
 				}
 			}


More information about the Commits mailing list