]> git.puffer.fish Git - matthieu/nova.git/commitdiff
add ratelimiter service
authorMatthieuCoder <matthieu@matthieu-dev.xyz>
Mon, 2 Jan 2023 16:07:48 +0000 (20:07 +0400)
committerMatthieuCoder <matthieu@matthieu-dev.xyz>
Mon, 2 Jan 2023 16:07:48 +0000 (20:07 +0400)
Cargo.lock
exes/.gitignore [deleted file]
exes/ratelimit/Cargo.toml [new file with mode: 0644]
exes/ratelimit/src/grpc.rs [new file with mode: 0644]
exes/ratelimit/src/main.rs [new file with mode: 0644]
exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs [new file with mode: 0644]
exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs [new file with mode: 0644]
proto/nova/ratelimit/ratelimiter.proto

index 8181e8d70a519eab52261eb73cf93c0e53874d1a..08e01efe042cbfd3c1efd3469ef269db142f9b8c 100644 (file)
@@ -1792,6 +1792,7 @@ dependencies = [
  "anyhow",
  "futures-util",
  "hyper",
+ "leash",
  "proto",
  "serde",
  "serde_json",
diff --git a/exes/.gitignore b/exes/.gitignore
deleted file mode 100644 (file)
index 0f02d23..0000000
+++ /dev/null
@@ -1 +0,0 @@
-ratelimit/
\ No newline at end of file
diff --git a/exes/ratelimit/Cargo.toml b/exes/ratelimit/Cargo.toml
new file mode 100644 (file)
index 0000000..a28d2d0
--- /dev/null
@@ -0,0 +1,21 @@
+[package]
+name = "ratelimit"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+shared = { path = "../../libs/shared" }
+proto = { path = "../../libs/proto" }
+leash = { path = "../../libs/leash" }
+hyper = { version = "0.14", features = ["full"] }
+tokio = { version = "1", features = ["full"] }
+serde = { version = "1.0.8", features = ["derive"] }
+twilight-http-ratelimiting = { git = "https://github.com/MatthieuCoder/twilight.git" }
+anyhow = "*"
+futures-util = "0.3.17"
+tracing = "*"
+serde_json = { version = "1.0" }
+tonic = "0.8.3"
+tokio-stream = "0.1.11"
\ No newline at end of file
diff --git a/exes/ratelimit/src/grpc.rs b/exes/ratelimit/src/grpc.rs
new file mode 100644 (file)
index 0000000..a75c329
--- /dev/null
@@ -0,0 +1,79 @@
+
+use std::pin::Pin;
+
+use futures_util::Stream;
+use proto::nova::ratelimit::ratelimiter::{ratelimiter_server::Ratelimiter, BucketSubmitTicketResponse, BucketSubmitTicketRequest};
+use tokio::sync::mpsc;
+use tokio_stream::{wrappers::ReceiverStream, StreamExt};
+use tonic::{Request, Response, Status, Streaming};
+use twilight_http_ratelimiting::{ticket::TicketReceiver, RatelimitHeaders};
+
+use crate::redis_global_local_bucket_ratelimiter::RedisGlobalLocalBucketRatelimiter;
+
+pub struct RLServer {
+    ratelimiter: RedisGlobalLocalBucketRatelimiter,
+}
+
+impl RLServer {
+    pub fn new(ratelimiter: RedisGlobalLocalBucketRatelimiter) -> Self {
+        Self { ratelimiter }
+    }
+}
+
+#[tonic::async_trait]
+impl Ratelimiter for RLServer {
+
+    type SubmitTicketStream =
+        Pin<Box<dyn Stream<Item = Result<BucketSubmitTicketResponse, Status>> + Send>>;
+
+    async fn submit_ticket(
+        &self,
+        req: Request<Streaming<BucketSubmitTicketRequest>>,
+    ) -> Result<Response<Self::SubmitTicketStream>, Status> {
+        let mut in_stream = req.into_inner();
+        let (tx, rx) = mpsc::channel(128);
+        let imrl = self.ratelimiter.clone();
+
+        // this spawn here is required if you want to handle connection error.
+        // If we just map `in_stream` and write it back as `out_stream` the `out_stream`
+        // will be drooped when connection error occurs and error will never be propagated
+        // to mapped version of `in_stream`.
+        tokio::spawn(async move {
+            let mut receiver: Option<TicketReceiver> = None;
+            while let Some(result) = in_stream.next().await {
+                let result = result.unwrap();
+
+                match result.data.unwrap() {
+                    proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Path(path) => {
+                        let a = imrl.ticket(path).await.unwrap();
+                        receiver = Some(a);
+                        
+
+                        tx.send(Ok(BucketSubmitTicketResponse {
+                            accepted: 1
+                        })).await.unwrap();
+
+                    },
+                    proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Headers(b) => {
+                        if let Some(recv) = receiver {
+                            let recv = recv.await.unwrap();
+                            let rheaders = RatelimitHeaders::from_pairs(b.headers.iter().map(|f| (f.0.as_str(), f.1.as_bytes()))).unwrap();
+                            
+                            recv.headers(Some(rheaders)).unwrap();
+
+                            break;
+                        }
+                    },
+                }
+            }
+            println!("\tstream ended");
+        });
+
+        // echo just write the same data that was received
+        let out_stream = ReceiverStream::new(rx);
+
+        Ok(Response::new(
+            Box::pin(out_stream) as Self::SubmitTicketStream
+        ))
+    }
+}
\ No newline at end of file
diff --git a/exes/ratelimit/src/main.rs b/exes/ratelimit/src/main.rs
new file mode 100644 (file)
index 0000000..8e43b34
--- /dev/null
@@ -0,0 +1,47 @@
+use std::net::ToSocketAddrs;
+
+use futures_util::FutureExt;
+use grpc::RLServer;
+use leash::{ignite, AnyhowResultFuture, Component};
+use proto::nova::ratelimit::ratelimiter::ratelimiter_server::RatelimiterServer;
+use redis_global_local_bucket_ratelimiter::RedisGlobalLocalBucketRatelimiter;
+use shared::{config::Settings, redis_crate::Client};
+use tokio::sync::oneshot;
+use tonic::transport::Server;
+
+mod grpc;
+mod redis_global_local_bucket_ratelimiter;
+
+struct RatelimiterServerComponent {}
+impl Component for RatelimiterServerComponent {
+    type Config = ();
+    const SERVICE_NAME: &'static str = "rest";
+
+    fn start(
+        &self,
+        settings: Settings<Self::Config>,
+        stop: oneshot::Receiver<()>,
+    ) -> AnyhowResultFuture<()> {
+        Box::pin(async move {
+            // let config = Arc::new(settings.config);
+            let redis: Client = settings.redis.into();
+            let server = RLServer::new(RedisGlobalLocalBucketRatelimiter::new(redis.into()));
+
+            Server::builder()
+                .add_service(RatelimiterServer::new(server))
+                .serve_with_shutdown(
+                    "0.0.0.0:8080".to_socket_addrs().unwrap().next().unwrap(),
+                    stop.map(|_| ()),
+                )
+                .await?;
+
+            Ok(())
+        })
+    }
+
+    fn new() -> Self {
+        Self {}
+    }
+}
+
+ignite!(RatelimiterServerComponent);
diff --git a/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs
new file mode 100644 (file)
index 0000000..d739acf
--- /dev/null
@@ -0,0 +1,323 @@
+//! [`Bucket`] management used by the [`super::InMemoryRatelimiter`] internally.
+//! Each bucket has an associated [`BucketQueue`] to queue an API request, which is
+//! consumed by the [`BucketQueueTask`] that manages the ratelimit for the bucket
+//! and respects the global ratelimit.
+
+use super::RedisLockPair;
+use twilight_http_ratelimiting::{headers::RatelimitHeaders, ticket::TicketNotifier};
+use std::{
+    collections::HashMap,
+    mem,
+    sync::{
+        atomic::{AtomicU64, Ordering},
+        Arc, Mutex,
+    },
+    time::{Duration, Instant},
+};
+use tokio::{
+    sync::{
+        mpsc::{self, UnboundedReceiver, UnboundedSender},
+        Mutex as AsyncMutex,
+    },
+    time::{sleep, timeout},
+};
+
+/// Time remaining until a bucket will reset.
+#[derive(Clone, Debug)]
+pub enum TimeRemaining {
+    /// Bucket has already reset.
+    Finished,
+    /// Bucket's ratelimit refresh countdown has not started yet.
+    NotStarted,
+    /// Amount of time until the bucket resets.
+    Some(Duration),
+}
+
+/// Ratelimit information for a bucket used in the [`super::InMemoryRatelimiter`].
+///
+/// A generic version not specific to this ratelimiter is [`crate::Bucket`].
+#[derive(Debug)]
+pub struct Bucket {
+    /// Total number of tickets allotted in a cycle.
+    pub limit: AtomicU64,
+    /// Path this ratelimit applies to.
+    pub path: String,
+    /// Queue associated with this bucket.
+    pub queue: BucketQueue,
+    /// Number of tickets remaining.
+    pub remaining: AtomicU64,
+    /// Duration after the [`Self::started_at`] time the bucket will refresh.
+    pub reset_after: AtomicU64,
+    /// When the bucket's ratelimit refresh countdown started.
+    pub started_at: Mutex<Option<Instant>>,
+}
+
+impl Bucket {
+    /// Create a new bucket for the specified [`Path`].
+    pub fn new(path: String) -> Self {
+        Self {
+            limit: AtomicU64::new(u64::max_value()),
+            path,
+            queue: BucketQueue::default(),
+            remaining: AtomicU64::new(u64::max_value()),
+            reset_after: AtomicU64::new(u64::max_value()),
+            started_at: Mutex::new(None),
+        }
+    }
+
+    /// Total number of tickets allotted in a cycle.
+    pub fn limit(&self) -> u64 {
+        self.limit.load(Ordering::Relaxed)
+    }
+
+    /// Number of tickets remaining.
+    pub fn remaining(&self) -> u64 {
+        self.remaining.load(Ordering::Relaxed)
+    }
+
+    /// Duration after the [`started_at`] time the bucket will refresh.
+    ///
+    /// [`started_at`]: Self::started_at
+    pub fn reset_after(&self) -> u64 {
+        self.reset_after.load(Ordering::Relaxed)
+    }
+
+    /// Time remaining until this bucket will reset.
+    pub fn time_remaining(&self) -> TimeRemaining {
+        let reset_after = self.reset_after();
+        let maybe_started_at = *self.started_at.lock().expect("bucket poisoned");
+
+        let started_at = if let Some(started_at) = maybe_started_at {
+            started_at
+        } else {
+            return TimeRemaining::NotStarted;
+        };
+
+        let elapsed = started_at.elapsed();
+
+        if elapsed > Duration::from_millis(reset_after) {
+            return TimeRemaining::Finished;
+        }
+
+        TimeRemaining::Some(Duration::from_millis(reset_after) - elapsed)
+    }
+
+    /// Try to reset this bucket's [`started_at`] value if it has finished.
+    ///
+    /// Returns whether resetting was possible.
+    ///
+    /// [`started_at`]: Self::started_at
+    pub fn try_reset(&self) -> bool {
+        if self.started_at.lock().expect("bucket poisoned").is_none() {
+            return false;
+        }
+
+        if let TimeRemaining::Finished = self.time_remaining() {
+            self.remaining.store(self.limit(), Ordering::Relaxed);
+            *self.started_at.lock().expect("bucket poisoned") = None;
+
+            true
+        } else {
+            false
+        }
+    }
+
+    /// Update this bucket's ratelimit data after a request has been made.
+    pub fn update(&self, ratelimits: Option<(u64, u64, u64)>) {
+        let bucket_limit = self.limit();
+
+        {
+            let mut started_at = self.started_at.lock().expect("bucket poisoned");
+
+            if started_at.is_none() {
+                started_at.replace(Instant::now());
+            }
+        }
+
+        if let Some((limit, remaining, reset_after)) = ratelimits {
+            if bucket_limit != limit && bucket_limit == u64::max_value() {
+                self.reset_after.store(reset_after, Ordering::SeqCst);
+                self.limit.store(limit, Ordering::SeqCst);
+            }
+
+            self.remaining.store(remaining, Ordering::Relaxed);
+        } else {
+            self.remaining.fetch_sub(1, Ordering::Relaxed);
+        }
+    }
+}
+
+/// Queue of ratelimit requests for a bucket.
+#[derive(Debug)]
+pub struct BucketQueue {
+    /// Receiver for the ratelimit requests.
+    rx: AsyncMutex<UnboundedReceiver<TicketNotifier>>,
+    /// Sender for the ratelimit requests.
+    tx: UnboundedSender<TicketNotifier>,
+}
+
+impl BucketQueue {
+    /// Add a new ratelimit request to the queue.
+    pub fn push(&self, tx: TicketNotifier) {
+        let _sent = self.tx.send(tx);
+    }
+
+    /// Receive the first incoming ratelimit request.
+    pub async fn pop(&self, timeout_duration: Duration) -> Option<TicketNotifier> {
+        let mut rx = self.rx.lock().await;
+
+        timeout(timeout_duration, rx.recv()).await.ok().flatten()
+    }
+}
+
+impl Default for BucketQueue {
+    fn default() -> Self {
+        let (tx, rx) = mpsc::unbounded_channel();
+
+        Self {
+            rx: AsyncMutex::new(rx),
+            tx,
+        }
+    }
+}
+
+/// A background task that handles ratelimit requests to a [`Bucket`]
+/// and processes them in order, keeping track of both the global and
+/// the [`Path`]-specific ratelimits.
+pub(super) struct BucketQueueTask {
+    /// The [`Bucket`] managed by this task.
+    bucket: Arc<Bucket>,
+    /// All buckets managed by the associated [`super::InMemoryRatelimiter`].
+    buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
+    /// Global ratelimit data.
+    global: Arc<RedisLockPair>,
+    /// The [`Path`] this [`Bucket`] belongs to.
+    path: String,
+}
+
+impl BucketQueueTask {
+    /// Timeout to wait for response headers after initiating a request.
+    const WAIT: Duration = Duration::from_secs(10);
+
+    /// Create a new task to manage the ratelimit for a [`Bucket`].
+    pub fn new(
+        bucket: Arc<Bucket>,
+        buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
+        global: Arc<RedisLockPair>,
+        path: String,
+    ) -> Self {
+        Self {
+            bucket,
+            buckets,
+            global,
+            path,
+        }
+    }
+
+    /// Process incoming ratelimit requests to this bucket and update the state
+    /// based on received [`RatelimitHeaders`].
+    #[tracing::instrument(name = "background queue task", skip(self), fields(path = ?self.path))]
+    pub async fn run(self) {
+        while let Some(queue_tx) = self.next().await {
+            if self.global.is_locked().await {
+                mem::drop(self.global.0.lock().await);
+            }
+
+            let ticket_headers = if let Some(ticket_headers) = queue_tx.available() {
+                ticket_headers
+            } else {
+                continue;
+            };
+
+            tracing::debug!("starting to wait for response headers");
+
+            match timeout(Self::WAIT, ticket_headers).await {
+                Ok(Ok(Some(headers))) => self.handle_headers(&headers).await,
+                Ok(Ok(None)) => {
+                    tracing::debug!("request aborted");
+                }
+                Ok(Err(_)) => {
+                    tracing::debug!("ticket channel closed");
+                }
+                Err(_) => {
+                    tracing::debug!("receiver timed out");
+                }
+            }
+        }
+
+        tracing::debug!("bucket appears finished, removing");
+
+        self.buckets
+            .lock()
+            .expect("ratelimit buckets poisoned")
+            .remove(&self.path);
+    }
+
+    /// Update the bucket's ratelimit state.
+    async fn handle_headers(&self, headers: &RatelimitHeaders) {
+        let ratelimits = match headers {
+            RatelimitHeaders::Global(global) => {
+                self.lock_global(Duration::from_secs(global.retry_after()))
+                    .await;
+
+                None
+            }
+            RatelimitHeaders::None => return,
+            RatelimitHeaders::Present(present) => {
+                Some((present.limit(), present.remaining(), present.reset_after()))
+            },
+            _=> unreachable!()
+        };
+
+        tracing::debug!(path=?self.path, "updating bucket");
+        self.bucket.update(ratelimits);
+    }
+
+    /// Lock the global ratelimit for a specified duration.
+    async fn lock_global(&self, wait: Duration) {
+        tracing::debug!(path=?self.path, "request got global ratelimited");
+        self.global.lock_for(wait).await;
+    }
+
+    /// Get the next [`TicketNotifier`] in the queue.
+    async fn next(&self) -> Option<TicketNotifier> {
+        tracing::debug!(path=?self.path, "starting to get next in queue");
+
+        self.wait_if_needed().await;
+
+        self.bucket.queue.pop(Self::WAIT).await
+    }
+
+    /// Wait for this bucket to refresh if it isn't ready yet.
+    #[tracing::instrument(name = "waiting for bucket to refresh", skip(self), fields(path = ?self.path))]
+    async fn wait_if_needed(&self) {
+        let wait = {
+            if self.bucket.remaining() > 0 {
+                return;
+            }
+
+            tracing::debug!("0 tickets remaining, may have to wait");
+
+            match self.bucket.time_remaining() {
+                TimeRemaining::Finished => {
+                    self.bucket.try_reset();
+
+                    return;
+                }
+                TimeRemaining::NotStarted => return,
+                TimeRemaining::Some(dur) => dur,
+            }
+        };
+
+        tracing::debug!(
+            milliseconds=%wait.as_millis(),
+            "waiting for ratelimit to pass",
+        );
+
+        sleep(wait).await;
+
+        tracing::debug!("done waiting for ratelimit to pass");
+
+        self.bucket.try_reset();
+    }
+}
diff --git a/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs
new file mode 100644 (file)
index 0000000..c759db9
--- /dev/null
@@ -0,0 +1,99 @@
+use self::bucket::{Bucket, BucketQueueTask};
+use shared::redis_crate::{Client, Commands};
+use twilight_http_ratelimiting::ticket::{self, TicketNotifier};
+use twilight_http_ratelimiting::GetTicketFuture;
+mod bucket;
+
+use futures_util::future;
+use std::{
+    collections::hash_map::{Entry, HashMap},
+    sync::{Arc, Mutex},
+    time::Duration,
+};
+
+#[derive(Debug)]
+struct RedisLockPair(tokio::sync::Mutex<Client>);
+
+impl RedisLockPair {
+    /// Set the global ratelimit as exhausted.
+    pub async fn lock_for(&self, duration: Duration) {
+        let _: () = self
+            .0
+            .lock()
+            .await
+            .set_ex(
+                "nova:rls:lock",
+                1,
+                (duration.as_secs() + 1).try_into().unwrap(),
+            )
+            .unwrap();
+    }
+
+    pub async fn is_locked(&self) -> bool {
+        self.0.lock().await.exists("nova:rls:lock").unwrap()
+    }
+}
+
+#[derive(Clone, Debug)]
+pub struct RedisGlobalLocalBucketRatelimiter {
+    buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
+
+    global: Arc<RedisLockPair>,
+}
+
+impl RedisGlobalLocalBucketRatelimiter {
+    #[must_use]
+    pub fn new(redis: tokio::sync::Mutex<Client>) -> Self {
+        Self {
+            buckets: Arc::default(),
+            global: Arc::new(RedisLockPair(redis)),
+        }
+    }
+
+    fn entry(&self, path: String, tx: TicketNotifier) -> Option<Arc<Bucket>> {
+        let mut buckets = self.buckets.lock().expect("buckets poisoned");
+
+        match buckets.entry(path.clone()) {
+            Entry::Occupied(bucket) => {
+                tracing::debug!("got existing bucket: {path:?}");
+
+                bucket.get().queue.push(tx);
+
+                tracing::debug!("added request into bucket queue: {path:?}");
+
+                None
+            }
+            Entry::Vacant(entry) => {
+                tracing::debug!("making new bucket for path: {path:?}");
+
+                let bucket = Bucket::new(path);
+                bucket.queue.push(tx);
+
+                let bucket = Arc::new(bucket);
+                entry.insert(Arc::clone(&bucket));
+
+                Some(bucket)
+            }
+        }
+    }
+
+    pub fn ticket(&self, path: String) -> GetTicketFuture {
+        tracing::debug!("getting bucket for path: {path:?}");
+
+        let (tx, rx) = ticket::channel();
+
+        if let Some(bucket) = self.entry(path.clone(), tx) {
+            tokio::spawn(
+                BucketQueueTask::new(
+                    bucket,
+                    Arc::clone(&self.buckets),
+                    Arc::clone(&self.global),
+                    path,
+                )
+                .run(),
+            );
+        }
+
+        Box::pin(future::ok(rx))
+    }
+}
index 34d5b6fe9a71c456371bef6fc8a3df12fa7104be..3d7c75af7a7dd3b21a19807017f38204260323f8 100644 (file)
@@ -3,22 +3,9 @@ syntax = "proto3";
 package nova.ratelimit.ratelimiter;
 
 service Ratelimiter {
-    rpc GetBucketInformation(BucketInformationRequest) returns (BucketInformationResponse);
     rpc SubmitTicket(stream BucketSubmitTicketRequest) returns (stream BucketSubmitTicketResponse);
 }
 
-message BucketInformationRequest {
-    string path = 1;
-}
-
-message BucketInformationResponse {
-    uint64 limit = 1;
-    uint64 remaining = 2;
-    uint64 reset_after = 3;
-    uint64 started_at = 4;
-}
-
-
 message BucketSubmitTicketRequest {
     oneof data {
         string path = 1;