]> git.puffer.fish Git - matthieu/frr.git/commitdiff
lib: grpc: fix shutdown code
authorChristian Hopps <chopps@labn.net>
Sat, 5 Mar 2022 16:04:43 +0000 (11:04 -0500)
committerChristian Hopps <chopps@labn.net>
Sun, 6 Mar 2022 17:00:17 +0000 (12:00 -0500)
fixes #9732

Signed-off-by: Christian Hopps <chopps@labn.net>
lib/northbound_grpc.cpp

index 2d9b61483bfdb8e1a16cc5f7db8b3bfc1a2ea5ac..ca253031f45baf0c8c6e98711a683bc76f941598 100644 (file)
@@ -50,6 +50,8 @@ static struct thread_master *main_master;
 
 static struct frr_pthread *fpt;
 
+static bool grpc_running;
+
 #define grpc_debug(...)                                                        \
        do {                                                                   \
                if (nb_dbg_client_grpc)                                        \
@@ -116,9 +118,11 @@ class Candidates
 class RpcStateBase
 {
       public:
+       virtual ~RpcStateBase() = default;
        virtual CallState doCallback() = 0;
        virtual void do_request(::frr::Northbound::AsyncService *service,
-                               ::grpc::ServerCompletionQueue *cq) = 0;
+                               ::grpc::ServerCompletionQueue *cq,
+                               bool no_copy) = 0;
 };
 
 /*
@@ -188,17 +192,22 @@ template <typename Q, typename S> class NewRpcState : RpcStateBase
        }
 
        void do_request(::frr::Northbound::AsyncService *service,
-                       ::grpc::ServerCompletionQueue *cq) override
+                       ::grpc::ServerCompletionQueue *cq,
+                       bool no_copy) override
        {
                grpc_debug("%s, posting a request for: %s", __func__, name);
                if (requestf) {
                        NewRpcState<Q, S> *copy =
-                               new NewRpcState(cdb, requestf, callback, name);
+                               no_copy ? this
+                                       : new NewRpcState(cdb, requestf,
+                                                         callback, name);
                        (service->*requestf)(&copy->ctx, &copy->request,
                                             &copy->responder, cq, cq, copy);
                } else {
                        NewRpcState<Q, S> *copy =
-                               new NewRpcState(cdb, requestsf, callback, name);
+                               no_copy ? this
+                                       : new NewRpcState(cdb, requestsf,
+                                                         callback, name);
                        (service->*requestsf)(&copy->ctx, &copy->request,
                                              &copy->async_responder, cq, cq,
                                              copy);
@@ -1225,7 +1234,7 @@ void HandleUnaryExecute(
                                                 frr::NAME##Response>(         \
                        (cdb), &frr::Northbound::AsyncService::Request##NAME,  \
                        &HandleUnary##NAME, #NAME);                            \
-               _rpcState->do_request(service, s_cq);                          \
+               _rpcState->do_request(&service, cq.get(), true);               \
        } while (0)
 
 #define REQUEST_NEWRPC_STREAMING(NAME, cdb)                                    \
@@ -1234,7 +1243,7 @@ void HandleUnaryExecute(
                                                 frr::NAME##Response>(         \
                        (cdb), &frr::Northbound::AsyncService::Request##NAME,  \
                        &HandleStreaming##NAME, #NAME);                        \
-               _rpcState->do_request(service, s_cq);                          \
+               _rpcState->do_request(&service, cq.get(), true);               \
        } while (0)
 
 struct grpc_pthread_attr {
@@ -1243,8 +1252,8 @@ struct grpc_pthread_attr {
 };
 
 // Capture these objects so we can try to shut down cleanly
-static std::unique_ptr<grpc::Server> s_server;
-static grpc::ServerCompletionQueue *s_cq;
+static pthread_mutex_t s_server_lock = PTHREAD_MUTEX_INITIALIZER;
+static grpc::Server *s_server;
 
 static void *grpc_pthread_start(void *arg)
 {
@@ -1254,20 +1263,22 @@ static void *grpc_pthread_start(void *arg)
        Candidates candidates;
        grpc::ServerBuilder builder;
        std::stringstream server_address;
-       frr::Northbound::AsyncService *service =
-               new frr::Northbound::AsyncService();
+       frr::Northbound::AsyncService service;
 
        frr_pthread_set_name(fpt);
 
        server_address << "0.0.0.0:" << port;
        builder.AddListeningPort(server_address.str(),
                                 grpc::InsecureServerCredentials());
-       builder.RegisterService(service);
+       builder.RegisterService(&service);
        builder.AddChannelArgument(
                GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS, 5000);
-       auto cq = builder.AddCompletionQueue();
-       s_cq = cq.get();
-       s_server = builder.BuildAndStart();
+       std::unique_ptr<grpc::ServerCompletionQueue> cq =
+               builder.AddCompletionQueue();
+       std::unique_ptr<grpc::Server> server = builder.BuildAndStart();
+       s_server = server.get();
+
+       grpc_running = true;
 
        /* Schedule all RPC handlers */
        REQUEST_NEWRPC(GetCapabilities, NULL);
@@ -1288,20 +1299,25 @@ static void *grpc_pthread_start(void *arg)
                    server_address.str().c_str());
 
        /* Process inbound RPCs */
-       while (true) {
-               void *tag;
-               bool ok;
-
-               s_cq->Next(&tag, &ok);
-               if (!ok)
+       bool ok;
+       void *tag;
+       while (grpc_running) {
+               if (!cq->Next(&tag, &ok)) {
+                       grpc_debug("%s: CQ empty exiting", __func__);
                        break;
+               }
 
-               grpc_debug("%s: Got next from CompletionQueue, %p %d", __func__,
-                          tag, ok);
+               grpc_debug("%s: got next from CQ tag: %p ok: %d", __func__, tag,
+                          ok);
+
+               if (!ok || !grpc_running) {
+                       delete static_cast<RpcStateBase *>(tag);
+                       break;
+               }
 
                RpcStateBase *rpc = static_cast<RpcStateBase *>(tag);
                CallState state = rpc->doCallback();
-               grpc_debug("%s: Callback returned RPC State: %s", __func__,
+               grpc_debug("%s: callback returned RPC State: %s", __func__,
                           call_states[state]);
 
                /*
@@ -1311,10 +1327,30 @@ static void *grpc_pthread_start(void *arg)
                 * to be called back once more in the FINISH state (from the
                 * user indicating Finish() for cleanup.
                 */
-               if (state == FINISH)
-                       rpc->do_request(service, s_cq);
+               if (state == FINISH && grpc_running)
+                       rpc->do_request(&service, cq.get(), false);
        }
 
+       /* This was probably done for us to get here, but let's be safe */
+       pthread_mutex_lock(&s_server_lock);
+       grpc_running = false;
+       if (s_server) {
+               grpc_debug("%s: shutdown server and CQ", __func__);
+               server->Shutdown();
+               s_server = NULL;
+       }
+       pthread_mutex_unlock(&s_server_lock);
+
+       grpc_debug("%s: shutting down CQ", __func__);
+       cq->Shutdown();
+
+       grpc_debug("%s: draining the CQ", __func__);
+       while (cq->Next(&tag, &ok)) {
+               grpc_debug("%s: drain tag %p", __func__, tag);
+               delete static_cast<RpcStateBase *>(tag);
+       }
+
+       zlog_info("%s: exiting from grpc pthread", __func__);
        return NULL;
 }
 
@@ -1326,6 +1362,8 @@ static int frr_grpc_init(uint port)
                .stop = NULL,
        };
 
+       grpc_debug("%s: entered", __func__);
+
        fpt = frr_pthread_new(&attr, "frr-grpc", "frr-grpc");
        fpt->data = reinterpret_cast<void *>((intptr_t)port);
 
@@ -1341,24 +1379,27 @@ static int frr_grpc_init(uint port)
 
 static int frr_grpc_finish(void)
 {
-       // Shutdown the grpc server
-       if (s_server) {
-               s_server->Shutdown();
-               s_cq->Shutdown();
+       grpc_debug("%s: entered", __func__);
 
-               // And drain the queue
-               void *ignore;
-               bool ok;
-
-               while (s_cq->Next(&ignore, &ok))
-                       ;
-       }
+       if (!fpt)
+               return 0;
 
-       if (fpt) {
-               pthread_join(fpt->thread, NULL);
-               frr_pthread_destroy(fpt);
+       /*
+        * Shut the server down here in main thread. This will cause the wait on
+        * the completion queue (cq.Next()) to exit and cleanup everything else.
+        */
+       pthread_mutex_lock(&s_server_lock);
+       grpc_running = false;
+       if (s_server) {
+               grpc_debug("%s: shutdown server", __func__);
+               s_server->Shutdown();
+               s_server = NULL;
        }
+       pthread_mutex_unlock(&s_server_lock);
 
+       grpc_debug("%s: joining and destroy grpc thread", __func__);
+       pthread_join(fpt->thread, NULL);
+       frr_pthread_destroy(fpt);
        return 0;
 }