commit 264b9f2: [Project] Implement fasttext language detection

Vsevolod Stakhov vsevolod at rspamd.com
Sat Apr 29 17:14:07 UTC 2023


Author: Vsevolod Stakhov
Date: 2023-04-29 15:47:15 +0100
URL: https://github.com/rspamd/rspamd/commit/264b9f2c480a1b0240acb8183a8d7470691aff11

[Project] Implement fasttext language detection

---
 src/libmime/lang_detection.c            | 169 ++++++++++++++++++++------------
 src/libmime/lang_detection_fasttext.cxx |  43 ++++++--
 src/libmime/lang_detection_fasttext.h   |  17 +++-
 3 files changed, 158 insertions(+), 71 deletions(-)

diff --git a/src/libmime/lang_detection.c b/src/libmime/lang_detection.c
index 09591438e..211dfe48b 100644
--- a/src/libmime/lang_detection.c
+++ b/src/libmime/lang_detection.c
@@ -1801,88 +1801,132 @@ rspamd_language_detector_detect (struct rspamd_task *task,
 	}
 
 	if (!ret) {
-		if (part->utf_words->len < default_short_text_limit) {
-			r = rs_detect_none;
-			msg_debug_lang_det ("text is too short for trigrams detection: "
-					   "%d words; at least %d words required",
+		unsigned ndetected = 0;
+		if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) {
+			rspamd_fasttext_predict_result_t fasttext_predict_result;
+			fasttext_predict_result = rspamd_lang_detection_fasttext_detect(d->fasttext_detector,
+				part->utf_stripped_content->data,
+				part->utf_stripped_content->len, 4);
+
+			ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result);
+
+			if (ndetected > 0) {
+				candidates = kh_init (rspamd_candidates_hash);
+				kh_resize (rspamd_candidates_hash, candidates, ndetected);
+
+				/* Now fill all results where probability is above threshold */
+				float max_prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, 0);
+
+				for (unsigned int i = 0; i < ndetected; i ++) {
+					float prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, i);
+					if (prob > max_prob * 0.75) {
+						char *lang = rspamd_mempool_strdup(task->task_pool,
+							rspamd_lang_detection_fasttext_get_lang(fasttext_predict_result, i));
+						int tmp;
+						khiter_t k = kh_put (rspamd_candidates_hash, candidates, lang, &tmp);
+
+						kh_value(candidates, k) = rspamd_mempool_alloc0(task->task_pool, sizeof(*cand));
+						cand = kh_value(candidates, k);
+						cand->lang = lang;
+						cand->prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, i);
+					}
+				}
+
+				if (kh_size(candidates) == 1) {
+					r = rs_detect_single;
+				}
+				else if (kh_size(candidates) > 1) {
+					r = rs_detect_multiple;
+				}
+				else {
+					r = rs_detect_none;
+				}
+			}
+		}
+		if (ndetected == 0) {
+			if (part->utf_words->len < default_short_text_limit) {
+				r = rs_detect_none;
+				msg_debug_lang_det ("text is too short for trigrams detection: "
+									"%d words; at least %d words required",
 					(int)part->utf_words->len,
 					(int)default_short_text_limit);
-			switch (cat) {
-			case RSPAMD_LANGUAGE_CYRILLIC:
-				rspamd_language_detector_set_language (task, part, "ru", NULL);
-				break;
-			case RSPAMD_LANGUAGE_DEVANAGARI:
-				rspamd_language_detector_set_language (task, part, "hi", NULL);
-				break;
-			case RSPAMD_LANGUAGE_ARAB:
-				rspamd_language_detector_set_language (task, part, "ar", NULL);
-				break;
-			default:
-			case RSPAMD_LANGUAGE_LATIN:
-				rspamd_language_detector_set_language (task, part, "en", NULL);
-				break;
-			}
-			msg_debug_lang_det ("set %s language based on symbols category",
+				switch (cat) {
+				case RSPAMD_LANGUAGE_CYRILLIC:
+					rspamd_language_detector_set_language (task, part, "ru", NULL);
+					break;
+				case RSPAMD_LANGUAGE_DEVANAGARI:
+					rspamd_language_detector_set_language (task, part, "hi", NULL);
+					break;
+				case RSPAMD_LANGUAGE_ARAB:
+					rspamd_language_detector_set_language (task, part, "ar", NULL);
+					break;
+				default:
+				case RSPAMD_LANGUAGE_LATIN:
+					rspamd_language_detector_set_language (task, part, "en", NULL);
+					break;
+				}
+				msg_debug_lang_det ("set %s language based on symbols category",
 					part->language);
 
-			candidates = kh_init (rspamd_candidates_hash);
-		}
-		else {
-			candidates = kh_init (rspamd_candidates_hash);
-			kh_resize (rspamd_candidates_hash, candidates, 32);
+				candidates = kh_init (rspamd_candidates_hash);
+			}
+			else {
+				candidates = kh_init (rspamd_candidates_hash);
+				kh_resize (rspamd_candidates_hash, candidates, 32);
 
-			r = rspamd_language_detector_try_ngramm (task,
+				r = rspamd_language_detector_try_ngramm (task,
 					default_words,
 					d,
 					part->utf_words,
 					cat,
 					candidates);
 
-			if (r == rs_detect_none) {
-				msg_debug_lang_det ("no trigrams found, fallback to english");
-				rspamd_language_detector_set_language (task, part, "en", NULL);
-			} else if (r == rs_detect_multiple) {
-				/* Check our guess */
-
-				mean = 0.0;
-				std = 0.0;
-				cand_len = 0;
-
-				/* Check distribution */
-				kh_foreach_value (candidates, cand, {
-					if (!isnan (cand->prob)) {
-						mean += cand->prob;
-						cand_len++;
-					}
-				});
+				if (r == rs_detect_none) {
+					msg_debug_lang_det ("no trigrams found, fallback to english");
+					rspamd_language_detector_set_language (task, part, "en", NULL);
+				} else if (r == rs_detect_multiple) {
+					/* Check our guess */
 
-				if (cand_len > 0) {
-					mean /= cand_len;
+					mean = 0.0;
+					std = 0.0;
+					cand_len = 0;
 
+					/* Check distribution */
 					kh_foreach_value (candidates, cand, {
-						gdouble err;
 						if (!isnan (cand->prob)) {
-							err = cand->prob - mean;
-							std += fabs (err);
+							mean += cand->prob;
+							cand_len++;
 						}
 					});
 
-					std /= cand_len;
-				}
+					if (cand_len > 0) {
+						mean /= cand_len;
 
-				msg_debug_lang_det ("trigrams checked, %d candidates, %.3f mean, %.4f stddev",
+						kh_foreach_value (candidates, cand, {
+							gdouble err;
+							if (!isnan (cand->prob)) {
+								err = cand->prob - mean;
+								std += fabs (err);
+							}
+						});
+
+						std /= cand_len;
+					}
+
+					msg_debug_lang_det ("trigrams checked, %d candidates, %.3f mean, %.4f stddev",
 						cand_len, mean, std);
 
-				if (cand_len > 0 && std / fabs (mean) < 0.25) {
-					msg_debug_lang_det ("apply frequency heuristic sorting");
-					frequency_heuristic_applied = TRUE;
-					cbd.d = d;
-					cbd.mean = mean;
-					cbd.std = std;
-					cbd.flags = RSPAMD_LANG_FLAG_DEFAULT;
+					if (cand_len > 0 && std / fabs (mean) < 0.25) {
+						msg_debug_lang_det ("apply frequency heuristic sorting");
+						frequency_heuristic_applied = TRUE;
+						cbd.d = d;
+						cbd.mean = mean;
+						cbd.std = std;
+						cbd.flags = RSPAMD_LANG_FLAG_DEFAULT;
 
-					if (part->nwords < default_words / 2) {
-						cbd.flags |= RSPAMD_LANG_FLAG_SHORT;
+						if (part->nwords < default_words / 2) {
+							cbd.flags |= RSPAMD_LANG_FLAG_SHORT;
+						}
 					}
 				}
 			}
@@ -1909,7 +1953,9 @@ rspamd_language_detector_detect (struct rspamd_task *task,
 
 			if (result->len > 0 && !frequency_heuristic_applied) {
 				cand = g_ptr_array_index (result, 0);
-				cand->elt->occurrences++;
+				if (cand->elt) {
+					cand->elt->occurrences++;
+				}
 				d->total_occurrences++;
 			}
 
@@ -1918,6 +1964,7 @@ rspamd_language_detector_detect (struct rspamd_task *task,
 			}
 
 			part->languages = result;
+			part->language = ((struct rspamd_lang_detector_res *)g_ptr_array_index (result, 0))->lang;
 			ret = TRUE;
 		}
 		else if (part->languages == NULL) {
diff --git a/src/libmime/lang_detection_fasttext.cxx b/src/libmime/lang_detection_fasttext.cxx
index 9ede47a6e..eda4c2850 100644
--- a/src/libmime/lang_detection_fasttext.cxx
+++ b/src/libmime/lang_detection_fasttext.cxx
@@ -72,8 +72,8 @@ public:
 
 	~fasttext_langdet() = default;
 
-
-	auto detect_language(const char *in, size_t len, int k) -> std::vector<std::pair<fasttext::real, std::string>> *
+	auto is_enabled() const -> bool { return loaded; }
+	auto detect_language(const char *in, size_t len, int k) const -> std::vector<std::pair<fasttext::real, std::string>> *
 	{
 		if (!loaded) {
 			return nullptr;
@@ -135,6 +135,19 @@ char *rspamd_lang_detection_fasttext_show_info(void *ud)
 #endif
 }
 
+bool rspamd_lang_detection_fasttext_is_enabled(void *ud)
+{
+#ifdef WITH_FASTTEXT
+	auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);
+
+	if (real_model) {
+		return real_model->is_enabled();
+	}
+#endif
+
+	return false;
+}
+
 rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
 											   const char *in, size_t len, int k)
 {
@@ -155,27 +168,41 @@ void rspamd_lang_detection_fasttext_destroy(void *ud)
 #endif
 }
 
+
+guint
+rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t res)
+{
+#ifdef WITH_FASTTEXT
+	auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
+
+	if (real_res) {
+		return real_res->size();
+	}
+#endif
+	return 0;
+}
+
 const char *
-rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res)
+rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx)
 {
 #ifdef WITH_FASTTEXT
 	auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
 
-	if (real_res && !real_res->empty()) {
-		return real_res->front().second.c_str();
+	if (real_res && real_res->size() < idx) {
+		return real_res->at(idx).second.c_str();
 	}
 #endif
 	return nullptr;
 }
 
 float
-rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res)
+rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx)
 {
 #ifdef WITH_FASTTEXT
 	auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
 
-	if (real_res && !real_res->empty()) {
-		return real_res->front().first;
+	if (real_res && real_res->size() < idx) {
+		return real_res->at(idx).first;
 	}
 #endif
 	return 0.0f;
diff --git a/src/libmime/lang_detection_fasttext.h b/src/libmime/lang_detection_fasttext.h
index 71e253940..2e8a9fe78 100644
--- a/src/libmime/lang_detection_fasttext.h
+++ b/src/libmime/lang_detection_fasttext.h
@@ -27,6 +27,13 @@ struct rspamd_config;
  */
 void* rspamd_lang_detection_fasttext_init(struct rspamd_config *cfg);
 
+/**
+ * Check if fasttext language detector is enabled
+ * @param ud
+ * @return
+ */
+bool rspamd_lang_detection_fasttext_is_enabled(void *ud);
+
 /**
  * Show info about fasttext language detector
  * @param ud
@@ -47,19 +54,25 @@ typedef  void * rspamd_fasttext_predict_result_t;
 rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
 		const char *in, size_t len, int k);
 
+/**
+ * Get number of languages detected
+ * @param ud
+ * @return
+ */
+guint rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t ud);
 /**
  * Get language from fasttext result
  * @param res
  * @return
  */
-const char *rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res);
+const char *rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx);
 
 /**
  * Get probability from fasttext result
  * @param res
  * @return
  */
-float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res);
+float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx);
 
 /**
  * Destroy fasttext result


More information about the Commits mailing list