]> git.puffer.fish Git - matthieu/frr.git/commitdiff
lib: fix c-ares thread misuse
authorDavid Lamparter <equinox@opensourcerouting.org>
Sun, 7 Nov 2021 14:41:18 +0000 (15:41 +0100)
committerDavid Lamparter <equinox@opensourcerouting.org>
Mon, 8 Nov 2021 13:06:21 +0000 (14:06 +0100)
The `struct thread **ref` that the thread code takes is written to and
needs to stay valid over the lifetime of a thread.  This does not hold
up if thread pointers are directly put in a `vector` since adding items
to a `vector` may reallocate the entire array.  The thread code would
then write to a now-invalid `ref`, potentially corrupting entirely
unrelated data.

This should be extremely rare to trigger in practice since we only use
one c-ares channel, which will likely only ever use one fd, so the
vector is never resized.  That said, c-ares using only one fd is just
plain fragile luck.

Either way, fix this by creating a resolver_fd tracking struct, and
clean up the code while we're at it.

Signed-off-by: David Lamparter <equinox@opensourcerouting.org>
lib/resolver.c

index 4aba909f2555b758ddaae00261dab0d4705bdf20..e3dba5f8aeb9e83f095ab39e8af7f0bfe7fc07a1 100644 (file)
@@ -14,7 +14,8 @@
 #include <ares.h>
 #include <ares_version.h>
 
-#include "vector.h"
+#include "typesafe.h"
+#include "jhash.h"
 #include "thread.h"
 #include "lib_errors.h"
 #include "resolver.h"
@@ -27,13 +28,78 @@ struct resolver_state {
        ares_channel channel;
        struct thread_master *master;
        struct thread *timeout;
-       vector read_threads, write_threads;
 };
 
 static struct resolver_state state;
 static bool resolver_debug;
 
-#define THREAD_RUNNING ((struct thread *)-1)
+/* a FD doesn't necessarily map 1:1 to a request;  we could be talking to
+ * multiple caches simultaneously, to see which responds fastest.
+ * Theoretically we could also be using the same fd for multiple lookups,
+ * but the c-ares API guarantees an n:1 mapping for fd => channel.
+ *
+ * Either way c-ares makes that decision and we just need to deal with
+ * whatever FDs it gives us.
+ */
+
+DEFINE_MTYPE_STATIC(LIB, ARES_FD, "c-ares (DNS) file descriptor information");
+PREDECL_HASH(resolver_fds);
+
+struct resolver_fd {
+       struct resolver_fds_item itm;
+
+       int fd;
+       struct resolver_state *state;
+       struct thread *t_read, *t_write;
+};
+
+static int resolver_fd_cmp(const struct resolver_fd *a,
+                          const struct resolver_fd *b)
+{
+       return numcmp(a->fd, b->fd);
+}
+
+static uint32_t resolver_fd_hash(const struct resolver_fd *item)
+{
+       return jhash_1word(item->fd, 0xacd04c9e);
+}
+
+DECLARE_HASH(resolver_fds, struct resolver_fd, itm, resolver_fd_cmp,
+            resolver_fd_hash);
+
+static struct resolver_fds_head resfds[1] = {INIT_HASH(resfds[0])};
+
+static struct resolver_fd *resolver_fd_get(int fd,
+                                          struct resolver_state *newstate)
+{
+       struct resolver_fd ref = {.fd = fd}, *res;
+
+       res = resolver_fds_find(resfds, &ref);
+       if (!res && newstate) {
+               res = XCALLOC(MTYPE_ARES_FD, sizeof(*res));
+               res->fd = fd;
+               res->state = newstate;
+               resolver_fds_add(resfds, res);
+
+               if (resolver_debug)
+                       zlog_debug("c-ares registered FD %d", fd);
+       }
+       return res;
+}
+
+static void resolver_fd_drop_maybe(struct resolver_fd *resfd)
+{
+       if (resfd->t_read || resfd->t_write)
+               return;
+
+       if (resolver_debug)
+               zlog_debug("c-ares unregistered FD %d", resfd->fd);
+
+       resolver_fds_del(resfds, resfd);
+       XFREE(MTYPE_ARES_FD, resfd);
+}
+
+/* end of FD housekeeping */
 
 static void resolver_update_timeouts(struct resolver_state *r);
 
@@ -41,9 +107,7 @@ static int resolver_cb_timeout(struct thread *t)
 {
        struct resolver_state *r = THREAD_ARG(t);
 
-       r->timeout = THREAD_RUNNING;
        ares_process(r->channel, NULL, NULL);
-       r->timeout = NULL;
        resolver_update_timeouts(r);
 
        return 0;
@@ -51,17 +115,16 @@ static int resolver_cb_timeout(struct thread *t)
 
 static int resolver_cb_socket_readable(struct thread *t)
 {
-       struct resolver_state *r = THREAD_ARG(t);
-       int fd = THREAD_FD(t);
-       struct thread **t_ptr;
-
-       vector_set_index(r->read_threads, fd, THREAD_RUNNING);
-       ares_process_fd(r->channel, fd, ARES_SOCKET_BAD);
-       if (vector_lookup(r->read_threads, fd) == THREAD_RUNNING) {
-               t_ptr = (struct thread **)vector_get_index(r->read_threads, fd);
-               thread_add_read(r->master, resolver_cb_socket_readable, r, fd,
-                               t_ptr);
-       }
+       struct resolver_fd *resfd = THREAD_ARG(t);
+       struct resolver_state *r = resfd->state;
+
+       thread_add_read(r->master, resolver_cb_socket_readable, resfd,
+                       resfd->fd, &resfd->t_read);
+       /* ^ ordering important:
+        * ares_process_fd may transitively call THREAD_OFF(resfd->t_read)
+        * combined with resolver_fd_drop_maybe, so resfd may be free'd after!
+        */
+       ares_process_fd(r->channel, resfd->fd, ARES_SOCKET_BAD);
        resolver_update_timeouts(r);
 
        return 0;
@@ -69,17 +132,16 @@ static int resolver_cb_socket_readable(struct thread *t)
 
 static int resolver_cb_socket_writable(struct thread *t)
 {
-       struct resolver_state *r = THREAD_ARG(t);
-       int fd = THREAD_FD(t);
-       struct thread **t_ptr;
-
-       vector_set_index(r->write_threads, fd, THREAD_RUNNING);
-       ares_process_fd(r->channel, ARES_SOCKET_BAD, fd);
-       if (vector_lookup(r->write_threads, fd) == THREAD_RUNNING) {
-               t_ptr = (struct thread **)vector_get_index(r->write_threads, fd);
-               thread_add_write(r->master, resolver_cb_socket_writable, r, fd,
-                                t_ptr);
-       }
+       struct resolver_fd *resfd = THREAD_ARG(t);
+       struct resolver_state *r = resfd->state;
+
+       thread_add_write(r->master, resolver_cb_socket_writable, resfd,
+                        resfd->fd, &resfd->t_write);
+       /* ^ ordering important:
+        * ares_process_fd may transitively call THREAD_OFF(resfd->t_write)
+        * combined with resolver_fd_drop_maybe, so resfd may be free'd after!
+        */
+       ares_process_fd(r->channel, ARES_SOCKET_BAD, resfd->fd);
        resolver_update_timeouts(r);
 
        return 0;
@@ -89,13 +151,11 @@ static void resolver_update_timeouts(struct resolver_state *r)
 {
        struct timeval *tv, tvbuf;
 
-       if (r->timeout == THREAD_RUNNING)
-               return;
-
        THREAD_OFF(r->timeout);
        tv = ares_timeout(r->channel, NULL, &tvbuf);
        if (tv) {
                unsigned int timeoutms = tv->tv_sec * 1000 + tv->tv_usec / 1000;
+
                thread_add_timer_msec(r->master, resolver_cb_timeout, r,
                                      timeoutms, &r->timeout);
        }
@@ -105,43 +165,27 @@ static void ares_socket_cb(void *data, ares_socket_t fd, int readable,
                           int writable)
 {
        struct resolver_state *r = (struct resolver_state *)data;
-       struct thread *t, **t_ptr;
-
-       if (readable) {
-               t = vector_lookup(r->read_threads, fd);
-               if (!t) {
-                       t_ptr = (struct thread **)vector_get_index(
-                               r->read_threads, fd);
-                       thread_add_read(r->master, resolver_cb_socket_readable,
-                                       r, fd, t_ptr);
-               }
-       } else {
-               t = vector_lookup(r->read_threads, fd);
-               if (t) {
-                       if (t != THREAD_RUNNING) {
-                               THREAD_OFF(t);
-                       }
-                       vector_unset(r->read_threads, fd);
-               }
-       }
+       struct resolver_fd *resfd;
 
-       if (writable) {
-               t = vector_lookup(r->write_threads, fd);
-               if (!t) {
-                       t_ptr = (struct thread **)vector_get_index(
-                               r->write_threads, fd);
-                       thread_add_read(r->master, resolver_cb_socket_writable,
-                                       r, fd, t_ptr);
-               }
-       } else {
-               t = vector_lookup(r->write_threads, fd);
-               if (t) {
-                       if (t != THREAD_RUNNING) {
-                               THREAD_OFF(t);
-                       }
-                       vector_unset(r->write_threads, fd);
-               }
-       }
+       resfd = resolver_fd_get(fd, (readable || writable) ? r : NULL);
+       if (!resfd)
+               return;
+
+       assert(resfd->state == r);
+
+       if (!readable)
+               THREAD_OFF(resfd->t_read);
+       else if (!resfd->t_read)
+               thread_add_read(r->master, resolver_cb_socket_readable, resfd,
+                               fd, &resfd->t_read);
+
+       if (!writable)
+               THREAD_OFF(resfd->t_write);
+       else if (!resfd->t_write)
+               thread_add_write(r->master, resolver_cb_socket_writable, resfd,
+                                fd, &resfd->t_write);
+
+       resolver_fd_drop_maybe(resfd);
 }
 
 
@@ -271,8 +315,6 @@ void resolver_init(struct thread_master *tm)
        struct ares_options ares_opts;
 
        state.master = tm;
-       state.read_threads = vector_init(1);
-       state.write_threads = vector_init(1);
 
        ares_opts = (struct ares_options){
                .sock_state_cb = &ares_socket_cb,