diff options
| author | MatthieuCoder <matthieu@matthieu-dev.xyz> | 2022-12-31 17:07:30 +0400 | 
|---|---|---|
| committer | MatthieuCoder <matthieu@matthieu-dev.xyz> | 2022-12-31 17:07:30 +0400 | 
| commit | 65652932f77ce194a10cbc8dd42f3064e2c1a132 (patch) | |
| tree | 4ca18a9317c4e561e917e9dd0cf39b695b43bc34 /exes/rest | |
| parent | a16bafdf5b0ec52fa0d73458597eee7c34ea5e7b (diff) | |
updates and bazel removal
Diffstat (limited to 'exes/rest')
| -rw-r--r-- | exes/rest/Cargo.toml | 16 | ||||
| -rw-r--r-- | exes/rest/src/config.rs | 18 | ||||
| -rw-r--r-- | exes/rest/src/main.rs | 46 | ||||
| -rw-r--r-- | exes/rest/src/proxy/mod.rs | 138 | ||||
| -rw-r--r-- | exes/rest/src/ratelimit/mod.rs | 155 | 
5 files changed, 373 insertions, 0 deletions
diff --git a/exes/rest/Cargo.toml b/exes/rest/Cargo.toml new file mode 100644 index 0000000..7b5b2b5 --- /dev/null +++ b/exes/rest/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "rest" +version = "0.1.0" +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +shared = { path = "../../libs/shared" } +hyper = { version = "0.14", features = ["full"] } +tokio = { version = "1", features = ["full"] } +serde = { version = "1.0.8", features = ["derive"] } +futures-util = "0.3.17" +hyper-tls = "0.5.0" +lazy_static = "1.4.0" +xxhash-rust = { version = "0.8.2", features = ["xxh32"] }
\ No newline at end of file diff --git a/exes/rest/src/config.rs b/exes/rest/src/config.rs new file mode 100644 index 0000000..559929f --- /dev/null +++ b/exes/rest/src/config.rs @@ -0,0 +1,18 @@ +use serde::Deserialize; + +#[derive(Debug, Deserialize, Clone, Default)] +pub struct ServerSettings { +    pub port: u16, +    pub address: String, +} + +#[derive(Debug, Deserialize, Clone, Default)] +pub struct Discord { +    pub token: String +} + +#[derive(Debug, Deserialize, Clone, Default)] +pub struct Config { +    pub server: ServerSettings, +    pub discord: Discord, +} diff --git a/exes/rest/src/main.rs b/exes/rest/src/main.rs new file mode 100644 index 0000000..9fa6ce7 --- /dev/null +++ b/exes/rest/src/main.rs @@ -0,0 +1,46 @@ +use std::{convert::Infallible, sync::Arc}; + +use crate::{config::Config, ratelimit::Ratelimiter}; +use shared::{ +    config::Settings, +    log::{error, info}, +    redis_crate::Client, +}; +use hyper::{server::conn::AddrStream, service::make_service_fn, Server}; +use std::net::ToSocketAddrs; +use tokio::sync::Mutex; + +use crate::proxy::ServiceProxy; + +mod config; +mod proxy; +mod ratelimit; + +#[tokio::main] +async fn main() { +    let settings: Settings<Config> = Settings::new("rest").unwrap(); +    let config = Arc::new(settings.config); +    let redis_client: Client = settings.redis.into(); +    let redis = Arc::new(Mutex::new( +        redis_client.get_async_connection().await.unwrap(), +    )); +    let ratelimiter = Arc::new(Ratelimiter::new(redis)); + +    let addr = format!("{}:{}", config.server.address, config.server.port) +        .to_socket_addrs() +        .unwrap() +        .next() +        .unwrap(); + +    let service_fn = make_service_fn(move |_: &AddrStream| { +        let service_proxy = ServiceProxy::new(config.clone(), ratelimiter.clone()); +        async move { Ok::<_, Infallible>(service_proxy) } +    }); + +    let server = Server::bind(&addr).serve(service_fn); + +    info!("starting ratelimit server"); +    if let Err(e) = server.await { +        error!("server error: {}", e); +    } +} diff --git a/exes/rest/src/proxy/mod.rs b/exes/rest/src/proxy/mod.rs new file mode 100644 index 0000000..65d77aa --- /dev/null +++ b/exes/rest/src/proxy/mod.rs @@ -0,0 +1,138 @@ +use crate::{config::Config, ratelimit::Ratelimiter}; +use hyper::{ +    client::HttpConnector, header::HeaderValue, http::uri::Parts, service::Service, Body, Client, +    Request, Response, Uri, +}; +use hyper_tls::HttpsConnector; +use shared::{ +    log::debug, +    prometheus::{labels, opts, register_counter, register_histogram_vec, Counter, HistogramVec}, +}; +use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; +use tokio::sync::Mutex; + +lazy_static::lazy_static! { +    static ref HTTP_COUNTER: Counter = register_counter!(opts!( +        "nova_rest_http_requests_total", +        "Number of HTTP requests made.", +        labels! {"handler" => "all",} +    )) +    .unwrap(); + +    static ref HTTP_REQ_HISTOGRAM: HistogramVec = register_histogram_vec!( +        "nova_rest_http_request_duration_seconds", +        "The HTTP request latencies in seconds.", +        &["handler"] +    ) +    .unwrap(); + +    static ref HTTP_COUNTER_STATUS: Counter = register_counter!(opts!( +        "nova_rest_http_requests_status", +        "Number of HTTP requests made by status", +        labels! {"" => ""} +    )) +    .unwrap(); +} + +#[derive(Clone)] +pub struct ServiceProxy { +    client: Client<HttpsConnector<HttpConnector>>, +    ratelimiter: Arc<Ratelimiter>, +    config: Arc<Config>, +    fail: Arc<Mutex<i32>>, +} + +impl Service<Request<Body>> for ServiceProxy { +    type Response = Response<Body>; +    type Error = hyper::Error; +    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; + +    fn poll_ready( +        &mut self, +        cx: &mut std::task::Context<'_>, +    ) -> std::task::Poll<Result<(), Self::Error>> { +        match self.client.poll_ready(cx) { +            Poll::Ready(Ok(())) => Poll::Ready(Ok(())), +            Poll::Ready(Err(e)) => Poll::Ready(Err(e)), +            Poll::Pending => Poll::Pending, +        } +    } + +    fn call(&mut self, mut req: Request<hyper::Body>) -> Self::Future { +        HTTP_COUNTER.inc(); + +        let timer = HTTP_REQ_HISTOGRAM.with_label_values(&["all"]).start_timer(); +        let host = "discord.com"; +        let mut new_parts = Parts::default(); + +        let path = req.uri().path().to_string(); + +        new_parts.scheme = Some("https".parse().unwrap()); +        new_parts.authority = Some(host.parse().unwrap()); +        new_parts.path_and_query = Some(path.parse().unwrap()); + +        *req.uri_mut() = Uri::from_parts(new_parts).unwrap(); + +        let headers = req.headers_mut(); +        headers.remove("user-agent"); +        headers.insert("Host", HeaderValue::from_str("discord.com").unwrap()); +        headers.insert( +            "Authorization", +            HeaderValue::from_str(&format!("Bot {}", self.config.discord.token)).unwrap(), +        ); + +        println!("{:?}", headers); + +        let client = self.client.clone(); +        let ratelimiter = self.ratelimiter.clone(); +        let fail = self.fail.clone(); + +        return Box::pin(async move { +            let resp = match ratelimiter.before_request(&req).await { +                Ok(allowed) => match allowed { +                    crate::ratelimit::RatelimiterResponse::Ratelimited => { +                        debug!("ratelimited"); +                        Ok(Response::builder().body("ratelimited".into()).unwrap()) +                    } +                    _ => { +                        debug!("forwarding request"); +                        match client.request(req).await { +                            Ok(mut response) => { +                                ratelimiter.after_request(&path, &response).await; +                                if response.status() != 200 { +                                    *fail.lock().await += 1 +                                } +                                response.headers_mut().insert( +                                    "x-fails", +                                    HeaderValue::from_str(&format!("{}", fail.lock().await)) +                                        .unwrap(), +                                ); +                                Ok(response) +                            } +                            Err(e) => Err(e), +                        } +                    } +                }, +                Err(e) => Ok(Response::builder() +                    .body(format!("server error: {}", e).into()) +                    .unwrap()), +            }; +            timer.observe_duration(); +            resp +        }); +    } +} + +impl ServiceProxy { +    pub fn new(config: Arc<Config>, ratelimiter: Arc<Ratelimiter>) -> Self { +        let https = HttpsConnector::new(); +        let client = Client::builder().build::<_, hyper::Body>(https); +        let fail = Arc::new(Mutex::new(0)); +        ServiceProxy { +            client, +            config, +            ratelimiter, +            fail, +        } +    } +} diff --git a/exes/rest/src/ratelimit/mod.rs b/exes/rest/src/ratelimit/mod.rs new file mode 100644 index 0000000..132bfd3 --- /dev/null +++ b/exes/rest/src/ratelimit/mod.rs @@ -0,0 +1,155 @@ +use shared::{ +    log::debug, +    redis_crate::{aio::Connection, AsyncCommands}, error::GenericError, +}; +use hyper::{Body, Request, Response}; +use std::{ +    convert::TryInto, +    sync::Arc, +    time::{SystemTime, UNIX_EPOCH}, +}; +use tokio::sync::Mutex; +use xxhash_rust::xxh32::xxh32; + +pub enum RatelimiterResponse { +    NoSuchUrl, +    Ratelimited, +    Pass, +} + +pub struct Ratelimiter { +    redis: Arc<Mutex<Connection>>, +} + +impl Ratelimiter { +    pub fn new(redis: Arc<Mutex<Connection>>) -> Ratelimiter { +        return Ratelimiter { redis }; +    } + +    pub async fn before_request( +        &self, +        request: &Request<Body>, +    ) -> Result<RatelimiterResponse, GenericError> { +        // we lookup if the route hash is stored in the redis table +        let path = request.uri().path(); +        let hash = xxh32(path.as_bytes(), 32); +        let mut redis = self.redis.lock().await; + +        let start = SystemTime::now(); +        let since_the_epoch = start +            .duration_since(UNIX_EPOCH) +            .expect("Time went backwards"); + +        // global rate litmit +        match redis +            .get::<String, Option<i32>>(format!( +                "nova:rest:ratelimit:global:{}", +                since_the_epoch.as_secs() +            )) +            .await +        { +            Ok(value) => { +                match value { +                    Some(value) => { +                        debug!("incr: {}", value); +                        if value >= 49 { +                            return Ok(RatelimiterResponse::Ratelimited); +                        } +                    } +                    None => { +                        let key = +                            format!("nova:rest:ratelimit:global:{}", since_the_epoch.as_secs()); +                        // init global ratelimit +                        redis.set_ex::<String, i32, ()>(key, 0, 2).await.unwrap(); +                    } +                } +            } +            Err(_) => { +                return Err(GenericError::StepFailed("radis ratelimit check".to_string())); +            } +        }; + +        // we lookup the corresponding bucket for this url +        match redis +            .get::<String, Option<String>>(format!("nova:rest:ratelimit:url_bucket:{}", hash)) +            .await +        { +            Ok(bucket) => match bucket { +                Some(bucket) => { +                    match redis +                        .exists::<String, bool>(format!("nova:rest:ratelimit:lock:{}", bucket)) +                        .await +                    { +                        Ok(exists) => { +                            if exists { +                                Ok(RatelimiterResponse::Ratelimited) +                            } else { +                                Ok(RatelimiterResponse::Pass) +                            } +                        } +                        Err(_) =>  Err(GenericError::StepFailed("radis ratelimit check".to_string())), +                    } +                } +                None => Ok(RatelimiterResponse::NoSuchUrl), +            }, +            Err(_) => Err(GenericError::StepFailed("radis ratelimit check".to_string())), +        } +    } + +    fn parse_headers(&self, response: &Response<Body>) -> Option<(String, i32, i32)> { +        if let Some(bucket) = response.headers().get("X-RateLimit-Bucket") { +            let bucket = bucket.to_str().unwrap().to_string(); + +            let remaining = response.headers().get("X-RateLimit-Remaining").unwrap(); +            let reset = response.headers().get("X-RateLimit-Reset-After").unwrap(); + +            let remaining_i32 = remaining.to_str().unwrap().parse::<i32>().unwrap(); +            let reset_ms_i32 = reset.to_str().unwrap().parse::<f32>().unwrap().ceil() as i32; +            return Some((bucket, remaining_i32, reset_ms_i32)); +        } else { +            None +        } +    } + +    pub async fn after_request(&self, path: &str, response: &Response<Body>) { +        let hash = xxh32(path.as_bytes(), 32); +        // verified earlier + +        let mut redis = self.redis.lock().await; + +        let start = SystemTime::now(); +        let since_the_epoch = start +            .duration_since(UNIX_EPOCH) +            .expect("Time went backwards"); + +        redis +            .incr::<String, i32, ()>( +                format!("nova:rest:ratelimit:global:{}", since_the_epoch.as_secs()), +                1, +            ) +            .await +            .unwrap(); +        if let Some((bucket, remaining, reset)) = self.parse_headers(response) { +            if remaining <= 1 { +                // we set a lock for the bucket until the timeout passes +                redis +                    .set_ex::<String, bool, ()>( +                        format!("nova:rest:ratelimit:lock:{}", bucket), +                        true, +                        reset.try_into().unwrap(), +                    ) +                    .await +                    .unwrap(); +            } + +            redis +                .set_ex::<String, String, ()>( +                    format!("nova:rest:ratelimit:url_bucket:{}", hash), +                    bucket, +                    reset.try_into().unwrap(), +                ) +                .await +                .unwrap(); +        } +    } +}  | 
