diff options
| author | MatthieuCoder <matthieu@matthieu-dev.xyz> | 2023-01-13 22:23:19 +0400 | 
|---|---|---|
| committer | MatthieuCoder <matthieu@matthieu-dev.xyz> | 2023-01-13 22:23:19 +0400 | 
| commit | b1e17001e3fce2874e4bb1354196c90a5fd7acd0 (patch) | |
| tree | 95aa49d203c37afda042befcfc8c0fd679a7e1e8 /exes/rest | |
| parent | bca576c35e54d5505c62e42e8be416c797e84b6b (diff) | |
better ratelimit, new go structure and better build system
Diffstat (limited to 'exes/rest')
| -rw-r--r-- | exes/rest/src/handler.rs | 37 | ||||
| -rw-r--r-- | exes/rest/src/lib.rs | 25 | ||||
| -rw-r--r-- | exes/rest/src/ratelimit_client/mod.rs | 169 | ||||
| -rw-r--r-- | exes/rest/src/ratelimit_client/remote_hashring.rs | 8 | 
4 files changed, 123 insertions, 116 deletions
diff --git a/exes/rest/src/handler.rs b/exes/rest/src/handler.rs index 6d58dce..f59583d 100644 --- a/exes/rest/src/handler.rs +++ b/exes/rest/src/handler.rs @@ -10,8 +10,9 @@ use std::{      convert::TryFrom,      hash::{Hash, Hasher},      str::FromStr, +    sync::Arc,  }; -use tracing::{debug_span, error, instrument, Instrument}; +use tracing::{debug_span, error, info_span, instrument, Instrument};  use twilight_http_ratelimiting::{Method, Path};  use crate::ratelimit_client::RemoteRatelimiter; @@ -37,8 +38,8 @@ fn normalize_path(request_path: &str) -> (&str, &str) {  #[instrument]  pub async fn handle_request(      client: Client<HttpsConnector<HttpConnector>, Body>, -    ratelimiter: RemoteRatelimiter, -    token: &str, +    ratelimiter: Arc<RemoteRatelimiter>, +    token: String,      mut request: Request<Body>,  ) -> Result<Response<Body>, anyhow::Error> {      let (hash, uri_string) = { @@ -78,19 +79,12 @@ pub async fn handle_request(          (hash.finish().to_string(), uri_string)      }; +    // waits for the request to be authorized +    ratelimiter +        .ticket(hash.clone()) +        .instrument(debug_span!("ticket validation request")) +        .await?; -    let span = debug_span!("ticket validation request"); -    let header_sender = match span -        .in_scope(|| ratelimiter.ticket(hash)) -        .await -    { -        Ok(sender) => sender, -        Err(e) => { -            error!("Failed to receive ticket for ratelimiting: {:?}", e); -            bail!("failed to reteive ticket"); -        } -    }; -          request          .headers_mut()          .insert(HOST, HeaderValue::from_static("discord.com")); @@ -135,15 +129,18 @@ pub async fn handle_request(          }      }; -    let ratelimit_headers = resp +    let headers = resp          .headers()          .into_iter() -        .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string())) +        .map(|(k, v)| (k.to_string(), v.to_str().map(|f| f.to_string()))) +        .filter(|f| f.1.is_ok()) +        .map(|f| (f.0, f.1.expect("errors should be filtered")))          .collect(); -    if header_sender.send(ratelimit_headers).is_err() { -        error!("Error when sending ratelimit headers to ratelimiter"); -    }; +    let _ = ratelimiter +        .submit_headers(hash, headers) +        .instrument(info_span!("submitting headers")) +        .await;      Ok(resp)  } diff --git a/exes/rest/src/lib.rs b/exes/rest/src/lib.rs index 2cef631..b2287b4 100644 --- a/exes/rest/src/lib.rs +++ b/exes/rest/src/lib.rs @@ -7,11 +7,12 @@ use hyper::{      Body, Client, Request, Server,  };  use leash::{AnyhowResultFuture, Component}; -use opentelemetry::{global, trace::Tracer}; +use opentelemetry::{global};  use opentelemetry_http::HeaderExtractor;  use shared::config::Settings;  use std::{convert::Infallible, sync::Arc};  use tokio::sync::oneshot; +use tracing_opentelemetry::OpenTelemetrySpanExt;  mod config;  mod handler; @@ -29,7 +30,9 @@ impl Component for ReverseProxyServer {      ) -> AnyhowResultFuture<()> {          Box::pin(async move {              // Client to the remote ratelimiters -            let ratelimiter = ratelimit_client::RemoteRatelimiter::new(settings.config.clone()); +            let ratelimiter = Arc::new(ratelimit_client::RemoteRatelimiter::new( +                settings.config.clone(), +            ));              let https = hyper_rustls::HttpsConnectorBuilder::new()                  .with_native_roots()                  .https_only() @@ -37,31 +40,37 @@ impl Component for ReverseProxyServer {                  .build();              let client: Client<_, hyper::Body> = Client::builder().build(https); -            let token = Arc::new(settings.discord.token.clone()); +            let token = settings.config.discord.token.clone(); +              let service_fn = make_service_fn(move |_: &AddrStream| {                  let client = client.clone();                  let ratelimiter = ratelimiter.clone();                  let token = token.clone();                  async move {                      Ok::<_, Infallible>(service_fn(move |request: Request<Body>| { +                        let token = token.clone();                          let parent_cx = global::get_text_map_propagator(|propagator| {                              propagator.extract(&HeaderExtractor(request.headers()))                          }); -                        let _span = -                            global::tracer("").start_with_context("handle_request", &parent_cx); + +                        let span = tracing::span!(tracing::Level::INFO, "request process"); +                        span.set_parent(parent_cx);                          let client = client.clone();                          let ratelimiter = ratelimiter.clone(); -                        let token = token.clone(); +                          async move { -                            let token = token.as_str(); +                            let token = token.clone(); +                            let ratelimiter = ratelimiter.clone();                              handle_request(client, ratelimiter, token, request).await                          }                      }))                  }              }); -            let server = Server::bind(&settings.config.server.listening_adress).http1_only(true).serve(service_fn); +            let server = Server::bind(&settings.config.server.listening_adress) +                .http1_only(true) +                .serve(service_fn);              server                  .with_graceful_shutdown(async { diff --git a/exes/rest/src/ratelimit_client/mod.rs b/exes/rest/src/ratelimit_client/mod.rs index 6212529..ecd489c 100644 --- a/exes/rest/src/ratelimit_client/mod.rs +++ b/exes/rest/src/ratelimit_client/mod.rs @@ -1,21 +1,18 @@  use crate::config::ReverseProxyConfig;  use self::remote_hashring::{HashRingWrapper, MetadataMap, VNode}; +use anyhow::anyhow;  use opentelemetry::global; -use proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::{Data, Headers}; -use proto::nova::ratelimit::ratelimiter::BucketSubmitTicketRequest; +use proto::nova::ratelimit::ratelimiter::{BucketSubmitTicketRequest, HeadersSubmitRequest};  use std::collections::HashMap;  use std::fmt::Debug;  use std::future::Future;  use std::pin::Pin;  use std::sync::Arc; -use std::time::UNIX_EPOCH;  use std::time::{Duration, SystemTime}; -use tokio::sync::oneshot::{self}; -use tokio::sync::{broadcast, mpsc, RwLock}; -use tokio_stream::wrappers::ReceiverStream; +use tokio::sync::{broadcast, RwLock};  use tonic::Request; -use tracing::{debug, debug_span, instrument, Instrument, Span}; +use tracing::{debug, error, info_span, instrument, trace_span, Instrument, Span};  use tracing_opentelemetry::OpenTelemetrySpanExt;  mod remote_hashring; @@ -29,25 +26,21 @@ pub struct RemoteRatelimiter {  impl Drop for RemoteRatelimiter {      fn drop(&mut self) { -        self.stop.clone().send(()).unwrap(); +        let _ = self +            .stop +            .clone() +            .send(()) +            .map_err(|_| error!("ratelimiter was already stopped"));      }  } -type IssueTicket = Pin< -    Box< -        dyn Future<Output = anyhow::Result<oneshot::Sender<HashMap<String, String>>>> -            + Send -            + 'static, -    >, ->; -  impl RemoteRatelimiter {      async fn get_ratelimiters(&self) -> Result<(), anyhow::Error> {          // get list of dns responses -        let responses = dns_lookup::lookup_host(&self.config.ratelimiter_address) -            .unwrap() +        let responses = dns_lookup::lookup_host(&self.config.ratelimiter_address)?              .into_iter() -            .map(|f| f.to_string()); +            .filter(|address| address.is_ipv4()) +            .map(|address| address.to_string());          let mut write = self.remotes.write().await; @@ -72,11 +65,19 @@ impl RemoteRatelimiter {          // Task to update the ratelimiters in the background          tokio::spawn(async move {              loop { +                debug!("refreshing"); + +                match obj_clone.get_ratelimiters().await { +                    Ok(_) => { +                        debug!("refreshed ratelimiting servers") +                    } +                    Err(err) => { +                        error!("refreshing ratelimiting servers failed {}", err); +                    } +                } +                  let sleep = tokio::time::sleep(Duration::from_secs(10));                  tokio::pin!(sleep); - -                debug!("refreshing"); -                obj_clone.get_ratelimiters().await.unwrap();                  tokio::select! {                      () = &mut sleep => {                          debug!("timer elapsed"); @@ -90,79 +91,79 @@ impl RemoteRatelimiter {      }      #[instrument(name = "ticket task")] -    pub fn ticket(&self, path: String) -> IssueTicket { +    pub fn ticket( +        &self, +        path: String, +    ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> {          let remotes = self.remotes.clone(); -        let (tx, rx) = oneshot::channel::<HashMap<String, String>>();          Box::pin(              async move { -                // Get node managing this path -                let mut node = (*remotes.read().await.get(&path).unwrap()).clone(); - -                // Buffers for the gRPC streaming channel. -                let (send, remote) = mpsc::channel(5); -                let (do_request, wait) = oneshot::channel(); -                // Tonic requires a stream to be used; Since we use a mpsc channel, we can create a stream from it -                let stream = ReceiverStream::new(remote); - -                let mut request = Request::new(stream); - -                let span = debug_span!("remote request"); +                // Getting the node managing this path +                let mut node = remotes +                    .write() +                    .instrument(trace_span!("acquiring ring lock")) +                    .await +                    .get(&path) +                    .and_then(|node| Some(node.clone())) +                    .ok_or_else(|| { +                        anyhow!( +                            "did not compute ratelimit because no ratelimiter nodes are detected" +                        ) +                    })?; + +                // Initialize span for tracing (headers injection) +                let span = info_span!("remote request");                  let context = span.context(); +                let mut request = Request::new(BucketSubmitTicketRequest { path });                  global::get_text_map_propagator(|propagator| {                      propagator.inject_context(&context, &mut MetadataMap(request.metadata_mut()))                  }); -                // Start the grpc streaming -                let ticket = node.submit_ticket(request).await?; - -                // First, send the request -                send.send(BucketSubmitTicketRequest { -                    data: Some(Data::Path(path)), -                }) -                .await?; - -                // We continuously listen for events in the channel. -                let span = debug_span!("stream worker"); -                tokio::spawn( -                    async move { -                        let span = debug_span!("waiting for ticket upstream"); -                        let message = ticket -                            .into_inner() -                            .message() -                            .instrument(span) -                            .await -                            .unwrap() -                            .unwrap(); - -                        if message.accepted == 1 { -                            debug!("request ticket was accepted"); -                            do_request.send(()).unwrap(); -                            let span = debug_span!("waiting for response headers"); -                            let headers = rx.instrument(span).await.unwrap(); - -                            send.send(BucketSubmitTicketRequest { -                                data: Some(Data::Headers(Headers { -                                    precise_time: SystemTime::now() -                                        .duration_since(UNIX_EPOCH) -                                        .expect("time went backwards") -                                        .as_millis() -                                        as u64, -                                    headers, -                                })), -                            }) -                            .await -                            .unwrap(); -                        } -                    } -                    .instrument(span), -                ); - -                // Wait for the message to be sent -                wait.await?; +                // Requesting +                node.submit_ticket(request) +                    .instrument(info_span!("waiting for ticket response")) +                    .await?; -                Ok(tx) +                Ok(())              }              .instrument(Span::current()),          )      } + +    pub fn submit_headers( +        &self, +        path: String, +        headers: HashMap<String, String>, +    ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> { +        let remotes = self.remotes.clone(); +        Box::pin(async move { +            let mut node = remotes +                .write() +                .instrument(trace_span!("acquiring ring lock")) +                .await +                .get(&path) +                .and_then(|node| Some(node.clone())) +                .ok_or_else(|| { +                    anyhow!("did not compute ratelimit because no ratelimiter nodes are detected") +                })?; + +            let span = info_span!("remote request"); +            let context = span.context(); +            let time = SystemTime::now() +                .duration_since(SystemTime::UNIX_EPOCH)? +                .as_millis(); +            let mut request = Request::new(HeadersSubmitRequest { +                path, +                precise_time: time as u64, +                headers, +            }); +            global::get_text_map_propagator(|propagator| { +                propagator.inject_context(&context, &mut MetadataMap(request.metadata_mut())) +            }); + +            node.submit_headers(request).await?; + +            Ok(()) +        }) +    }  } diff --git a/exes/rest/src/ratelimit_client/remote_hashring.rs b/exes/rest/src/ratelimit_client/remote_hashring.rs index b8771e5..ac025c8 100644 --- a/exes/rest/src/ratelimit_client/remote_hashring.rs +++ b/exes/rest/src/ratelimit_client/remote_hashring.rs @@ -6,6 +6,7 @@ use std::hash::Hash;  use std::ops::Deref;  use std::ops::DerefMut;  use tonic::transport::Channel; +use tracing::debug;  #[derive(Debug, Clone)]  pub struct VNode { @@ -49,15 +50,14 @@ impl<'a> Injector for MetadataMap<'a> {  impl VNode {      pub async fn new(address: String, port: u16) -> Result<Self, tonic::transport::Error> { -        let client = -            RatelimiterClient::connect(format!("http://{}:{}", address.clone(), port)).await?; +        let host = format!("http://{}:{}", address.clone(), port); +        debug!("connecting to {}", host); +        let client = RatelimiterClient::connect(host).await?;          Ok(VNode { client, address })      }  } -unsafe impl Send for VNode {} -  #[repr(transparent)]  #[derive(Default)]  pub struct HashRingWrapper(hashring::HashRing<VNode>);  | 
