summaryrefslogtreecommitdiff
path: root/exes/rest
diff options
context:
space:
mode:
authorMatthieuCoder <matthieu@matthieu-dev.xyz>2023-01-13 22:23:19 +0400
committerMatthieuCoder <matthieu@matthieu-dev.xyz>2023-01-13 22:23:19 +0400
commitb1e17001e3fce2874e4bb1354196c90a5fd7acd0 (patch)
tree95aa49d203c37afda042befcfc8c0fd679a7e1e8 /exes/rest
parentbca576c35e54d5505c62e42e8be416c797e84b6b (diff)
better ratelimit, new go structure and better build system
Diffstat (limited to 'exes/rest')
-rw-r--r--exes/rest/src/handler.rs37
-rw-r--r--exes/rest/src/lib.rs25
-rw-r--r--exes/rest/src/ratelimit_client/mod.rs169
-rw-r--r--exes/rest/src/ratelimit_client/remote_hashring.rs8
4 files changed, 123 insertions, 116 deletions
diff --git a/exes/rest/src/handler.rs b/exes/rest/src/handler.rs
index 6d58dce..f59583d 100644
--- a/exes/rest/src/handler.rs
+++ b/exes/rest/src/handler.rs
@@ -10,8 +10,9 @@ use std::{
convert::TryFrom,
hash::{Hash, Hasher},
str::FromStr,
+ sync::Arc,
};
-use tracing::{debug_span, error, instrument, Instrument};
+use tracing::{debug_span, error, info_span, instrument, Instrument};
use twilight_http_ratelimiting::{Method, Path};
use crate::ratelimit_client::RemoteRatelimiter;
@@ -37,8 +38,8 @@ fn normalize_path(request_path: &str) -> (&str, &str) {
#[instrument]
pub async fn handle_request(
client: Client<HttpsConnector<HttpConnector>, Body>,
- ratelimiter: RemoteRatelimiter,
- token: &str,
+ ratelimiter: Arc<RemoteRatelimiter>,
+ token: String,
mut request: Request<Body>,
) -> Result<Response<Body>, anyhow::Error> {
let (hash, uri_string) = {
@@ -78,19 +79,12 @@ pub async fn handle_request(
(hash.finish().to_string(), uri_string)
};
+ // waits for the request to be authorized
+ ratelimiter
+ .ticket(hash.clone())
+ .instrument(debug_span!("ticket validation request"))
+ .await?;
- let span = debug_span!("ticket validation request");
- let header_sender = match span
- .in_scope(|| ratelimiter.ticket(hash))
- .await
- {
- Ok(sender) => sender,
- Err(e) => {
- error!("Failed to receive ticket for ratelimiting: {:?}", e);
- bail!("failed to reteive ticket");
- }
- };
-
request
.headers_mut()
.insert(HOST, HeaderValue::from_static("discord.com"));
@@ -135,15 +129,18 @@ pub async fn handle_request(
}
};
- let ratelimit_headers = resp
+ let headers = resp
.headers()
.into_iter()
- .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string()))
+ .map(|(k, v)| (k.to_string(), v.to_str().map(|f| f.to_string())))
+ .filter(|f| f.1.is_ok())
+ .map(|f| (f.0, f.1.expect("errors should be filtered")))
.collect();
- if header_sender.send(ratelimit_headers).is_err() {
- error!("Error when sending ratelimit headers to ratelimiter");
- };
+ let _ = ratelimiter
+ .submit_headers(hash, headers)
+ .instrument(info_span!("submitting headers"))
+ .await;
Ok(resp)
}
diff --git a/exes/rest/src/lib.rs b/exes/rest/src/lib.rs
index 2cef631..b2287b4 100644
--- a/exes/rest/src/lib.rs
+++ b/exes/rest/src/lib.rs
@@ -7,11 +7,12 @@ use hyper::{
Body, Client, Request, Server,
};
use leash::{AnyhowResultFuture, Component};
-use opentelemetry::{global, trace::Tracer};
+use opentelemetry::{global};
use opentelemetry_http::HeaderExtractor;
use shared::config::Settings;
use std::{convert::Infallible, sync::Arc};
use tokio::sync::oneshot;
+use tracing_opentelemetry::OpenTelemetrySpanExt;
mod config;
mod handler;
@@ -29,7 +30,9 @@ impl Component for ReverseProxyServer {
) -> AnyhowResultFuture<()> {
Box::pin(async move {
// Client to the remote ratelimiters
- let ratelimiter = ratelimit_client::RemoteRatelimiter::new(settings.config.clone());
+ let ratelimiter = Arc::new(ratelimit_client::RemoteRatelimiter::new(
+ settings.config.clone(),
+ ));
let https = hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
.https_only()
@@ -37,31 +40,37 @@ impl Component for ReverseProxyServer {
.build();
let client: Client<_, hyper::Body> = Client::builder().build(https);
- let token = Arc::new(settings.discord.token.clone());
+ let token = settings.config.discord.token.clone();
+
let service_fn = make_service_fn(move |_: &AddrStream| {
let client = client.clone();
let ratelimiter = ratelimiter.clone();
let token = token.clone();
async move {
Ok::<_, Infallible>(service_fn(move |request: Request<Body>| {
+ let token = token.clone();
let parent_cx = global::get_text_map_propagator(|propagator| {
propagator.extract(&HeaderExtractor(request.headers()))
});
- let _span =
- global::tracer("").start_with_context("handle_request", &parent_cx);
+
+ let span = tracing::span!(tracing::Level::INFO, "request process");
+ span.set_parent(parent_cx);
let client = client.clone();
let ratelimiter = ratelimiter.clone();
- let token = token.clone();
+
async move {
- let token = token.as_str();
+ let token = token.clone();
+ let ratelimiter = ratelimiter.clone();
handle_request(client, ratelimiter, token, request).await
}
}))
}
});
- let server = Server::bind(&settings.config.server.listening_adress).http1_only(true).serve(service_fn);
+ let server = Server::bind(&settings.config.server.listening_adress)
+ .http1_only(true)
+ .serve(service_fn);
server
.with_graceful_shutdown(async {
diff --git a/exes/rest/src/ratelimit_client/mod.rs b/exes/rest/src/ratelimit_client/mod.rs
index 6212529..ecd489c 100644
--- a/exes/rest/src/ratelimit_client/mod.rs
+++ b/exes/rest/src/ratelimit_client/mod.rs
@@ -1,21 +1,18 @@
use crate::config::ReverseProxyConfig;
use self::remote_hashring::{HashRingWrapper, MetadataMap, VNode};
+use anyhow::anyhow;
use opentelemetry::global;
-use proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::{Data, Headers};
-use proto::nova::ratelimit::ratelimiter::BucketSubmitTicketRequest;
+use proto::nova::ratelimit::ratelimiter::{BucketSubmitTicketRequest, HeadersSubmitRequest};
use std::collections::HashMap;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
-use std::time::UNIX_EPOCH;
use std::time::{Duration, SystemTime};
-use tokio::sync::oneshot::{self};
-use tokio::sync::{broadcast, mpsc, RwLock};
-use tokio_stream::wrappers::ReceiverStream;
+use tokio::sync::{broadcast, RwLock};
use tonic::Request;
-use tracing::{debug, debug_span, instrument, Instrument, Span};
+use tracing::{debug, error, info_span, instrument, trace_span, Instrument, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
mod remote_hashring;
@@ -29,25 +26,21 @@ pub struct RemoteRatelimiter {
impl Drop for RemoteRatelimiter {
fn drop(&mut self) {
- self.stop.clone().send(()).unwrap();
+ let _ = self
+ .stop
+ .clone()
+ .send(())
+ .map_err(|_| error!("ratelimiter was already stopped"));
}
}
-type IssueTicket = Pin<
- Box<
- dyn Future<Output = anyhow::Result<oneshot::Sender<HashMap<String, String>>>>
- + Send
- + 'static,
- >,
->;
-
impl RemoteRatelimiter {
async fn get_ratelimiters(&self) -> Result<(), anyhow::Error> {
// get list of dns responses
- let responses = dns_lookup::lookup_host(&self.config.ratelimiter_address)
- .unwrap()
+ let responses = dns_lookup::lookup_host(&self.config.ratelimiter_address)?
.into_iter()
- .map(|f| f.to_string());
+ .filter(|address| address.is_ipv4())
+ .map(|address| address.to_string());
let mut write = self.remotes.write().await;
@@ -72,11 +65,19 @@ impl RemoteRatelimiter {
// Task to update the ratelimiters in the background
tokio::spawn(async move {
loop {
+ debug!("refreshing");
+
+ match obj_clone.get_ratelimiters().await {
+ Ok(_) => {
+ debug!("refreshed ratelimiting servers")
+ }
+ Err(err) => {
+ error!("refreshing ratelimiting servers failed {}", err);
+ }
+ }
+
let sleep = tokio::time::sleep(Duration::from_secs(10));
tokio::pin!(sleep);
-
- debug!("refreshing");
- obj_clone.get_ratelimiters().await.unwrap();
tokio::select! {
() = &mut sleep => {
debug!("timer elapsed");
@@ -90,79 +91,79 @@ impl RemoteRatelimiter {
}
#[instrument(name = "ticket task")]
- pub fn ticket(&self, path: String) -> IssueTicket {
+ pub fn ticket(
+ &self,
+ path: String,
+ ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> {
let remotes = self.remotes.clone();
- let (tx, rx) = oneshot::channel::<HashMap<String, String>>();
Box::pin(
async move {
- // Get node managing this path
- let mut node = (*remotes.read().await.get(&path).unwrap()).clone();
-
- // Buffers for the gRPC streaming channel.
- let (send, remote) = mpsc::channel(5);
- let (do_request, wait) = oneshot::channel();
- // Tonic requires a stream to be used; Since we use a mpsc channel, we can create a stream from it
- let stream = ReceiverStream::new(remote);
-
- let mut request = Request::new(stream);
-
- let span = debug_span!("remote request");
+ // Getting the node managing this path
+ let mut node = remotes
+ .write()
+ .instrument(trace_span!("acquiring ring lock"))
+ .await
+ .get(&path)
+ .and_then(|node| Some(node.clone()))
+ .ok_or_else(|| {
+ anyhow!(
+ "did not compute ratelimit because no ratelimiter nodes are detected"
+ )
+ })?;
+
+ // Initialize span for tracing (headers injection)
+ let span = info_span!("remote request");
let context = span.context();
+ let mut request = Request::new(BucketSubmitTicketRequest { path });
global::get_text_map_propagator(|propagator| {
propagator.inject_context(&context, &mut MetadataMap(request.metadata_mut()))
});
- // Start the grpc streaming
- let ticket = node.submit_ticket(request).await?;
-
- // First, send the request
- send.send(BucketSubmitTicketRequest {
- data: Some(Data::Path(path)),
- })
- .await?;
-
- // We continuously listen for events in the channel.
- let span = debug_span!("stream worker");
- tokio::spawn(
- async move {
- let span = debug_span!("waiting for ticket upstream");
- let message = ticket
- .into_inner()
- .message()
- .instrument(span)
- .await
- .unwrap()
- .unwrap();
-
- if message.accepted == 1 {
- debug!("request ticket was accepted");
- do_request.send(()).unwrap();
- let span = debug_span!("waiting for response headers");
- let headers = rx.instrument(span).await.unwrap();
-
- send.send(BucketSubmitTicketRequest {
- data: Some(Data::Headers(Headers {
- precise_time: SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .expect("time went backwards")
- .as_millis()
- as u64,
- headers,
- })),
- })
- .await
- .unwrap();
- }
- }
- .instrument(span),
- );
-
- // Wait for the message to be sent
- wait.await?;
+ // Requesting
+ node.submit_ticket(request)
+ .instrument(info_span!("waiting for ticket response"))
+ .await?;
- Ok(tx)
+ Ok(())
}
.instrument(Span::current()),
)
}
+
+ pub fn submit_headers(
+ &self,
+ path: String,
+ headers: HashMap<String, String>,
+ ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> {
+ let remotes = self.remotes.clone();
+ Box::pin(async move {
+ let mut node = remotes
+ .write()
+ .instrument(trace_span!("acquiring ring lock"))
+ .await
+ .get(&path)
+ .and_then(|node| Some(node.clone()))
+ .ok_or_else(|| {
+ anyhow!("did not compute ratelimit because no ratelimiter nodes are detected")
+ })?;
+
+ let span = info_span!("remote request");
+ let context = span.context();
+ let time = SystemTime::now()
+ .duration_since(SystemTime::UNIX_EPOCH)?
+ .as_millis();
+ let mut request = Request::new(HeadersSubmitRequest {
+ path,
+ precise_time: time as u64,
+ headers,
+ });
+ global::get_text_map_propagator(|propagator| {
+ propagator.inject_context(&context, &mut MetadataMap(request.metadata_mut()))
+ });
+
+ node.submit_headers(request).await?;
+
+ Ok(())
+ })
+ }
}
diff --git a/exes/rest/src/ratelimit_client/remote_hashring.rs b/exes/rest/src/ratelimit_client/remote_hashring.rs
index b8771e5..ac025c8 100644
--- a/exes/rest/src/ratelimit_client/remote_hashring.rs
+++ b/exes/rest/src/ratelimit_client/remote_hashring.rs
@@ -6,6 +6,7 @@ use std::hash::Hash;
use std::ops::Deref;
use std::ops::DerefMut;
use tonic::transport::Channel;
+use tracing::debug;
#[derive(Debug, Clone)]
pub struct VNode {
@@ -49,15 +50,14 @@ impl<'a> Injector for MetadataMap<'a> {
impl VNode {
pub async fn new(address: String, port: u16) -> Result<Self, tonic::transport::Error> {
- let client =
- RatelimiterClient::connect(format!("http://{}:{}", address.clone(), port)).await?;
+ let host = format!("http://{}:{}", address.clone(), port);
+ debug!("connecting to {}", host);
+ let client = RatelimiterClient::connect(host).await?;
Ok(VNode { client, address })
}
}
-unsafe impl Send for VNode {}
-
#[repr(transparent)]
#[derive(Default)]
pub struct HashRingWrapper(hashring::HashRing<VNode>);