From: MatthieuCoder Date: Mon, 2 Jan 2023 16:07:48 +0000 (+0400) Subject: add ratelimiter service X-Git-Tag: v0.1~29 X-Git-Url: https://git.puffer.fish/?a=commitdiff_plain;h=562eea54627527130de9ceee1d791f2948e431d2;p=matthieu%2Fnova.git add ratelimiter service --- diff --git a/Cargo.lock b/Cargo.lock index 8181e8d..08e01ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1792,6 +1792,7 @@ dependencies = [ "anyhow", "futures-util", "hyper", + "leash", "proto", "serde", "serde_json", diff --git a/exes/.gitignore b/exes/.gitignore deleted file mode 100644 index 0f02d23..0000000 --- a/exes/.gitignore +++ /dev/null @@ -1 +0,0 @@ -ratelimit/ \ No newline at end of file diff --git a/exes/ratelimit/Cargo.toml b/exes/ratelimit/Cargo.toml new file mode 100644 index 0000000..a28d2d0 --- /dev/null +++ b/exes/ratelimit/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "ratelimit" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +shared = { path = "../../libs/shared" } +proto = { path = "../../libs/proto" } +leash = { path = "../../libs/leash" } +hyper = { version = "0.14", features = ["full"] } +tokio = { version = "1", features = ["full"] } +serde = { version = "1.0.8", features = ["derive"] } +twilight-http-ratelimiting = { git = "https://github.com/MatthieuCoder/twilight.git" } +anyhow = "*" +futures-util = "0.3.17" +tracing = "*" +serde_json = { version = "1.0" } +tonic = "0.8.3" +tokio-stream = "0.1.11" \ No newline at end of file diff --git a/exes/ratelimit/src/grpc.rs b/exes/ratelimit/src/grpc.rs new file mode 100644 index 0000000..a75c329 --- /dev/null +++ b/exes/ratelimit/src/grpc.rs @@ -0,0 +1,79 @@ + +use std::pin::Pin; + +use futures_util::Stream; +use proto::nova::ratelimit::ratelimiter::{ratelimiter_server::Ratelimiter, BucketSubmitTicketResponse, BucketSubmitTicketRequest}; +use tokio::sync::mpsc; +use tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use tonic::{Request, Response, Status, Streaming}; +use twilight_http_ratelimiting::{ticket::TicketReceiver, RatelimitHeaders}; + +use crate::redis_global_local_bucket_ratelimiter::RedisGlobalLocalBucketRatelimiter; + +pub struct RLServer { + ratelimiter: RedisGlobalLocalBucketRatelimiter, +} + +impl RLServer { + pub fn new(ratelimiter: RedisGlobalLocalBucketRatelimiter) -> Self { + Self { ratelimiter } + } +} + +#[tonic::async_trait] +impl Ratelimiter for RLServer { + + type SubmitTicketStream = + Pin> + Send>>; + + async fn submit_ticket( + &self, + req: Request>, + ) -> Result, Status> { + let mut in_stream = req.into_inner(); + let (tx, rx) = mpsc::channel(128); + let imrl = self.ratelimiter.clone(); + + // this spawn here is required if you want to handle connection error. + // If we just map `in_stream` and write it back as `out_stream` the `out_stream` + // will be drooped when connection error occurs and error will never be propagated + // to mapped version of `in_stream`. + tokio::spawn(async move { + let mut receiver: Option = None; + while let Some(result) = in_stream.next().await { + let result = result.unwrap(); + + match result.data.unwrap() { + proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Path(path) => { + let a = imrl.ticket(path).await.unwrap(); + receiver = Some(a); + + + tx.send(Ok(BucketSubmitTicketResponse { + accepted: 1 + })).await.unwrap(); + + }, + proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Headers(b) => { + if let Some(recv) = receiver { + let recv = recv.await.unwrap(); + let rheaders = RatelimitHeaders::from_pairs(b.headers.iter().map(|f| (f.0.as_str(), f.1.as_bytes()))).unwrap(); + + recv.headers(Some(rheaders)).unwrap(); + + break; + } + }, + } + } + println!("\tstream ended"); + }); + + // echo just write the same data that was received + let out_stream = ReceiverStream::new(rx); + + Ok(Response::new( + Box::pin(out_stream) as Self::SubmitTicketStream + )) + } +} \ No newline at end of file diff --git a/exes/ratelimit/src/main.rs b/exes/ratelimit/src/main.rs new file mode 100644 index 0000000..8e43b34 --- /dev/null +++ b/exes/ratelimit/src/main.rs @@ -0,0 +1,47 @@ +use std::net::ToSocketAddrs; + +use futures_util::FutureExt; +use grpc::RLServer; +use leash::{ignite, AnyhowResultFuture, Component}; +use proto::nova::ratelimit::ratelimiter::ratelimiter_server::RatelimiterServer; +use redis_global_local_bucket_ratelimiter::RedisGlobalLocalBucketRatelimiter; +use shared::{config::Settings, redis_crate::Client}; +use tokio::sync::oneshot; +use tonic::transport::Server; + +mod grpc; +mod redis_global_local_bucket_ratelimiter; + +struct RatelimiterServerComponent {} +impl Component for RatelimiterServerComponent { + type Config = (); + const SERVICE_NAME: &'static str = "rest"; + + fn start( + &self, + settings: Settings, + stop: oneshot::Receiver<()>, + ) -> AnyhowResultFuture<()> { + Box::pin(async move { + // let config = Arc::new(settings.config); + let redis: Client = settings.redis.into(); + let server = RLServer::new(RedisGlobalLocalBucketRatelimiter::new(redis.into())); + + Server::builder() + .add_service(RatelimiterServer::new(server)) + .serve_with_shutdown( + "0.0.0.0:8080".to_socket_addrs().unwrap().next().unwrap(), + stop.map(|_| ()), + ) + .await?; + + Ok(()) + }) + } + + fn new() -> Self { + Self {} + } +} + +ignite!(RatelimiterServerComponent); diff --git a/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs new file mode 100644 index 0000000..d739acf --- /dev/null +++ b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs @@ -0,0 +1,323 @@ +//! [`Bucket`] management used by the [`super::InMemoryRatelimiter`] internally. +//! Each bucket has an associated [`BucketQueue`] to queue an API request, which is +//! consumed by the [`BucketQueueTask`] that manages the ratelimit for the bucket +//! and respects the global ratelimit. + +use super::RedisLockPair; +use twilight_http_ratelimiting::{headers::RatelimitHeaders, ticket::TicketNotifier}; +use std::{ + collections::HashMap, + mem, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }, + time::{Duration, Instant}, +}; +use tokio::{ + sync::{ + mpsc::{self, UnboundedReceiver, UnboundedSender}, + Mutex as AsyncMutex, + }, + time::{sleep, timeout}, +}; + +/// Time remaining until a bucket will reset. +#[derive(Clone, Debug)] +pub enum TimeRemaining { + /// Bucket has already reset. + Finished, + /// Bucket's ratelimit refresh countdown has not started yet. + NotStarted, + /// Amount of time until the bucket resets. + Some(Duration), +} + +/// Ratelimit information for a bucket used in the [`super::InMemoryRatelimiter`]. +/// +/// A generic version not specific to this ratelimiter is [`crate::Bucket`]. +#[derive(Debug)] +pub struct Bucket { + /// Total number of tickets allotted in a cycle. + pub limit: AtomicU64, + /// Path this ratelimit applies to. + pub path: String, + /// Queue associated with this bucket. + pub queue: BucketQueue, + /// Number of tickets remaining. + pub remaining: AtomicU64, + /// Duration after the [`Self::started_at`] time the bucket will refresh. + pub reset_after: AtomicU64, + /// When the bucket's ratelimit refresh countdown started. + pub started_at: Mutex>, +} + +impl Bucket { + /// Create a new bucket for the specified [`Path`]. + pub fn new(path: String) -> Self { + Self { + limit: AtomicU64::new(u64::max_value()), + path, + queue: BucketQueue::default(), + remaining: AtomicU64::new(u64::max_value()), + reset_after: AtomicU64::new(u64::max_value()), + started_at: Mutex::new(None), + } + } + + /// Total number of tickets allotted in a cycle. + pub fn limit(&self) -> u64 { + self.limit.load(Ordering::Relaxed) + } + + /// Number of tickets remaining. + pub fn remaining(&self) -> u64 { + self.remaining.load(Ordering::Relaxed) + } + + /// Duration after the [`started_at`] time the bucket will refresh. + /// + /// [`started_at`]: Self::started_at + pub fn reset_after(&self) -> u64 { + self.reset_after.load(Ordering::Relaxed) + } + + /// Time remaining until this bucket will reset. + pub fn time_remaining(&self) -> TimeRemaining { + let reset_after = self.reset_after(); + let maybe_started_at = *self.started_at.lock().expect("bucket poisoned"); + + let started_at = if let Some(started_at) = maybe_started_at { + started_at + } else { + return TimeRemaining::NotStarted; + }; + + let elapsed = started_at.elapsed(); + + if elapsed > Duration::from_millis(reset_after) { + return TimeRemaining::Finished; + } + + TimeRemaining::Some(Duration::from_millis(reset_after) - elapsed) + } + + /// Try to reset this bucket's [`started_at`] value if it has finished. + /// + /// Returns whether resetting was possible. + /// + /// [`started_at`]: Self::started_at + pub fn try_reset(&self) -> bool { + if self.started_at.lock().expect("bucket poisoned").is_none() { + return false; + } + + if let TimeRemaining::Finished = self.time_remaining() { + self.remaining.store(self.limit(), Ordering::Relaxed); + *self.started_at.lock().expect("bucket poisoned") = None; + + true + } else { + false + } + } + + /// Update this bucket's ratelimit data after a request has been made. + pub fn update(&self, ratelimits: Option<(u64, u64, u64)>) { + let bucket_limit = self.limit(); + + { + let mut started_at = self.started_at.lock().expect("bucket poisoned"); + + if started_at.is_none() { + started_at.replace(Instant::now()); + } + } + + if let Some((limit, remaining, reset_after)) = ratelimits { + if bucket_limit != limit && bucket_limit == u64::max_value() { + self.reset_after.store(reset_after, Ordering::SeqCst); + self.limit.store(limit, Ordering::SeqCst); + } + + self.remaining.store(remaining, Ordering::Relaxed); + } else { + self.remaining.fetch_sub(1, Ordering::Relaxed); + } + } +} + +/// Queue of ratelimit requests for a bucket. +#[derive(Debug)] +pub struct BucketQueue { + /// Receiver for the ratelimit requests. + rx: AsyncMutex>, + /// Sender for the ratelimit requests. + tx: UnboundedSender, +} + +impl BucketQueue { + /// Add a new ratelimit request to the queue. + pub fn push(&self, tx: TicketNotifier) { + let _sent = self.tx.send(tx); + } + + /// Receive the first incoming ratelimit request. + pub async fn pop(&self, timeout_duration: Duration) -> Option { + let mut rx = self.rx.lock().await; + + timeout(timeout_duration, rx.recv()).await.ok().flatten() + } +} + +impl Default for BucketQueue { + fn default() -> Self { + let (tx, rx) = mpsc::unbounded_channel(); + + Self { + rx: AsyncMutex::new(rx), + tx, + } + } +} + +/// A background task that handles ratelimit requests to a [`Bucket`] +/// and processes them in order, keeping track of both the global and +/// the [`Path`]-specific ratelimits. +pub(super) struct BucketQueueTask { + /// The [`Bucket`] managed by this task. + bucket: Arc, + /// All buckets managed by the associated [`super::InMemoryRatelimiter`]. + buckets: Arc>>>, + /// Global ratelimit data. + global: Arc, + /// The [`Path`] this [`Bucket`] belongs to. + path: String, +} + +impl BucketQueueTask { + /// Timeout to wait for response headers after initiating a request. + const WAIT: Duration = Duration::from_secs(10); + + /// Create a new task to manage the ratelimit for a [`Bucket`]. + pub fn new( + bucket: Arc, + buckets: Arc>>>, + global: Arc, + path: String, + ) -> Self { + Self { + bucket, + buckets, + global, + path, + } + } + + /// Process incoming ratelimit requests to this bucket and update the state + /// based on received [`RatelimitHeaders`]. + #[tracing::instrument(name = "background queue task", skip(self), fields(path = ?self.path))] + pub async fn run(self) { + while let Some(queue_tx) = self.next().await { + if self.global.is_locked().await { + mem::drop(self.global.0.lock().await); + } + + let ticket_headers = if let Some(ticket_headers) = queue_tx.available() { + ticket_headers + } else { + continue; + }; + + tracing::debug!("starting to wait for response headers"); + + match timeout(Self::WAIT, ticket_headers).await { + Ok(Ok(Some(headers))) => self.handle_headers(&headers).await, + Ok(Ok(None)) => { + tracing::debug!("request aborted"); + } + Ok(Err(_)) => { + tracing::debug!("ticket channel closed"); + } + Err(_) => { + tracing::debug!("receiver timed out"); + } + } + } + + tracing::debug!("bucket appears finished, removing"); + + self.buckets + .lock() + .expect("ratelimit buckets poisoned") + .remove(&self.path); + } + + /// Update the bucket's ratelimit state. + async fn handle_headers(&self, headers: &RatelimitHeaders) { + let ratelimits = match headers { + RatelimitHeaders::Global(global) => { + self.lock_global(Duration::from_secs(global.retry_after())) + .await; + + None + } + RatelimitHeaders::None => return, + RatelimitHeaders::Present(present) => { + Some((present.limit(), present.remaining(), present.reset_after())) + }, + _=> unreachable!() + }; + + tracing::debug!(path=?self.path, "updating bucket"); + self.bucket.update(ratelimits); + } + + /// Lock the global ratelimit for a specified duration. + async fn lock_global(&self, wait: Duration) { + tracing::debug!(path=?self.path, "request got global ratelimited"); + self.global.lock_for(wait).await; + } + + /// Get the next [`TicketNotifier`] in the queue. + async fn next(&self) -> Option { + tracing::debug!(path=?self.path, "starting to get next in queue"); + + self.wait_if_needed().await; + + self.bucket.queue.pop(Self::WAIT).await + } + + /// Wait for this bucket to refresh if it isn't ready yet. + #[tracing::instrument(name = "waiting for bucket to refresh", skip(self), fields(path = ?self.path))] + async fn wait_if_needed(&self) { + let wait = { + if self.bucket.remaining() > 0 { + return; + } + + tracing::debug!("0 tickets remaining, may have to wait"); + + match self.bucket.time_remaining() { + TimeRemaining::Finished => { + self.bucket.try_reset(); + + return; + } + TimeRemaining::NotStarted => return, + TimeRemaining::Some(dur) => dur, + } + }; + + tracing::debug!( + milliseconds=%wait.as_millis(), + "waiting for ratelimit to pass", + ); + + sleep(wait).await; + + tracing::debug!("done waiting for ratelimit to pass"); + + self.bucket.try_reset(); + } +} diff --git a/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs new file mode 100644 index 0000000..c759db9 --- /dev/null +++ b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs @@ -0,0 +1,99 @@ +use self::bucket::{Bucket, BucketQueueTask}; +use shared::redis_crate::{Client, Commands}; +use twilight_http_ratelimiting::ticket::{self, TicketNotifier}; +use twilight_http_ratelimiting::GetTicketFuture; +mod bucket; + +use futures_util::future; +use std::{ + collections::hash_map::{Entry, HashMap}, + sync::{Arc, Mutex}, + time::Duration, +}; + +#[derive(Debug)] +struct RedisLockPair(tokio::sync::Mutex); + +impl RedisLockPair { + /// Set the global ratelimit as exhausted. + pub async fn lock_for(&self, duration: Duration) { + let _: () = self + .0 + .lock() + .await + .set_ex( + "nova:rls:lock", + 1, + (duration.as_secs() + 1).try_into().unwrap(), + ) + .unwrap(); + } + + pub async fn is_locked(&self) -> bool { + self.0.lock().await.exists("nova:rls:lock").unwrap() + } +} + +#[derive(Clone, Debug)] +pub struct RedisGlobalLocalBucketRatelimiter { + buckets: Arc>>>, + + global: Arc, +} + +impl RedisGlobalLocalBucketRatelimiter { + #[must_use] + pub fn new(redis: tokio::sync::Mutex) -> Self { + Self { + buckets: Arc::default(), + global: Arc::new(RedisLockPair(redis)), + } + } + + fn entry(&self, path: String, tx: TicketNotifier) -> Option> { + let mut buckets = self.buckets.lock().expect("buckets poisoned"); + + match buckets.entry(path.clone()) { + Entry::Occupied(bucket) => { + tracing::debug!("got existing bucket: {path:?}"); + + bucket.get().queue.push(tx); + + tracing::debug!("added request into bucket queue: {path:?}"); + + None + } + Entry::Vacant(entry) => { + tracing::debug!("making new bucket for path: {path:?}"); + + let bucket = Bucket::new(path); + bucket.queue.push(tx); + + let bucket = Arc::new(bucket); + entry.insert(Arc::clone(&bucket)); + + Some(bucket) + } + } + } + + pub fn ticket(&self, path: String) -> GetTicketFuture { + tracing::debug!("getting bucket for path: {path:?}"); + + let (tx, rx) = ticket::channel(); + + if let Some(bucket) = self.entry(path.clone(), tx) { + tokio::spawn( + BucketQueueTask::new( + bucket, + Arc::clone(&self.buckets), + Arc::clone(&self.global), + path, + ) + .run(), + ); + } + + Box::pin(future::ok(rx)) + } +} diff --git a/proto/nova/ratelimit/ratelimiter.proto b/proto/nova/ratelimit/ratelimiter.proto index 34d5b6f..3d7c75a 100644 --- a/proto/nova/ratelimit/ratelimiter.proto +++ b/proto/nova/ratelimit/ratelimiter.proto @@ -3,22 +3,9 @@ syntax = "proto3"; package nova.ratelimit.ratelimiter; service Ratelimiter { - rpc GetBucketInformation(BucketInformationRequest) returns (BucketInformationResponse); rpc SubmitTicket(stream BucketSubmitTicketRequest) returns (stream BucketSubmitTicketResponse); } -message BucketInformationRequest { - string path = 1; -} - -message BucketInformationResponse { - uint64 limit = 1; - uint64 remaining = 2; - uint64 reset_after = 3; - uint64 started_at = 4; -} - - message BucketSubmitTicketRequest { oneof data { string path = 1;