commit 273bd07: [Project] Rdns: Add preliminary reading logic for TCP channels

Vsevolod Stakhov vsevolod at highsecure.ru
Wed Jan 5 11:28:15 UTC 2022


Author: Vsevolod Stakhov
Date: 2022-01-03 17:13:37 +0000
URL: https://github.com/rspamd/rspamd/commit/273bd073821c437ea22fd65ec69079e2baf45ccf

[Project] Rdns: Add preliminary reading logic for TCP channels

---
 contrib/librdns/compression.c |   2 +-
 contrib/librdns/compression.h |   2 +-
 contrib/librdns/dns_private.h |  11 ++-
 contrib/librdns/resolver.c    | 186 ++++++++++++++++++++++++++++++++++++------
 contrib/librdns/util.c        |  41 +++++++++-
 contrib/librdns/util.h        |   7 ++
 6 files changed, 216 insertions(+), 33 deletions(-)

diff --git a/contrib/librdns/compression.c b/contrib/librdns/compression.c
index ac31a92c3..c48090115 100644
--- a/contrib/librdns/compression.c
+++ b/contrib/librdns/compression.c
@@ -66,7 +66,7 @@ rdns_add_compressed (const char *pos, const char *end,
 }
 
 void
-rnds_compression_free (struct rdns_compression_entry *comp)
+rdns_compression_free (struct rdns_compression_entry *comp)
 {
 	struct rdns_compression_entry *cur, *tmp;
 
diff --git a/contrib/librdns/compression.h b/contrib/librdns/compression.h
index adae2f3ef..29832302f 100644
--- a/contrib/librdns/compression.h
+++ b/contrib/librdns/compression.h
@@ -40,6 +40,6 @@ bool rdns_write_name_compressed (struct rdns_request *req,
 		const char *name, unsigned int namelen,
 		struct rdns_compression_entry **comp);
 
-void rnds_compression_free (struct rdns_compression_entry *comp);
+void rdns_compression_free (struct rdns_compression_entry *comp);
 
 #endif /* COMPRESSION_H_ */
diff --git a/contrib/librdns/dns_private.h b/contrib/librdns/dns_private.h
index cc2d48683..137b317e8 100644
--- a/contrib/librdns/dns_private.h
+++ b/contrib/librdns/dns_private.h
@@ -102,6 +102,8 @@ enum rdns_request_state {
 	RDNS_REQUEST_WAIT_REPLY,
 	RDNS_REQUEST_REPLIED,
 	RDNS_REQUEST_FAKE,
+	RDNS_REQUEST_ERROR,
+	RDNS_REQUEST_TCP,
 };
 
 struct rdns_request {
@@ -151,8 +153,8 @@ enum rdns_io_channel_flags {
  * Used to chain output DNS requests for a TCP connection
  */
 struct rdns_tcp_output_chain {
-	uint16_t next_write_size;
-	uint16_t cur_write;
+	uint16_t next_write_size; /* Network byte order! */
+	uint16_t cur_write; /* Cur bytes written including `next_write_size` */
 	struct rdns_request *req;
 	struct rdns_tcp_output_chain *prev, *next;
 };
@@ -161,9 +163,10 @@ struct rdns_tcp_output_chain {
  * Specific stuff for a TCP IO chain
  */
 struct rdns_tcp_channel {
-	uint16_t next_read_size;
-	uint16_t cur_read;
+	uint16_t next_read_size; /* Network byte order on read, then host byte order */
+	uint16_t cur_read; /* Cur bytes read including `next_read_size` */
 	unsigned char *cur_read_buf;
+	unsigned read_buf_allocated;
 
 	/* Chained set of the planned writes */
 	struct rdns_tcp_output_chain *output_chain;
diff --git a/contrib/librdns/resolver.c b/contrib/librdns/resolver.c
index 8598cfdf5..c5cebc572 100644
--- a/contrib/librdns/resolver.c
+++ b/contrib/librdns/resolver.c
@@ -151,25 +151,6 @@ rdns_send_request (struct rdns_request *req, int fd, bool new_req)
 }
 
 
-static struct rdns_reply *
-rdns_make_reply (struct rdns_request *req, enum dns_rcode rcode)
-{
-	struct rdns_reply *rep;
-
-	rep = malloc (sizeof (struct rdns_reply));
-	if (rep != NULL) {
-		rep->request = req;
-		rep->resolver = req->resolver;
-		rep->entries = NULL;
-		rep->code = rcode;
-		req->reply = rep;
-		rep->flags = 0;
-		rep->requested_name = req->requested_names[0].name;
-	}
-
-	return rep;
-}
-
 static struct rdns_request *
 rdns_find_dns_request (uint8_t *in, struct rdns_io_channel *ioc)
 {
@@ -287,18 +268,173 @@ rdns_parse_reply (uint8_t *in, int r, struct rdns_request *req,
 	return true;
 }
 
+static bool
+rdns_tcp_maybe_realloc_read_buf (struct rdns_io_channel *ioc)
+{
+	if (ioc->tcp->read_buf_allocated == 0 && ioc->tcp->next_read_size > 0) {
+		ioc->tcp->cur_read_buf = malloc(ioc->tcp->next_read_size);
+
+		if (ioc->tcp->cur_read_buf == NULL) {
+			return false;
+		}
+		ioc->tcp->read_buf_allocated = ioc->tcp->next_read_size;
+	}
+	else if (ioc->tcp->read_buf_allocated < ioc->tcp->next_read_size) {
+		/* Need to realloc */
+		unsigned next_shift = ioc->tcp->next_read_size;
+
+		if (next_shift < ioc->tcp->read_buf_allocated * 2) {
+			if (next_shift < UINT16_MAX && ioc->tcp->read_buf_allocated * 2 <= UINT16_MAX) {
+				next_shift = ioc->tcp->read_buf_allocated * 2;
+			}
+		}
+		void *next_buf = realloc(ioc->tcp->cur_read_buf, next_shift);
+
+		if (next_buf == NULL) {
+			free (ioc->tcp->cur_read_buf);
+			ioc->tcp->cur_read_buf = NULL;
+			return false;
+		}
+
+		ioc->tcp->cur_read_buf = next_buf;
+	}
+
+	return true;
+}
+
 static void
 rdns_process_tcp_read (int fd, struct rdns_io_channel *ioc)
 {
+	ssize_t r;
+	struct rdns_resolver *resolver = ioc->resolver;
+
+	if (ioc->tcp->cur_read == 0) {
+		/* We have to read size first */
+		r = read(fd, &ioc->tcp->next_read_size, sizeof(ioc->tcp->next_read_size));
+
+		if (r == -1 || r == 0) {
+			goto err;
+		}
+
+		ioc->tcp->cur_read += r;
+
+		if (r == sizeof(ioc->tcp->next_read_size)) {
+			ioc->tcp->next_read_size = ntohl(ioc->tcp->next_read_size);
+
+			/* We have read the size, so we can try read one more time */
+			if (!rdns_tcp_maybe_realloc_read_buf(ioc)) {
+				rdns_err("failed to allocate %d bytes: %s",
+						(int)ioc->tcp->next_read_size, strerror(errno));
+				r = -1;
+				goto err;
+			}
+		}
+		else {
+			/* We have read one byte, need to retry... */
+			return;
+		}
+	}
+	else if (ioc->tcp->cur_read == 1) {
+		r = read(fd, ((unsigned char *)&ioc->tcp->next_read_size) + 1, 1);
 
+		if (r == -1 || r == 0) {
+			goto err;
+		}
+
+		ioc->tcp->cur_read += r;
+		ioc->tcp->next_read_size = ntohl(ioc->tcp->next_read_size);
+
+		/* We have read the size, so we can try read one more time */
+		if (!rdns_tcp_maybe_realloc_read_buf(ioc)) {
+			rdns_err("failed to allocate %d bytes: %s",
+					(int)ioc->tcp->next_read_size, strerror(errno));
+			r = -1;
+			goto err;
+		}
+	}
+
+	if (ioc->tcp->next_read_size < sizeof(struct dns_header)) {
+		/* Truncated reply, reset channel */
+		rdns_err("got truncated size: %d on TCP read", ioc->tcp->next_read_size);
+		r = -1;
+		errno = EINVAL;
+		goto err;
+	}
+
+	/* Try to read the full packet if we can */
+	int to_read = ioc->tcp->next_read_size - (ioc->tcp->cur_read - 2);
+
+	if (to_read <= 0) {
+		/* Internal error */
+		rdns_err("internal buffer error on reading!");
+		r = -1;
+		errno = EINVAL;
+		goto err;
+	}
+
+	r = read(fd, ioc->tcp->cur_read_buf + (ioc->tcp->cur_read - 2), to_read);
+	ioc->tcp->cur_read += r;
+
+	if ((ioc->tcp->cur_read - 2) == ioc->tcp->next_read_size) {
+		/* We have a full packet ready, process it */
+		struct rdns_request *req = rdns_find_dns_request (ioc->tcp->cur_read_buf, ioc);
+
+		if (req != NULL) {
+			struct rdns_reply *rep;
+
+			if (rdns_parse_reply (ioc->tcp->cur_read_buf,
+					ioc->tcp->next_read_size, req, &rep)) {
+				UPSTREAM_OK (req->io->srv);
+
+				if (req->resolver->ups && req->io->srv->ups_elt) {
+					req->resolver->ups->ok (req->io->srv->ups_elt,
+							req->resolver->ups->data);
+				}
+
+				rdns_request_unschedule (req);
+				req->state = RDNS_REQUEST_REPLIED;
+				req->func (rep, req->arg);
+				REF_RELEASE (req);
+			}
+		}
+		else {
+			rdns_warn("unwanted DNS id received over TCP");
+		}
+
+		ioc->tcp->next_read_size = 0;
+		ioc->tcp->cur_read = 0;
+
+		/* Retry read the next packet to avoid unnecessary polling */
+		rdns_process_tcp_read (fd, ioc);
+	}
+
+	return;
+
+err:
+	if (r == 0) {
+		/* Got EOF, just close the socket */
+		rdns_debug ("closing TCP channel due to EOF");
+		rdns_ioc_tcp_reset (ioc);
+	}
+	else if (errno == EINTR || errno == EAGAIN) {
+		/* We just retry later as there is no real error */
+		return;
+	}
+	else {
+		rdns_debug ("closing TCP channel due to IO error: %s", strerror(errno));
+		rdns_ioc_tcp_reset (ioc);
+	}
 }
 
 static void
 rdns_process_tcp_connect (int fd, struct rdns_io_channel *ioc)
 {
 	ioc->flags |= RDNS_CHANNEL_CONNECTED|RDNS_CHANNEL_ACTIVE;
-	ioc->tcp->async_read = ioc->resolver->async->add_read(ioc->resolver->async->data,
-			ioc->sock, ioc);
+
+	if (ioc->tcp->async_read == NULL) {
+		ioc->tcp->async_read = ioc->resolver->async->add_read(ioc->resolver->async->data,
+				ioc->sock, ioc);
+	}
 }
 
 static void
@@ -677,7 +813,7 @@ rdns_process_tcp_write (int fd, struct rdns_io_channel *ioc)
 				return;
 			}
 		}
-		else if (oc->next_write_size < oc->cur_write) {
+		else if (ntohl(oc->next_write_size) < oc->cur_write) {
 			/* Packet has been fully written, remove it */
 			DL_DELETE(ioc->tcp->output_chain, oc);
 			/* Data in output buffer belongs to request */
@@ -937,20 +1073,20 @@ rdns_make_request_full (
 				if (!rdns_add_rr (req, cur_name, clen, type, &comp)) {
 					rdns_err ("cannot add rr");
 					REF_RELEASE (req);
-					rnds_compression_free (comp);
+					rdns_compression_free(comp);
 					return NULL;
 				}
 			} else {
 				if (!rdns_add_rr (req, cur_name, clen, type, NULL)) {
 					rdns_err ("cannot add rr");
 					REF_RELEASE (req);
-					rnds_compression_free (comp);
+					rdns_compression_free(comp);
 					return NULL;
 				}
 			}
 		}
 
-		rnds_compression_free (comp);
+		rdns_compression_free(comp);
 
 		/* Add EDNS RR */
 		rdns_add_edns0 (req);
diff --git a/contrib/librdns/util.c b/contrib/librdns/util.c
index fd71179e9..d96103bb7 100644
--- a/contrib/librdns/util.c
+++ b/contrib/librdns/util.c
@@ -406,6 +406,24 @@ rdns_permutor_generate_id (void)
 	return id;
 }
 
+struct rdns_reply *
+rdns_make_reply (struct rdns_request *req, enum dns_rcode rcode)
+{
+	struct rdns_reply *rep;
+
+	rep = malloc (sizeof (struct rdns_reply));
+	if (rep != NULL) {
+		rep->request = req;
+		rep->resolver = req->resolver;
+		rep->entries = NULL;
+		rep->code = rcode;
+		req->reply = rep;
+		rep->flags = 0;
+		rep->requested_name = req->requested_names[0].name;
+	}
+
+	return rep;
+}
 
 void
 rdns_reply_free (struct rdns_reply *rep)
@@ -508,12 +526,18 @@ rdns_ioc_free (struct rdns_io_channel *ioc)
 {
 	struct rdns_request *req;
 
+	if (IS_CHANNEL_TCP(ioc)) {
+		rdns_ioc_tcp_reset(ioc);
+	}
+
 	kh_foreach_value(ioc->requests, req, {
 		REF_RELEASE (req);
 	});
 
-	ioc->resolver->async->del_read (ioc->resolver->async->data,
-			ioc->async_io);
+	if (ioc->async_io) {
+		ioc->resolver->async->del_read(ioc->resolver->async->data,
+				ioc->async_io);
+	}
 	kh_destroy(rdns_requests_hash, ioc->requests);
 
 	if (ioc->sock != -1) {
@@ -640,6 +664,8 @@ rdns_ioc_tcp_reset (struct rdns_io_channel *ioc)
 			ioc->tcp->async_read = NULL;
 		}
 
+		/* Clean all buffers and temporaries */
+
 		ioc->flags &= ~RDNS_CHANNEL_CONNECTED;
 	}
 
@@ -651,6 +677,17 @@ rdns_ioc_tcp_reset (struct rdns_io_channel *ioc)
 		free (ioc->saddr);
 		ioc->saddr = NULL;
 	}
+
+	/* Remove all requests pending as we are unable to complete them */
+	struct rdns_request *req;
+	kh_foreach_value(ioc->requests, req, {
+		struct rdns_reply *rep = rdns_make_reply (req, RDNS_RC_NETERR);
+		req->state = RDNS_REQUEST_REPLIED;
+		req->func (rep, req->arg);
+		REF_RELEASE (req);
+	});
+
+	kh_clear(rdns_requests_hash, ioc->requests);
 }
 
 bool
diff --git a/contrib/librdns/util.h b/contrib/librdns/util.h
index 70ad053a0..915b8febd 100644
--- a/contrib/librdns/util.h
+++ b/contrib/librdns/util.h
@@ -81,6 +81,13 @@ void rdns_request_free (struct rdns_request *req);
  */
 void rdns_request_remove_from_hash (struct rdns_request *req);
 
+/**
+ * Creates a new reply
+ * @param req
+ * @param rcode
+ * @return
+ */
+struct rdns_reply * rdns_make_reply (struct rdns_request *req, enum dns_rcode rcode);
 /**
  * Free reply
  * @param rep


More information about the Commits mailing list