1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  | 
use config::ReverseProxyConfig;
use handler::handle_request;
use hyper::{
    server::conn::AddrStream,
    service::{make_service_fn, service_fn},
    Body, Client, Request, Server,
};
use leash::{AnyhowResultFuture, Component};
use opentelemetry::global;
use opentelemetry_http::HeaderExtractor;
use shared::config::Settings;
use std::{convert::Infallible, sync::Arc};
use tokio::sync::oneshot;
use tracing_opentelemetry::OpenTelemetrySpanExt;
mod config;
mod handler;
mod ratelimit_client;
pub struct ReverseProxyServer {}
impl Component for ReverseProxyServer {
    type Config = ReverseProxyConfig;
    const SERVICE_NAME: &'static str = "rest";
    fn start(
        &self,
        settings: Settings<Self::Config>,
        stop: oneshot::Receiver<()>,
    ) -> AnyhowResultFuture<()> {
        Box::pin(async move {
            // Client to the remote ratelimiters
            let ratelimiter = Arc::new(ratelimit_client::RemoteRatelimiter::new(
                settings.config.clone(),
            ));
            let https = hyper_rustls::HttpsConnectorBuilder::new()
                .with_native_roots()
                .https_or_http()
                .enable_http1()
                .build();
            let client: Client<_, hyper::Body> = Client::builder().build(https);
            let token = settings.config.discord.token.clone();
            let service_fn = make_service_fn(move |_: &AddrStream| {
                let client = client.clone();
                let ratelimiter = ratelimiter.clone();
                let token = token.clone();
                async move {
                    Ok::<_, Infallible>(service_fn(move |request: Request<Body>| {
                        let token = token.clone();
                        let parent_cx = global::get_text_map_propagator(|propagator| {
                            propagator.extract(&HeaderExtractor(request.headers()))
                        });
                        let span = tracing::span!(tracing::Level::INFO, "request process");
                        span.set_parent(parent_cx);
                        let client = client.clone();
                        let ratelimiter = ratelimiter.clone();
                        async move {
                            let token = token.clone();
                            let ratelimiter = ratelimiter.clone();
                            handle_request(client, ratelimiter, token, request).await
                        }
                    }))
                }
            });
            let server = Server::bind(&settings.config.server.listening_adress)
                .http1_only(true)
                .serve(service_fn);
            server
                .with_graceful_shutdown(async {
                    stop.await.expect("should not fail");
                })
                .await?;
            Ok(())
        })
    }
    fn new() -> Self {
        Self {}
    }
}
  |