]> git.puffer.fish Git - matthieu/nova.git/commitdiff
add token from config and change the signal handler to SIGTERM
authorMatthieuCoder <matthieu@matthieu-dev.xyz>
Mon, 2 Jan 2023 15:53:53 +0000 (19:53 +0400)
committerMatthieuCoder <matthieu@matthieu-dev.xyz>
Mon, 2 Jan 2023 15:53:53 +0000 (19:53 +0400)
12 files changed:
Cargo.lock
exes/gateway/src/main.rs
exes/rest/src/config.rs
exes/rest/src/handler.rs
exes/rest/src/main.rs
exes/rest/src/ratelimit_client/mod.rs
exes/webhook/Cargo.toml
exes/webhook/src/config.rs
exes/webhook/src/main.rs
libs/leash/Cargo.toml
libs/leash/src/lib.rs
libs/shared/Cargo.toml

index c76107e2ad78b4e2e2eaa0234a623d797cdd05a2..8181e8d70a519eab52261eb73cf93c0e53874d1a 100644 (file)
@@ -1143,6 +1143,7 @@ name = "leash"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "pretty_env_logger",
  "serde",
  "shared",
  "tokio",
@@ -2183,7 +2184,6 @@ dependencies = [
  "hyper",
  "inner",
  "log",
- "pretty_env_logger",
  "prometheus",
  "redis",
  "serde",
@@ -3006,6 +3006,7 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "ed25519-dalek",
+ "futures-util",
  "hex",
  "hyper",
  "lazy_static",
index 7957b08435c7f4b5e90bb4bcaa8c17d474c289f2..f2a4f936b24fe42878d349415ed2ff0de4c97a01 100644 (file)
@@ -6,18 +6,24 @@ use shared::{
     nats_crate::Client,
     payloads::{CachePayload, DispatchEventTagged, Tracing},
 };
+use tokio::sync::oneshot;
 use std::{convert::TryFrom, pin::Pin};
 use twilight_gateway::{Event, Shard};
 mod config;
-use futures::{Future, StreamExt};
+use futures::{Future, StreamExt, select};
 use twilight_model::gateway::event::DispatchEvent;
+use futures::FutureExt;
 
 struct GatewayServer {}
 impl Component for GatewayServer {
     type Config = GatewayConfig;
     const SERVICE_NAME: &'static str = "gateway";
 
-    fn start(&self, settings: Settings<Self::Config>) -> AnyhowResultFuture<()> {
+    fn start(
+        &self,
+        settings: Settings<Self::Config>,
+        stop: oneshot::Receiver<()>,
+    ) -> AnyhowResultFuture<()> {
         Box::pin(async move {
             let (shard, mut events) = Shard::builder(settings.token.to_owned(), settings.intents)
                 .shard(settings.shard, settings.shard_total)?
@@ -29,34 +35,48 @@ impl Component for GatewayServer {
 
             shard.start().await?;
 
-            while let Some(event) = events.next().await {
-                match event {
-                    Event::Ready(ready) => {
-                        info!("Logged in as {}", ready.user.name);
-                    }
+            let mut stop = stop.fuse();
+            loop {
 
-                    _ => {
-                        let name = event.kind().name();
-                        if let Ok(dispatch_event) = DispatchEvent::try_from(event) {
-                            let data = CachePayload {
-                                tracing: Tracing {
-                                    node_id: "".to_string(),
-                                    span: None,
-                                },
-                                data: DispatchEventTagged {
-                                    data: dispatch_event,
-                                },
-                            };
-                            let value = serde_json::to_string(&data)?;
-                            debug!("nats send: {}", value);
-                            let bytes = bytes::Bytes::from(value);
-                            nats.publish(format!("nova.cache.dispatch.{}", name.unwrap()), bytes)
-                                .await?;
+                select! {
+                    event = events.next().fuse() => {
+                        if let Some(event) = event {
+                            match event {
+                                Event::Ready(ready) => {
+                                    info!("Logged in as {}", ready.user.name);
+                                }
+            
+                                _ => {
+                                    let name = event.kind().name();
+                                    if let Ok(dispatch_event) = DispatchEvent::try_from(event) {
+                                        let data = CachePayload {
+                                            tracing: Tracing {
+                                                node_id: "".to_string(),
+                                                span: None,
+                                            },
+                                            data: DispatchEventTagged {
+                                                data: dispatch_event,
+                                            },
+                                        };
+                                        let value = serde_json::to_string(&data)?;
+                                        debug!("nats send: {}", value);
+                                        let bytes = bytes::Bytes::from(value);
+                                        nats.publish(format!("nova.cache.dispatch.{}", name.unwrap()), bytes)
+                                            .await?;
+                                    }
+                                }
+                            }
+                        } else {
+                            break
                         }
-                    }
-                }
+                    },
+                    _ = stop => break
+                };
             }
 
+            info!("stopping shard...");
+            shard.shutdown();
+
             Ok(())
         })
     }
index 9261de268e4c9e6ae95d7405a8e25f9842fd384b..5c2698b316baa67aa3a9849caeb90e8d30a92378 100644 (file)
@@ -2,7 +2,7 @@ use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
 use serde::Deserialize;
 
 fn default_listening_address() -> SocketAddr {
-    SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080))
+    SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 8080))
 }
 
 #[derive(Debug, Deserialize, Clone)]
index 8b0dd52503e0459bbc4f08960d1098edbb733d4e..ea81ade96975bb817491b154b6841d4130143b36 100644 (file)
@@ -3,6 +3,7 @@ use std::{
     convert::TryFrom,
     hash::{Hash, Hasher},
     str::FromStr,
+    time::Instant,
 };
 
 use anyhow::bail;
@@ -38,7 +39,7 @@ fn normalize_path(request_path: &str) -> (&str, &str) {
 pub async fn handle_request(
     client: Client<HttpsConnector<HttpConnector>, Body>,
     ratelimiter: RemoteRatelimiter,
-    token: String,
+    token: &str,
     mut request: Request<Body>,
 ) -> Result<Response<Body>, anyhow::Error> {
     let (hash, uri_string) = {
@@ -57,7 +58,7 @@ pub async fn handle_request(
         let request_path = request.uri().path();
         let (api_path, trimmed_path) = normalize_path(&request_path);
 
-        let mut uri_string = format!("http://192.168.0.27:8000{}{}", api_path, trimmed_path);
+        let mut uri_string = format!("https://discord.com{}{}", api_path, trimmed_path);
         if let Some(query) = request.uri().query() {
             uri_string.push('?');
             uri_string.push_str(query);
@@ -79,6 +80,7 @@ pub async fn handle_request(
         (hash.finish().to_string(), uri_string)
     };
 
+    let start_ticket_request = Instant::now();
     let header_sender = match ratelimiter.ticket(hash).await {
         Ok(sender) => sender,
         Err(e) => {
@@ -86,6 +88,7 @@ pub async fn handle_request(
             bail!("failed to reteive ticket");
         }
     };
+    let time_took_ticket = Instant::now() - start_ticket_request;
 
     request.headers_mut().insert(
         AUTHORIZATION,
@@ -106,9 +109,7 @@ pub async fn handle_request(
     request.headers_mut().remove(AUTHORIZATION);
     request.headers_mut().append(
         AUTHORIZATION,
-        HeaderValue::from_static(
-            "Bot ODA3MTg4MzM1NzE3Mzg0MjEy.G3sXFM.8gY2sVYDAq2WuPWwDskAAEFLfTg8htooxME-LE",
-        ),
+        HeaderValue::from_str(&format!("Bot {}", token))?,
     );
 
     let uri = match Uri::from_str(&uri_string) {
@@ -119,14 +120,26 @@ pub async fn handle_request(
         }
     };
     *request.uri_mut() = uri;
-    let resp = match client.request(request).await {
+
+    let start_upstream_req = Instant::now();
+    let mut resp = match client.request(request).await {
         Ok(response) => response,
         Err(e) => {
             error!("Error when requesting the Discord API: {:?}", e);
             bail!("failed to request the discord api");
         }
     };
+    let upstream_time_took = Instant::now() - start_upstream_req;
 
+    resp.headers_mut().append(
+        "X-TicketRequest-Ms",
+        HeaderValue::from_str(&time_took_ticket.as_millis().to_string()).unwrap(),
+    );
+    resp.headers_mut().append(
+        "X-Upstream-Ms",
+        HeaderValue::from_str(&upstream_time_took.as_millis().to_string()).unwrap(),
+    );
+    
     let ratelimit_headers = resp
         .headers()
         .into_iter()
index 8d014ab1f8f3d08852263d392d1820f889b29307..07d835c2247dfadc9157aa42e425e0d253a60916 100644 (file)
@@ -9,7 +9,8 @@ use hyper::{
 use hyper_tls::HttpsConnector;
 use leash::{ignite, AnyhowResultFuture, Component};
 use shared::config::Settings;
-use std::convert::Infallible;
+use std::{convert::Infallible, sync::Arc};
+use tokio::sync::oneshot;
 
 mod config;
 mod handler;
@@ -20,21 +21,29 @@ impl Component for ReverseProxyServer {
     type Config = ReverseProxyConfig;
     const SERVICE_NAME: &'static str = "rest";
 
-    fn start(&self, settings: Settings<Self::Config>) -> AnyhowResultFuture<()> {
+    fn start(
+        &self,
+        settings: Settings<Self::Config>,
+        stop: oneshot::Receiver<()>,
+    ) -> AnyhowResultFuture<()> {
         Box::pin(async move {
             // Client to the remote ratelimiters
             let ratelimiter = ratelimit_client::RemoteRatelimiter::new();
             let client = Client::builder().build(HttpsConnector::new());
 
+            let token = Arc::new(settings.discord.token.clone());
             let service_fn = make_service_fn(move |_: &AddrStream| {
                 let client = client.clone();
                 let ratelimiter = ratelimiter.clone();
+                let token = token.clone();
                 async move {
                     Ok::<_, Infallible>(service_fn(move |request: Request<Body>| {
                         let client = client.clone();
                         let ratelimiter = ratelimiter.clone();
+                        let token = token.clone();
                         async move {
-                            handle_request(client, ratelimiter, "token".to_string(), request).await
+                            let token = token.as_str();
+                            handle_request(client, ratelimiter, token, request).await
                         }
                     }))
                 }
@@ -42,7 +51,11 @@ impl Component for ReverseProxyServer {
 
             let server = Server::bind(&settings.config.server.listening_adress).serve(service_fn);
 
-            server.await?;
+            server
+                .with_graceful_shutdown(async {
+                    stop.await.expect("should not fail");
+                })
+                .await?;
 
             Ok(())
         })
index 8263d156d70b7937d186ef4de0fa473475477921..87737dd5b9a92c74993d30b14b3fed5d03da44a3 100644 (file)
@@ -64,7 +64,7 @@ impl RemoteRatelimiter {
                 obj_clone.get_ratelimiters().await.unwrap();
                 tokio::select! {
                     () = &mut sleep => {
-                        println!("timer elapsed");
+                        debug!("timer elapsed");
                     },
                     _ = tx.recv() => {}
                 }
index 12a66080054f3610a86e36384ae55056afd1aab0..589b5bd772fb2f187f4994432d52aa89acfb402e 100644 (file)
@@ -17,6 +17,7 @@ lazy_static = "1.4.0"
 ed25519-dalek = "1"
 twilight-model = { version = "0.14" }
 anyhow = "1.0.68"
+futures-util = "0.3.25"
 
 [[bin]]
 name = "webhook"
index 68f6a5fc8949b0c21f2b611d42fc69d6e8eb25a3..e98de137ae4363248f4a008a88e54e1fcee0c322 100644 (file)
@@ -4,7 +4,7 @@ use ed25519_dalek::PublicKey;
 use serde::{Deserialize, Deserializer};
 
 fn default_listening_address() -> SocketAddr {
-    SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080))
+    SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 8080))
 }
 
 #[derive(Debug, Deserialize, Clone, Copy)]
index efd41474d438164314be9a6ad70119a8e6a55568..0215e511e112459f3a306684aa339fe8df322bd6 100644 (file)
@@ -9,6 +9,7 @@ use crate::{
 use hyper::Server;
 use leash::{ignite, AnyhowResultFuture, Component};
 use shared::{config::Settings, log::info, nats_crate::Client};
+use tokio::sync::oneshot;
 
 #[derive(Clone, Copy)]
 struct WebhookServer {}
@@ -17,7 +18,11 @@ impl Component for WebhookServer {
     type Config = WebhookConfig;
     const SERVICE_NAME: &'static str = "webhook";
 
-    fn start(&self, settings: Settings<Self::Config>) -> AnyhowResultFuture<()> {
+    fn start(
+        &self,
+        settings: Settings<Self::Config>,
+        stop: oneshot::Receiver<()>,
+    ) -> AnyhowResultFuture<()> {
         Box::pin(async move {
             info!("Starting server on {}", settings.server.listening_adress);
 
@@ -33,7 +38,9 @@ impl Component for WebhookServer {
 
             let server = Server::bind(&bind).serve(make_service);
 
-            server.await?;
+            server.with_graceful_shutdown(async {
+                stop.await.expect("should not fail");
+            }).await?;
 
             Ok(())
         })
index 5cd54a53350619b69e86958e98ae37277f0f0d75..32f385c5720b76fa6648c5b8f87be71a564feea6 100644 (file)
@@ -9,5 +9,5 @@ edition = "2021"
 shared = { path = "../shared" }
 anyhow = "1.0.68"
 tokio = { version = "1.23.0", features = ["full"] }
-
+pretty_env_logger = "0.4"
 serde = "1.0.152"
\ No newline at end of file
index 360db127848fc72125a964af0a581e172cb15653..1de768710df7425f8fbc25bd0f5cc181100cbcf9 100644 (file)
@@ -1,19 +1,29 @@
 use anyhow::Result;
 use serde::de::DeserializeOwned;
-use shared::config::Settings;
+use shared::{
+    config::Settings,
+    log::{error, info},
+};
 use std::{future::Future, pin::Pin};
+use tokio::{signal::{unix::SignalKind}, sync::oneshot};
 
 pub type AnyhowResultFuture<T> = Pin<Box<dyn Future<Output = Result<T>>>>;
 pub trait Component: Send + Sync + 'static + Sized {
     type Config: Default + Clone + DeserializeOwned;
 
     const SERVICE_NAME: &'static str;
-    fn start(&self, settings: Settings<Self::Config>) -> AnyhowResultFuture<()>;
+    fn start(
+        &self,
+        settings: Settings<Self::Config>,
+        stop: oneshot::Receiver<()>,
+    ) -> AnyhowResultFuture<()>;
     fn new() -> Self;
 
     fn _internal_start(self) -> AnyhowResultFuture<()> {
         Box::pin(async move {
+            pretty_env_logger::init();
             let settings = Settings::<Self::Config>::new(Self::SERVICE_NAME);
+            let (stop, stop_channel) = oneshot::channel();
 
             // Start the grpc healthcheck
             tokio::spawn(async move {});
@@ -21,7 +31,21 @@ pub trait Component: Send + Sync + 'static + Sized {
             // Start the prometheus monitoring job
             tokio::spawn(async move {});
 
-            self.start(settings?).await
+            tokio::spawn(async move {
+                match tokio::signal::unix::signal(SignalKind::terminate()).unwrap().recv().await {
+                    Some(()) => {
+                        info!("Stopping program.");
+
+                        stop.send(()).unwrap();
+                    }
+                    None => {
+                        error!("Unable to listen for shutdown signal");
+                        // we also shut down in case of error
+                    }
+                }
+            });
+
+            self.start(settings?, stop_channel).await
         })
     }
 }
@@ -41,6 +65,7 @@ macro_rules! ignite {
 #[cfg(test)]
 mod test {
     use serde::Deserialize;
+    use tokio::sync::oneshot;
 
     use crate::Component;
 
@@ -57,6 +82,7 @@ mod test {
         fn start(
             &self,
             _settings: shared::config::Settings<Self::Config>,
+            _stop: oneshot::Receiver<()>,
         ) -> crate::AnyhowResultFuture<()> {
             Box::pin(async move { Ok(()) })
         }
@@ -65,6 +91,6 @@ mod test {
             Self {}
         }
     }
-    
+
     ignite!(TestComponent);
 }
index ab19ce8e86c503decbf01676bb9fe6c91774f354..ce08fbc3f1792bc799657a247845f62ee3596706 100644 (file)
@@ -4,7 +4,6 @@ version = "0.1.0"
 edition = "2021"
 
 [dependencies]
-pretty_env_logger = "0.4"
 log = { version = "0.4", features = ["std"] }
 serde = { version = "1.0.8", features = ["derive"] }
 serde_repr = "0.1"