summaryrefslogtreecommitdiff
path: root/exes/rest/src/lib.rs
blob: 0910d5e3147cda2ae81088071092842b4c812a40 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
use config::ReverseProxyConfig;

use handler::handle_request;
use hyper::{
    server::conn::AddrStream,
    service::{make_service_fn, service_fn},
    Body, Client, Request, Server,
};
use leash::{AnyhowResultFuture, Component};
use opentelemetry::{global, trace::{Tracer}};
use opentelemetry_http::HeaderExtractor;
use shared::config::Settings;
use std::{convert::Infallible, sync::Arc};
use tokio::sync::oneshot;

mod config;
mod handler;
mod ratelimit_client;

pub struct ReverseProxyServer {}
impl Component for ReverseProxyServer {
    type Config = ReverseProxyConfig;
    const SERVICE_NAME: &'static str = "rest";

    fn start(
        &self,
        settings: Settings<Self::Config>,
        stop: oneshot::Receiver<()>,
    ) -> AnyhowResultFuture<()> {
        Box::pin(async move {
            // Client to the remote ratelimiters
            let ratelimiter = ratelimit_client::RemoteRatelimiter::new();
            let https = hyper_rustls::HttpsConnectorBuilder::new()
                .with_native_roots()
                .https_only()
                .enable_http1()
                .build();

            let client: Client<_, hyper::Body> = Client::builder().build(https);
            let token = Arc::new(settings.discord.token.clone());
            let service_fn = make_service_fn(move |_: &AddrStream| {
                let client = client.clone();
                let ratelimiter = ratelimiter.clone();
                let token = token.clone();
                async move {
                    Ok::<_, Infallible>(service_fn(move |request: Request<Body>| {
                        let parent_cx = global::get_text_map_propagator(|propagator| {
                            propagator.extract(&HeaderExtractor(request.headers()))
                        });
                        let _span = global::tracer("")
                            .start_with_context("handle_request", &parent_cx);

                        let client = client.clone();
                        let ratelimiter = ratelimiter.clone();
                        let token = token.clone();
                        async move {
                            let token = token.as_str();
                            handle_request(client, ratelimiter, token, request).await
                        }
                    }))
                }
            });

            let server = Server::bind(&settings.config.server.listening_adress).serve(service_fn);

            server
                .with_graceful_shutdown(async {
                    stop.await.expect("should not fail");
                })
                .await?;

            Ok(())
        })
    }

    fn new() -> Self {
        Self {}
    }
}