--- /dev/null
+
+use std::pin::Pin;
+
+use futures_util::Stream;
+use proto::nova::ratelimit::ratelimiter::{ratelimiter_server::Ratelimiter, BucketSubmitTicketResponse, BucketSubmitTicketRequest};
+use tokio::sync::mpsc;
+use tokio_stream::{wrappers::ReceiverStream, StreamExt};
+use tonic::{Request, Response, Status, Streaming};
+use twilight_http_ratelimiting::{ticket::TicketReceiver, RatelimitHeaders};
+
+use crate::redis_global_local_bucket_ratelimiter::RedisGlobalLocalBucketRatelimiter;
+
+pub struct RLServer {
+ ratelimiter: RedisGlobalLocalBucketRatelimiter,
+}
+
+impl RLServer {
+ pub fn new(ratelimiter: RedisGlobalLocalBucketRatelimiter) -> Self {
+ Self { ratelimiter }
+ }
+}
+
+#[tonic::async_trait]
+impl Ratelimiter for RLServer {
+
+ type SubmitTicketStream =
+ Pin<Box<dyn Stream<Item = Result<BucketSubmitTicketResponse, Status>> + Send>>;
+
+ async fn submit_ticket(
+ &self,
+ req: Request<Streaming<BucketSubmitTicketRequest>>,
+ ) -> Result<Response<Self::SubmitTicketStream>, Status> {
+ let mut in_stream = req.into_inner();
+ let (tx, rx) = mpsc::channel(128);
+ let imrl = self.ratelimiter.clone();
+
+ // this spawn here is required if you want to handle connection error.
+ // If we just map `in_stream` and write it back as `out_stream` the `out_stream`
+ // will be drooped when connection error occurs and error will never be propagated
+ // to mapped version of `in_stream`.
+ tokio::spawn(async move {
+ let mut receiver: Option<TicketReceiver> = None;
+ while let Some(result) = in_stream.next().await {
+ let result = result.unwrap();
+
+ match result.data.unwrap() {
+ proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Path(path) => {
+ let a = imrl.ticket(path).await.unwrap();
+ receiver = Some(a);
+
+
+ tx.send(Ok(BucketSubmitTicketResponse {
+ accepted: 1
+ })).await.unwrap();
+
+ },
+ proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Headers(b) => {
+ if let Some(recv) = receiver {
+ let recv = recv.await.unwrap();
+ let rheaders = RatelimitHeaders::from_pairs(b.headers.iter().map(|f| (f.0.as_str(), f.1.as_bytes()))).unwrap();
+
+ recv.headers(Some(rheaders)).unwrap();
+
+ break;
+ }
+ },
+ }
+ }
+ println!("\tstream ended");
+ });
+
+ // echo just write the same data that was received
+ let out_stream = ReceiverStream::new(rx);
+
+ Ok(Response::new(
+ Box::pin(out_stream) as Self::SubmitTicketStream
+ ))
+ }
+}
\ No newline at end of file
--- /dev/null
+//! [`Bucket`] management used by the [`super::InMemoryRatelimiter`] internally.
+//! Each bucket has an associated [`BucketQueue`] to queue an API request, which is
+//! consumed by the [`BucketQueueTask`] that manages the ratelimit for the bucket
+//! and respects the global ratelimit.
+
+use super::RedisLockPair;
+use twilight_http_ratelimiting::{headers::RatelimitHeaders, ticket::TicketNotifier};
+use std::{
+ collections::HashMap,
+ mem,
+ sync::{
+ atomic::{AtomicU64, Ordering},
+ Arc, Mutex,
+ },
+ time::{Duration, Instant},
+};
+use tokio::{
+ sync::{
+ mpsc::{self, UnboundedReceiver, UnboundedSender},
+ Mutex as AsyncMutex,
+ },
+ time::{sleep, timeout},
+};
+
+/// Time remaining until a bucket will reset.
+#[derive(Clone, Debug)]
+pub enum TimeRemaining {
+ /// Bucket has already reset.
+ Finished,
+ /// Bucket's ratelimit refresh countdown has not started yet.
+ NotStarted,
+ /// Amount of time until the bucket resets.
+ Some(Duration),
+}
+
+/// Ratelimit information for a bucket used in the [`super::InMemoryRatelimiter`].
+///
+/// A generic version not specific to this ratelimiter is [`crate::Bucket`].
+#[derive(Debug)]
+pub struct Bucket {
+ /// Total number of tickets allotted in a cycle.
+ pub limit: AtomicU64,
+ /// Path this ratelimit applies to.
+ pub path: String,
+ /// Queue associated with this bucket.
+ pub queue: BucketQueue,
+ /// Number of tickets remaining.
+ pub remaining: AtomicU64,
+ /// Duration after the [`Self::started_at`] time the bucket will refresh.
+ pub reset_after: AtomicU64,
+ /// When the bucket's ratelimit refresh countdown started.
+ pub started_at: Mutex<Option<Instant>>,
+}
+
+impl Bucket {
+ /// Create a new bucket for the specified [`Path`].
+ pub fn new(path: String) -> Self {
+ Self {
+ limit: AtomicU64::new(u64::max_value()),
+ path,
+ queue: BucketQueue::default(),
+ remaining: AtomicU64::new(u64::max_value()),
+ reset_after: AtomicU64::new(u64::max_value()),
+ started_at: Mutex::new(None),
+ }
+ }
+
+ /// Total number of tickets allotted in a cycle.
+ pub fn limit(&self) -> u64 {
+ self.limit.load(Ordering::Relaxed)
+ }
+
+ /// Number of tickets remaining.
+ pub fn remaining(&self) -> u64 {
+ self.remaining.load(Ordering::Relaxed)
+ }
+
+ /// Duration after the [`started_at`] time the bucket will refresh.
+ ///
+ /// [`started_at`]: Self::started_at
+ pub fn reset_after(&self) -> u64 {
+ self.reset_after.load(Ordering::Relaxed)
+ }
+
+ /// Time remaining until this bucket will reset.
+ pub fn time_remaining(&self) -> TimeRemaining {
+ let reset_after = self.reset_after();
+ let maybe_started_at = *self.started_at.lock().expect("bucket poisoned");
+
+ let started_at = if let Some(started_at) = maybe_started_at {
+ started_at
+ } else {
+ return TimeRemaining::NotStarted;
+ };
+
+ let elapsed = started_at.elapsed();
+
+ if elapsed > Duration::from_millis(reset_after) {
+ return TimeRemaining::Finished;
+ }
+
+ TimeRemaining::Some(Duration::from_millis(reset_after) - elapsed)
+ }
+
+ /// Try to reset this bucket's [`started_at`] value if it has finished.
+ ///
+ /// Returns whether resetting was possible.
+ ///
+ /// [`started_at`]: Self::started_at
+ pub fn try_reset(&self) -> bool {
+ if self.started_at.lock().expect("bucket poisoned").is_none() {
+ return false;
+ }
+
+ if let TimeRemaining::Finished = self.time_remaining() {
+ self.remaining.store(self.limit(), Ordering::Relaxed);
+ *self.started_at.lock().expect("bucket poisoned") = None;
+
+ true
+ } else {
+ false
+ }
+ }
+
+ /// Update this bucket's ratelimit data after a request has been made.
+ pub fn update(&self, ratelimits: Option<(u64, u64, u64)>) {
+ let bucket_limit = self.limit();
+
+ {
+ let mut started_at = self.started_at.lock().expect("bucket poisoned");
+
+ if started_at.is_none() {
+ started_at.replace(Instant::now());
+ }
+ }
+
+ if let Some((limit, remaining, reset_after)) = ratelimits {
+ if bucket_limit != limit && bucket_limit == u64::max_value() {
+ self.reset_after.store(reset_after, Ordering::SeqCst);
+ self.limit.store(limit, Ordering::SeqCst);
+ }
+
+ self.remaining.store(remaining, Ordering::Relaxed);
+ } else {
+ self.remaining.fetch_sub(1, Ordering::Relaxed);
+ }
+ }
+}
+
+/// Queue of ratelimit requests for a bucket.
+#[derive(Debug)]
+pub struct BucketQueue {
+ /// Receiver for the ratelimit requests.
+ rx: AsyncMutex<UnboundedReceiver<TicketNotifier>>,
+ /// Sender for the ratelimit requests.
+ tx: UnboundedSender<TicketNotifier>,
+}
+
+impl BucketQueue {
+ /// Add a new ratelimit request to the queue.
+ pub fn push(&self, tx: TicketNotifier) {
+ let _sent = self.tx.send(tx);
+ }
+
+ /// Receive the first incoming ratelimit request.
+ pub async fn pop(&self, timeout_duration: Duration) -> Option<TicketNotifier> {
+ let mut rx = self.rx.lock().await;
+
+ timeout(timeout_duration, rx.recv()).await.ok().flatten()
+ }
+}
+
+impl Default for BucketQueue {
+ fn default() -> Self {
+ let (tx, rx) = mpsc::unbounded_channel();
+
+ Self {
+ rx: AsyncMutex::new(rx),
+ tx,
+ }
+ }
+}
+
+/// A background task that handles ratelimit requests to a [`Bucket`]
+/// and processes them in order, keeping track of both the global and
+/// the [`Path`]-specific ratelimits.
+pub(super) struct BucketQueueTask {
+ /// The [`Bucket`] managed by this task.
+ bucket: Arc<Bucket>,
+ /// All buckets managed by the associated [`super::InMemoryRatelimiter`].
+ buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
+ /// Global ratelimit data.
+ global: Arc<RedisLockPair>,
+ /// The [`Path`] this [`Bucket`] belongs to.
+ path: String,
+}
+
+impl BucketQueueTask {
+ /// Timeout to wait for response headers after initiating a request.
+ const WAIT: Duration = Duration::from_secs(10);
+
+ /// Create a new task to manage the ratelimit for a [`Bucket`].
+ pub fn new(
+ bucket: Arc<Bucket>,
+ buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
+ global: Arc<RedisLockPair>,
+ path: String,
+ ) -> Self {
+ Self {
+ bucket,
+ buckets,
+ global,
+ path,
+ }
+ }
+
+ /// Process incoming ratelimit requests to this bucket and update the state
+ /// based on received [`RatelimitHeaders`].
+ #[tracing::instrument(name = "background queue task", skip(self), fields(path = ?self.path))]
+ pub async fn run(self) {
+ while let Some(queue_tx) = self.next().await {
+ if self.global.is_locked().await {
+ mem::drop(self.global.0.lock().await);
+ }
+
+ let ticket_headers = if let Some(ticket_headers) = queue_tx.available() {
+ ticket_headers
+ } else {
+ continue;
+ };
+
+ tracing::debug!("starting to wait for response headers");
+
+ match timeout(Self::WAIT, ticket_headers).await {
+ Ok(Ok(Some(headers))) => self.handle_headers(&headers).await,
+ Ok(Ok(None)) => {
+ tracing::debug!("request aborted");
+ }
+ Ok(Err(_)) => {
+ tracing::debug!("ticket channel closed");
+ }
+ Err(_) => {
+ tracing::debug!("receiver timed out");
+ }
+ }
+ }
+
+ tracing::debug!("bucket appears finished, removing");
+
+ self.buckets
+ .lock()
+ .expect("ratelimit buckets poisoned")
+ .remove(&self.path);
+ }
+
+ /// Update the bucket's ratelimit state.
+ async fn handle_headers(&self, headers: &RatelimitHeaders) {
+ let ratelimits = match headers {
+ RatelimitHeaders::Global(global) => {
+ self.lock_global(Duration::from_secs(global.retry_after()))
+ .await;
+
+ None
+ }
+ RatelimitHeaders::None => return,
+ RatelimitHeaders::Present(present) => {
+ Some((present.limit(), present.remaining(), present.reset_after()))
+ },
+ _=> unreachable!()
+ };
+
+ tracing::debug!(path=?self.path, "updating bucket");
+ self.bucket.update(ratelimits);
+ }
+
+ /// Lock the global ratelimit for a specified duration.
+ async fn lock_global(&self, wait: Duration) {
+ tracing::debug!(path=?self.path, "request got global ratelimited");
+ self.global.lock_for(wait).await;
+ }
+
+ /// Get the next [`TicketNotifier`] in the queue.
+ async fn next(&self) -> Option<TicketNotifier> {
+ tracing::debug!(path=?self.path, "starting to get next in queue");
+
+ self.wait_if_needed().await;
+
+ self.bucket.queue.pop(Self::WAIT).await
+ }
+
+ /// Wait for this bucket to refresh if it isn't ready yet.
+ #[tracing::instrument(name = "waiting for bucket to refresh", skip(self), fields(path = ?self.path))]
+ async fn wait_if_needed(&self) {
+ let wait = {
+ if self.bucket.remaining() > 0 {
+ return;
+ }
+
+ tracing::debug!("0 tickets remaining, may have to wait");
+
+ match self.bucket.time_remaining() {
+ TimeRemaining::Finished => {
+ self.bucket.try_reset();
+
+ return;
+ }
+ TimeRemaining::NotStarted => return,
+ TimeRemaining::Some(dur) => dur,
+ }
+ };
+
+ tracing::debug!(
+ milliseconds=%wait.as_millis(),
+ "waiting for ratelimit to pass",
+ );
+
+ sleep(wait).await;
+
+ tracing::debug!("done waiting for ratelimit to pass");
+
+ self.bucket.try_reset();
+ }
+}
--- /dev/null
+use self::bucket::{Bucket, BucketQueueTask};
+use shared::redis_crate::{Client, Commands};
+use twilight_http_ratelimiting::ticket::{self, TicketNotifier};
+use twilight_http_ratelimiting::GetTicketFuture;
+mod bucket;
+
+use futures_util::future;
+use std::{
+ collections::hash_map::{Entry, HashMap},
+ sync::{Arc, Mutex},
+ time::Duration,
+};
+
+#[derive(Debug)]
+struct RedisLockPair(tokio::sync::Mutex<Client>);
+
+impl RedisLockPair {
+ /// Set the global ratelimit as exhausted.
+ pub async fn lock_for(&self, duration: Duration) {
+ let _: () = self
+ .0
+ .lock()
+ .await
+ .set_ex(
+ "nova:rls:lock",
+ 1,
+ (duration.as_secs() + 1).try_into().unwrap(),
+ )
+ .unwrap();
+ }
+
+ pub async fn is_locked(&self) -> bool {
+ self.0.lock().await.exists("nova:rls:lock").unwrap()
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct RedisGlobalLocalBucketRatelimiter {
+ buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
+
+ global: Arc<RedisLockPair>,
+}
+
+impl RedisGlobalLocalBucketRatelimiter {
+ #[must_use]
+ pub fn new(redis: tokio::sync::Mutex<Client>) -> Self {
+ Self {
+ buckets: Arc::default(),
+ global: Arc::new(RedisLockPair(redis)),
+ }
+ }
+
+ fn entry(&self, path: String, tx: TicketNotifier) -> Option<Arc<Bucket>> {
+ let mut buckets = self.buckets.lock().expect("buckets poisoned");
+
+ match buckets.entry(path.clone()) {
+ Entry::Occupied(bucket) => {
+ tracing::debug!("got existing bucket: {path:?}");
+
+ bucket.get().queue.push(tx);
+
+ tracing::debug!("added request into bucket queue: {path:?}");
+
+ None
+ }
+ Entry::Vacant(entry) => {
+ tracing::debug!("making new bucket for path: {path:?}");
+
+ let bucket = Bucket::new(path);
+ bucket.queue.push(tx);
+
+ let bucket = Arc::new(bucket);
+ entry.insert(Arc::clone(&bucket));
+
+ Some(bucket)
+ }
+ }
+ }
+
+ pub fn ticket(&self, path: String) -> GetTicketFuture {
+ tracing::debug!("getting bucket for path: {path:?}");
+
+ let (tx, rx) = ticket::channel();
+
+ if let Some(bucket) = self.entry(path.clone(), tx) {
+ tokio::spawn(
+ BucketQueueTask::new(
+ bucket,
+ Arc::clone(&self.buckets),
+ Arc::clone(&self.global),
+ path,
+ )
+ .run(),
+ );
+ }
+
+ Box::pin(future::ok(rx))
+ }
+}