summaryrefslogtreecommitdiff
path: root/exes/rest
diff options
context:
space:
mode:
authorMatthieuCoder <matthieu@matthieu-dev.xyz>2022-12-31 17:07:30 +0400
committerMatthieuCoder <matthieu@matthieu-dev.xyz>2022-12-31 17:07:30 +0400
commit65652932f77ce194a10cbc8dd42f3064e2c1a132 (patch)
tree4ca18a9317c4e561e917e9dd0cf39b695b43bc34 /exes/rest
parenta16bafdf5b0ec52fa0d73458597eee7c34ea5e7b (diff)
updates and bazel removal
Diffstat (limited to 'exes/rest')
-rw-r--r--exes/rest/Cargo.toml16
-rw-r--r--exes/rest/src/config.rs18
-rw-r--r--exes/rest/src/main.rs46
-rw-r--r--exes/rest/src/proxy/mod.rs138
-rw-r--r--exes/rest/src/ratelimit/mod.rs155
5 files changed, 373 insertions, 0 deletions
diff --git a/exes/rest/Cargo.toml b/exes/rest/Cargo.toml
new file mode 100644
index 0000000..7b5b2b5
--- /dev/null
+++ b/exes/rest/Cargo.toml
@@ -0,0 +1,16 @@
+[package]
+name = "rest"
+version = "0.1.0"
+edition = "2018"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+shared = { path = "../../libs/shared" }
+hyper = { version = "0.14", features = ["full"] }
+tokio = { version = "1", features = ["full"] }
+serde = { version = "1.0.8", features = ["derive"] }
+futures-util = "0.3.17"
+hyper-tls = "0.5.0"
+lazy_static = "1.4.0"
+xxhash-rust = { version = "0.8.2", features = ["xxh32"] } \ No newline at end of file
diff --git a/exes/rest/src/config.rs b/exes/rest/src/config.rs
new file mode 100644
index 0000000..559929f
--- /dev/null
+++ b/exes/rest/src/config.rs
@@ -0,0 +1,18 @@
+use serde::Deserialize;
+
+#[derive(Debug, Deserialize, Clone, Default)]
+pub struct ServerSettings {
+ pub port: u16,
+ pub address: String,
+}
+
+#[derive(Debug, Deserialize, Clone, Default)]
+pub struct Discord {
+ pub token: String
+}
+
+#[derive(Debug, Deserialize, Clone, Default)]
+pub struct Config {
+ pub server: ServerSettings,
+ pub discord: Discord,
+}
diff --git a/exes/rest/src/main.rs b/exes/rest/src/main.rs
new file mode 100644
index 0000000..9fa6ce7
--- /dev/null
+++ b/exes/rest/src/main.rs
@@ -0,0 +1,46 @@
+use std::{convert::Infallible, sync::Arc};
+
+use crate::{config::Config, ratelimit::Ratelimiter};
+use shared::{
+ config::Settings,
+ log::{error, info},
+ redis_crate::Client,
+};
+use hyper::{server::conn::AddrStream, service::make_service_fn, Server};
+use std::net::ToSocketAddrs;
+use tokio::sync::Mutex;
+
+use crate::proxy::ServiceProxy;
+
+mod config;
+mod proxy;
+mod ratelimit;
+
+#[tokio::main]
+async fn main() {
+ let settings: Settings<Config> = Settings::new("rest").unwrap();
+ let config = Arc::new(settings.config);
+ let redis_client: Client = settings.redis.into();
+ let redis = Arc::new(Mutex::new(
+ redis_client.get_async_connection().await.unwrap(),
+ ));
+ let ratelimiter = Arc::new(Ratelimiter::new(redis));
+
+ let addr = format!("{}:{}", config.server.address, config.server.port)
+ .to_socket_addrs()
+ .unwrap()
+ .next()
+ .unwrap();
+
+ let service_fn = make_service_fn(move |_: &AddrStream| {
+ let service_proxy = ServiceProxy::new(config.clone(), ratelimiter.clone());
+ async move { Ok::<_, Infallible>(service_proxy) }
+ });
+
+ let server = Server::bind(&addr).serve(service_fn);
+
+ info!("starting ratelimit server");
+ if let Err(e) = server.await {
+ error!("server error: {}", e);
+ }
+}
diff --git a/exes/rest/src/proxy/mod.rs b/exes/rest/src/proxy/mod.rs
new file mode 100644
index 0000000..65d77aa
--- /dev/null
+++ b/exes/rest/src/proxy/mod.rs
@@ -0,0 +1,138 @@
+use crate::{config::Config, ratelimit::Ratelimiter};
+use hyper::{
+ client::HttpConnector, header::HeaderValue, http::uri::Parts, service::Service, Body, Client,
+ Request, Response, Uri,
+};
+use hyper_tls::HttpsConnector;
+use shared::{
+ log::debug,
+ prometheus::{labels, opts, register_counter, register_histogram_vec, Counter, HistogramVec},
+};
+use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
+use tokio::sync::Mutex;
+
+lazy_static::lazy_static! {
+ static ref HTTP_COUNTER: Counter = register_counter!(opts!(
+ "nova_rest_http_requests_total",
+ "Number of HTTP requests made.",
+ labels! {"handler" => "all",}
+ ))
+ .unwrap();
+
+ static ref HTTP_REQ_HISTOGRAM: HistogramVec = register_histogram_vec!(
+ "nova_rest_http_request_duration_seconds",
+ "The HTTP request latencies in seconds.",
+ &["handler"]
+ )
+ .unwrap();
+
+ static ref HTTP_COUNTER_STATUS: Counter = register_counter!(opts!(
+ "nova_rest_http_requests_status",
+ "Number of HTTP requests made by status",
+ labels! {"" => ""}
+ ))
+ .unwrap();
+}
+
+#[derive(Clone)]
+pub struct ServiceProxy {
+ client: Client<HttpsConnector<HttpConnector>>,
+ ratelimiter: Arc<Ratelimiter>,
+ config: Arc<Config>,
+ fail: Arc<Mutex<i32>>,
+}
+
+impl Service<Request<Body>> for ServiceProxy {
+ type Response = Response<Body>;
+ type Error = hyper::Error;
+ type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
+
+ fn poll_ready(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), Self::Error>> {
+ match self.client.poll_ready(cx) {
+ Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
+ Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
+ Poll::Pending => Poll::Pending,
+ }
+ }
+
+ fn call(&mut self, mut req: Request<hyper::Body>) -> Self::Future {
+ HTTP_COUNTER.inc();
+
+ let timer = HTTP_REQ_HISTOGRAM.with_label_values(&["all"]).start_timer();
+ let host = "discord.com";
+ let mut new_parts = Parts::default();
+
+ let path = req.uri().path().to_string();
+
+ new_parts.scheme = Some("https".parse().unwrap());
+ new_parts.authority = Some(host.parse().unwrap());
+ new_parts.path_and_query = Some(path.parse().unwrap());
+
+ *req.uri_mut() = Uri::from_parts(new_parts).unwrap();
+
+ let headers = req.headers_mut();
+ headers.remove("user-agent");
+ headers.insert("Host", HeaderValue::from_str("discord.com").unwrap());
+ headers.insert(
+ "Authorization",
+ HeaderValue::from_str(&format!("Bot {}", self.config.discord.token)).unwrap(),
+ );
+
+ println!("{:?}", headers);
+
+ let client = self.client.clone();
+ let ratelimiter = self.ratelimiter.clone();
+ let fail = self.fail.clone();
+
+ return Box::pin(async move {
+ let resp = match ratelimiter.before_request(&req).await {
+ Ok(allowed) => match allowed {
+ crate::ratelimit::RatelimiterResponse::Ratelimited => {
+ debug!("ratelimited");
+ Ok(Response::builder().body("ratelimited".into()).unwrap())
+ }
+ _ => {
+ debug!("forwarding request");
+ match client.request(req).await {
+ Ok(mut response) => {
+ ratelimiter.after_request(&path, &response).await;
+ if response.status() != 200 {
+ *fail.lock().await += 1
+ }
+ response.headers_mut().insert(
+ "x-fails",
+ HeaderValue::from_str(&format!("{}", fail.lock().await))
+ .unwrap(),
+ );
+ Ok(response)
+ }
+ Err(e) => Err(e),
+ }
+ }
+ },
+ Err(e) => Ok(Response::builder()
+ .body(format!("server error: {}", e).into())
+ .unwrap()),
+ };
+ timer.observe_duration();
+ resp
+ });
+ }
+}
+
+impl ServiceProxy {
+ pub fn new(config: Arc<Config>, ratelimiter: Arc<Ratelimiter>) -> Self {
+ let https = HttpsConnector::new();
+ let client = Client::builder().build::<_, hyper::Body>(https);
+ let fail = Arc::new(Mutex::new(0));
+ ServiceProxy {
+ client,
+ config,
+ ratelimiter,
+ fail,
+ }
+ }
+}
diff --git a/exes/rest/src/ratelimit/mod.rs b/exes/rest/src/ratelimit/mod.rs
new file mode 100644
index 0000000..132bfd3
--- /dev/null
+++ b/exes/rest/src/ratelimit/mod.rs
@@ -0,0 +1,155 @@
+use shared::{
+ log::debug,
+ redis_crate::{aio::Connection, AsyncCommands}, error::GenericError,
+};
+use hyper::{Body, Request, Response};
+use std::{
+ convert::TryInto,
+ sync::Arc,
+ time::{SystemTime, UNIX_EPOCH},
+};
+use tokio::sync::Mutex;
+use xxhash_rust::xxh32::xxh32;
+
+pub enum RatelimiterResponse {
+ NoSuchUrl,
+ Ratelimited,
+ Pass,
+}
+
+pub struct Ratelimiter {
+ redis: Arc<Mutex<Connection>>,
+}
+
+impl Ratelimiter {
+ pub fn new(redis: Arc<Mutex<Connection>>) -> Ratelimiter {
+ return Ratelimiter { redis };
+ }
+
+ pub async fn before_request(
+ &self,
+ request: &Request<Body>,
+ ) -> Result<RatelimiterResponse, GenericError> {
+ // we lookup if the route hash is stored in the redis table
+ let path = request.uri().path();
+ let hash = xxh32(path.as_bytes(), 32);
+ let mut redis = self.redis.lock().await;
+
+ let start = SystemTime::now();
+ let since_the_epoch = start
+ .duration_since(UNIX_EPOCH)
+ .expect("Time went backwards");
+
+ // global rate litmit
+ match redis
+ .get::<String, Option<i32>>(format!(
+ "nova:rest:ratelimit:global:{}",
+ since_the_epoch.as_secs()
+ ))
+ .await
+ {
+ Ok(value) => {
+ match value {
+ Some(value) => {
+ debug!("incr: {}", value);
+ if value >= 49 {
+ return Ok(RatelimiterResponse::Ratelimited);
+ }
+ }
+ None => {
+ let key =
+ format!("nova:rest:ratelimit:global:{}", since_the_epoch.as_secs());
+ // init global ratelimit
+ redis.set_ex::<String, i32, ()>(key, 0, 2).await.unwrap();
+ }
+ }
+ }
+ Err(_) => {
+ return Err(GenericError::StepFailed("radis ratelimit check".to_string()));
+ }
+ };
+
+ // we lookup the corresponding bucket for this url
+ match redis
+ .get::<String, Option<String>>(format!("nova:rest:ratelimit:url_bucket:{}", hash))
+ .await
+ {
+ Ok(bucket) => match bucket {
+ Some(bucket) => {
+ match redis
+ .exists::<String, bool>(format!("nova:rest:ratelimit:lock:{}", bucket))
+ .await
+ {
+ Ok(exists) => {
+ if exists {
+ Ok(RatelimiterResponse::Ratelimited)
+ } else {
+ Ok(RatelimiterResponse::Pass)
+ }
+ }
+ Err(_) => Err(GenericError::StepFailed("radis ratelimit check".to_string())),
+ }
+ }
+ None => Ok(RatelimiterResponse::NoSuchUrl),
+ },
+ Err(_) => Err(GenericError::StepFailed("radis ratelimit check".to_string())),
+ }
+ }
+
+ fn parse_headers(&self, response: &Response<Body>) -> Option<(String, i32, i32)> {
+ if let Some(bucket) = response.headers().get("X-RateLimit-Bucket") {
+ let bucket = bucket.to_str().unwrap().to_string();
+
+ let remaining = response.headers().get("X-RateLimit-Remaining").unwrap();
+ let reset = response.headers().get("X-RateLimit-Reset-After").unwrap();
+
+ let remaining_i32 = remaining.to_str().unwrap().parse::<i32>().unwrap();
+ let reset_ms_i32 = reset.to_str().unwrap().parse::<f32>().unwrap().ceil() as i32;
+ return Some((bucket, remaining_i32, reset_ms_i32));
+ } else {
+ None
+ }
+ }
+
+ pub async fn after_request(&self, path: &str, response: &Response<Body>) {
+ let hash = xxh32(path.as_bytes(), 32);
+ // verified earlier
+
+ let mut redis = self.redis.lock().await;
+
+ let start = SystemTime::now();
+ let since_the_epoch = start
+ .duration_since(UNIX_EPOCH)
+ .expect("Time went backwards");
+
+ redis
+ .incr::<String, i32, ()>(
+ format!("nova:rest:ratelimit:global:{}", since_the_epoch.as_secs()),
+ 1,
+ )
+ .await
+ .unwrap();
+ if let Some((bucket, remaining, reset)) = self.parse_headers(response) {
+ if remaining <= 1 {
+ // we set a lock for the bucket until the timeout passes
+ redis
+ .set_ex::<String, bool, ()>(
+ format!("nova:rest:ratelimit:lock:{}", bucket),
+ true,
+ reset.try_into().unwrap(),
+ )
+ .await
+ .unwrap();
+ }
+
+ redis
+ .set_ex::<String, String, ()>(
+ format!("nova:rest:ratelimit:url_bucket:{}", hash),
+ bucket,
+ reset.try_into().unwrap(),
+ )
+ .await
+ .unwrap();
+ }
+ }
+}