commit 273bd07: [Project] Rdns: Add preliminary reading logic for TCP channels
Vsevolod Stakhov
vsevolod at highsecure.ru
Wed Jan 5 11:28:15 UTC 2022
Author: Vsevolod Stakhov
Date: 2022-01-03 17:13:37 +0000
URL: https://github.com/rspamd/rspamd/commit/273bd073821c437ea22fd65ec69079e2baf45ccf
[Project] Rdns: Add preliminary reading logic for TCP channels
---
contrib/librdns/compression.c | 2 +-
contrib/librdns/compression.h | 2 +-
contrib/librdns/dns_private.h | 11 ++-
contrib/librdns/resolver.c | 186 ++++++++++++++++++++++++++++++++++++------
contrib/librdns/util.c | 41 +++++++++-
contrib/librdns/util.h | 7 ++
6 files changed, 216 insertions(+), 33 deletions(-)
diff --git a/contrib/librdns/compression.c b/contrib/librdns/compression.c
index ac31a92c3..c48090115 100644
--- a/contrib/librdns/compression.c
+++ b/contrib/librdns/compression.c
@@ -66,7 +66,7 @@ rdns_add_compressed (const char *pos, const char *end,
}
void
-rnds_compression_free (struct rdns_compression_entry *comp)
+rdns_compression_free (struct rdns_compression_entry *comp)
{
struct rdns_compression_entry *cur, *tmp;
diff --git a/contrib/librdns/compression.h b/contrib/librdns/compression.h
index adae2f3ef..29832302f 100644
--- a/contrib/librdns/compression.h
+++ b/contrib/librdns/compression.h
@@ -40,6 +40,6 @@ bool rdns_write_name_compressed (struct rdns_request *req,
const char *name, unsigned int namelen,
struct rdns_compression_entry **comp);
-void rnds_compression_free (struct rdns_compression_entry *comp);
+void rdns_compression_free (struct rdns_compression_entry *comp);
#endif /* COMPRESSION_H_ */
diff --git a/contrib/librdns/dns_private.h b/contrib/librdns/dns_private.h
index cc2d48683..137b317e8 100644
--- a/contrib/librdns/dns_private.h
+++ b/contrib/librdns/dns_private.h
@@ -102,6 +102,8 @@ enum rdns_request_state {
RDNS_REQUEST_WAIT_REPLY,
RDNS_REQUEST_REPLIED,
RDNS_REQUEST_FAKE,
+ RDNS_REQUEST_ERROR,
+ RDNS_REQUEST_TCP,
};
struct rdns_request {
@@ -151,8 +153,8 @@ enum rdns_io_channel_flags {
* Used to chain output DNS requests for a TCP connection
*/
struct rdns_tcp_output_chain {
- uint16_t next_write_size;
- uint16_t cur_write;
+ uint16_t next_write_size; /* Network byte order! */
+ uint16_t cur_write; /* Cur bytes written including `next_write_size` */
struct rdns_request *req;
struct rdns_tcp_output_chain *prev, *next;
};
@@ -161,9 +163,10 @@ struct rdns_tcp_output_chain {
* Specific stuff for a TCP IO chain
*/
struct rdns_tcp_channel {
- uint16_t next_read_size;
- uint16_t cur_read;
+ uint16_t next_read_size; /* Network byte order on read, then host byte order */
+ uint16_t cur_read; /* Cur bytes read including `next_read_size` */
unsigned char *cur_read_buf;
+ unsigned read_buf_allocated;
/* Chained set of the planned writes */
struct rdns_tcp_output_chain *output_chain;
diff --git a/contrib/librdns/resolver.c b/contrib/librdns/resolver.c
index 8598cfdf5..c5cebc572 100644
--- a/contrib/librdns/resolver.c
+++ b/contrib/librdns/resolver.c
@@ -151,25 +151,6 @@ rdns_send_request (struct rdns_request *req, int fd, bool new_req)
}
-static struct rdns_reply *
-rdns_make_reply (struct rdns_request *req, enum dns_rcode rcode)
-{
- struct rdns_reply *rep;
-
- rep = malloc (sizeof (struct rdns_reply));
- if (rep != NULL) {
- rep->request = req;
- rep->resolver = req->resolver;
- rep->entries = NULL;
- rep->code = rcode;
- req->reply = rep;
- rep->flags = 0;
- rep->requested_name = req->requested_names[0].name;
- }
-
- return rep;
-}
-
static struct rdns_request *
rdns_find_dns_request (uint8_t *in, struct rdns_io_channel *ioc)
{
@@ -287,18 +268,173 @@ rdns_parse_reply (uint8_t *in, int r, struct rdns_request *req,
return true;
}
+static bool
+rdns_tcp_maybe_realloc_read_buf (struct rdns_io_channel *ioc)
+{
+ if (ioc->tcp->read_buf_allocated == 0 && ioc->tcp->next_read_size > 0) {
+ ioc->tcp->cur_read_buf = malloc(ioc->tcp->next_read_size);
+
+ if (ioc->tcp->cur_read_buf == NULL) {
+ return false;
+ }
+ ioc->tcp->read_buf_allocated = ioc->tcp->next_read_size;
+ }
+ else if (ioc->tcp->read_buf_allocated < ioc->tcp->next_read_size) {
+ /* Need to realloc */
+ unsigned next_shift = ioc->tcp->next_read_size;
+
+ if (next_shift < ioc->tcp->read_buf_allocated * 2) {
+ if (next_shift < UINT16_MAX && ioc->tcp->read_buf_allocated * 2 <= UINT16_MAX) {
+ next_shift = ioc->tcp->read_buf_allocated * 2;
+ }
+ }
+ void *next_buf = realloc(ioc->tcp->cur_read_buf, next_shift);
+
+ if (next_buf == NULL) {
+ free (ioc->tcp->cur_read_buf);
+ ioc->tcp->cur_read_buf = NULL;
+ return false;
+ }
+
+ ioc->tcp->cur_read_buf = next_buf;
+ }
+
+ return true;
+}
+
static void
rdns_process_tcp_read (int fd, struct rdns_io_channel *ioc)
{
+ ssize_t r;
+ struct rdns_resolver *resolver = ioc->resolver;
+
+ if (ioc->tcp->cur_read == 0) {
+ /* We have to read size first */
+ r = read(fd, &ioc->tcp->next_read_size, sizeof(ioc->tcp->next_read_size));
+
+ if (r == -1 || r == 0) {
+ goto err;
+ }
+
+ ioc->tcp->cur_read += r;
+
+ if (r == sizeof(ioc->tcp->next_read_size)) {
+ ioc->tcp->next_read_size = ntohl(ioc->tcp->next_read_size);
+
+ /* We have read the size, so we can try read one more time */
+ if (!rdns_tcp_maybe_realloc_read_buf(ioc)) {
+ rdns_err("failed to allocate %d bytes: %s",
+ (int)ioc->tcp->next_read_size, strerror(errno));
+ r = -1;
+ goto err;
+ }
+ }
+ else {
+ /* We have read one byte, need to retry... */
+ return;
+ }
+ }
+ else if (ioc->tcp->cur_read == 1) {
+ r = read(fd, ((unsigned char *)&ioc->tcp->next_read_size) + 1, 1);
+ if (r == -1 || r == 0) {
+ goto err;
+ }
+
+ ioc->tcp->cur_read += r;
+ ioc->tcp->next_read_size = ntohl(ioc->tcp->next_read_size);
+
+ /* We have read the size, so we can try read one more time */
+ if (!rdns_tcp_maybe_realloc_read_buf(ioc)) {
+ rdns_err("failed to allocate %d bytes: %s",
+ (int)ioc->tcp->next_read_size, strerror(errno));
+ r = -1;
+ goto err;
+ }
+ }
+
+ if (ioc->tcp->next_read_size < sizeof(struct dns_header)) {
+ /* Truncated reply, reset channel */
+ rdns_err("got truncated size: %d on TCP read", ioc->tcp->next_read_size);
+ r = -1;
+ errno = EINVAL;
+ goto err;
+ }
+
+ /* Try to read the full packet if we can */
+ int to_read = ioc->tcp->next_read_size - (ioc->tcp->cur_read - 2);
+
+ if (to_read <= 0) {
+ /* Internal error */
+ rdns_err("internal buffer error on reading!");
+ r = -1;
+ errno = EINVAL;
+ goto err;
+ }
+
+ r = read(fd, ioc->tcp->cur_read_buf + (ioc->tcp->cur_read - 2), to_read);
+ ioc->tcp->cur_read += r;
+
+ if ((ioc->tcp->cur_read - 2) == ioc->tcp->next_read_size) {
+ /* We have a full packet ready, process it */
+ struct rdns_request *req = rdns_find_dns_request (ioc->tcp->cur_read_buf, ioc);
+
+ if (req != NULL) {
+ struct rdns_reply *rep;
+
+ if (rdns_parse_reply (ioc->tcp->cur_read_buf,
+ ioc->tcp->next_read_size, req, &rep)) {
+ UPSTREAM_OK (req->io->srv);
+
+ if (req->resolver->ups && req->io->srv->ups_elt) {
+ req->resolver->ups->ok (req->io->srv->ups_elt,
+ req->resolver->ups->data);
+ }
+
+ rdns_request_unschedule (req);
+ req->state = RDNS_REQUEST_REPLIED;
+ req->func (rep, req->arg);
+ REF_RELEASE (req);
+ }
+ }
+ else {
+ rdns_warn("unwanted DNS id received over TCP");
+ }
+
+ ioc->tcp->next_read_size = 0;
+ ioc->tcp->cur_read = 0;
+
+ /* Retry read the next packet to avoid unnecessary polling */
+ rdns_process_tcp_read (fd, ioc);
+ }
+
+ return;
+
+err:
+ if (r == 0) {
+ /* Got EOF, just close the socket */
+ rdns_debug ("closing TCP channel due to EOF");
+ rdns_ioc_tcp_reset (ioc);
+ }
+ else if (errno == EINTR || errno == EAGAIN) {
+ /* We just retry later as there is no real error */
+ return;
+ }
+ else {
+ rdns_debug ("closing TCP channel due to IO error: %s", strerror(errno));
+ rdns_ioc_tcp_reset (ioc);
+ }
}
static void
rdns_process_tcp_connect (int fd, struct rdns_io_channel *ioc)
{
ioc->flags |= RDNS_CHANNEL_CONNECTED|RDNS_CHANNEL_ACTIVE;
- ioc->tcp->async_read = ioc->resolver->async->add_read(ioc->resolver->async->data,
- ioc->sock, ioc);
+
+ if (ioc->tcp->async_read == NULL) {
+ ioc->tcp->async_read = ioc->resolver->async->add_read(ioc->resolver->async->data,
+ ioc->sock, ioc);
+ }
}
static void
@@ -677,7 +813,7 @@ rdns_process_tcp_write (int fd, struct rdns_io_channel *ioc)
return;
}
}
- else if (oc->next_write_size < oc->cur_write) {
+ else if (ntohl(oc->next_write_size) < oc->cur_write) {
/* Packet has been fully written, remove it */
DL_DELETE(ioc->tcp->output_chain, oc);
/* Data in output buffer belongs to request */
@@ -937,20 +1073,20 @@ rdns_make_request_full (
if (!rdns_add_rr (req, cur_name, clen, type, &comp)) {
rdns_err ("cannot add rr");
REF_RELEASE (req);
- rnds_compression_free (comp);
+ rdns_compression_free(comp);
return NULL;
}
} else {
if (!rdns_add_rr (req, cur_name, clen, type, NULL)) {
rdns_err ("cannot add rr");
REF_RELEASE (req);
- rnds_compression_free (comp);
+ rdns_compression_free(comp);
return NULL;
}
}
}
- rnds_compression_free (comp);
+ rdns_compression_free(comp);
/* Add EDNS RR */
rdns_add_edns0 (req);
diff --git a/contrib/librdns/util.c b/contrib/librdns/util.c
index fd71179e9..d96103bb7 100644
--- a/contrib/librdns/util.c
+++ b/contrib/librdns/util.c
@@ -406,6 +406,24 @@ rdns_permutor_generate_id (void)
return id;
}
+struct rdns_reply *
+rdns_make_reply (struct rdns_request *req, enum dns_rcode rcode)
+{
+ struct rdns_reply *rep;
+
+ rep = malloc (sizeof (struct rdns_reply));
+ if (rep != NULL) {
+ rep->request = req;
+ rep->resolver = req->resolver;
+ rep->entries = NULL;
+ rep->code = rcode;
+ req->reply = rep;
+ rep->flags = 0;
+ rep->requested_name = req->requested_names[0].name;
+ }
+
+ return rep;
+}
void
rdns_reply_free (struct rdns_reply *rep)
@@ -508,12 +526,18 @@ rdns_ioc_free (struct rdns_io_channel *ioc)
{
struct rdns_request *req;
+ if (IS_CHANNEL_TCP(ioc)) {
+ rdns_ioc_tcp_reset(ioc);
+ }
+
kh_foreach_value(ioc->requests, req, {
REF_RELEASE (req);
});
- ioc->resolver->async->del_read (ioc->resolver->async->data,
- ioc->async_io);
+ if (ioc->async_io) {
+ ioc->resolver->async->del_read(ioc->resolver->async->data,
+ ioc->async_io);
+ }
kh_destroy(rdns_requests_hash, ioc->requests);
if (ioc->sock != -1) {
@@ -640,6 +664,8 @@ rdns_ioc_tcp_reset (struct rdns_io_channel *ioc)
ioc->tcp->async_read = NULL;
}
+ /* Clean all buffers and temporaries */
+
ioc->flags &= ~RDNS_CHANNEL_CONNECTED;
}
@@ -651,6 +677,17 @@ rdns_ioc_tcp_reset (struct rdns_io_channel *ioc)
free (ioc->saddr);
ioc->saddr = NULL;
}
+
+ /* Remove all requests pending as we are unable to complete them */
+ struct rdns_request *req;
+ kh_foreach_value(ioc->requests, req, {
+ struct rdns_reply *rep = rdns_make_reply (req, RDNS_RC_NETERR);
+ req->state = RDNS_REQUEST_REPLIED;
+ req->func (rep, req->arg);
+ REF_RELEASE (req);
+ });
+
+ kh_clear(rdns_requests_hash, ioc->requests);
}
bool
diff --git a/contrib/librdns/util.h b/contrib/librdns/util.h
index 70ad053a0..915b8febd 100644
--- a/contrib/librdns/util.h
+++ b/contrib/librdns/util.h
@@ -81,6 +81,13 @@ void rdns_request_free (struct rdns_request *req);
*/
void rdns_request_remove_from_hash (struct rdns_request *req);
+/**
+ * Creates a new reply
+ * @param req
+ * @param rcode
+ * @return
+ */
+struct rdns_reply * rdns_make_reply (struct rdns_request *req, enum dns_rcode rcode);
/**
* Free reply
* @param rep
More information about the Commits
mailing list