summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthieuCoder <matthieu@matthieu-dev.xyz>2023-01-02 20:07:48 +0400
committerMatthieuCoder <matthieu@matthieu-dev.xyz>2023-01-02 20:07:48 +0400
commit562eea54627527130de9ceee1d791f2948e431d2 (patch)
tree29ae82421f1a21cfd860e1db3353963d306a983e
parentf152af136f24f309cd95e645cbc2e06b776a01d7 (diff)
add ratelimiter service
-rw-r--r--Cargo.lock1
-rw-r--r--exes/.gitignore1
-rw-r--r--exes/ratelimit/Cargo.toml21
-rw-r--r--exes/ratelimit/src/grpc.rs79
-rw-r--r--exes/ratelimit/src/main.rs47
-rw-r--r--exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs323
-rw-r--r--exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs99
-rw-r--r--proto/nova/ratelimit/ratelimiter.proto13
8 files changed, 570 insertions, 14 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 8181e8d..08e01ef 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1792,6 +1792,7 @@ dependencies = [
"anyhow",
"futures-util",
"hyper",
+ "leash",
"proto",
"serde",
"serde_json",
diff --git a/exes/.gitignore b/exes/.gitignore
deleted file mode 100644
index 0f02d23..0000000
--- a/exes/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-ratelimit/ \ No newline at end of file
diff --git a/exes/ratelimit/Cargo.toml b/exes/ratelimit/Cargo.toml
new file mode 100644
index 0000000..a28d2d0
--- /dev/null
+++ b/exes/ratelimit/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "ratelimit"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+shared = { path = "../../libs/shared" }
+proto = { path = "../../libs/proto" }
+leash = { path = "../../libs/leash" }
+hyper = { version = "0.14", features = ["full"] }
+tokio = { version = "1", features = ["full"] }
+serde = { version = "1.0.8", features = ["derive"] }
+twilight-http-ratelimiting = { git = "https://github.com/MatthieuCoder/twilight.git" }
+anyhow = "*"
+futures-util = "0.3.17"
+tracing = "*"
+serde_json = { version = "1.0" }
+tonic = "0.8.3"
+tokio-stream = "0.1.11" \ No newline at end of file
diff --git a/exes/ratelimit/src/grpc.rs b/exes/ratelimit/src/grpc.rs
new file mode 100644
index 0000000..a75c329
--- /dev/null
+++ b/exes/ratelimit/src/grpc.rs
@@ -0,0 +1,79 @@
+
+use std::pin::Pin;
+
+use futures_util::Stream;
+use proto::nova::ratelimit::ratelimiter::{ratelimiter_server::Ratelimiter, BucketSubmitTicketResponse, BucketSubmitTicketRequest};
+use tokio::sync::mpsc;
+use tokio_stream::{wrappers::ReceiverStream, StreamExt};
+use tonic::{Request, Response, Status, Streaming};
+use twilight_http_ratelimiting::{ticket::TicketReceiver, RatelimitHeaders};
+
+use crate::redis_global_local_bucket_ratelimiter::RedisGlobalLocalBucketRatelimiter;
+
+pub struct RLServer {
+ ratelimiter: RedisGlobalLocalBucketRatelimiter,
+}
+
+impl RLServer {
+ pub fn new(ratelimiter: RedisGlobalLocalBucketRatelimiter) -> Self {
+ Self { ratelimiter }
+ }
+}
+
+#[tonic::async_trait]
+impl Ratelimiter for RLServer {
+
+ type SubmitTicketStream =
+ Pin<Box<dyn Stream<Item = Result<BucketSubmitTicketResponse, Status>> + Send>>;
+
+ async fn submit_ticket(
+ &self,
+ req: Request<Streaming<BucketSubmitTicketRequest>>,
+ ) -> Result<Response<Self::SubmitTicketStream>, Status> {
+ let mut in_stream = req.into_inner();
+ let (tx, rx) = mpsc::channel(128);
+ let imrl = self.ratelimiter.clone();
+
+ // this spawn here is required if you want to handle connection error.
+ // If we just map `in_stream` and write it back as `out_stream` the `out_stream`
+ // will be drooped when connection error occurs and error will never be propagated
+ // to mapped version of `in_stream`.
+ tokio::spawn(async move {
+ let mut receiver: Option<TicketReceiver> = None;
+ while let Some(result) = in_stream.next().await {
+ let result = result.unwrap();
+
+ match result.data.unwrap() {
+ proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Path(path) => {
+ let a = imrl.ticket(path).await.unwrap();
+ receiver = Some(a);
+
+
+ tx.send(Ok(BucketSubmitTicketResponse {
+ accepted: 1
+ })).await.unwrap();
+
+ },
+ proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Headers(b) => {
+ if let Some(recv) = receiver {
+ let recv = recv.await.unwrap();
+ let rheaders = RatelimitHeaders::from_pairs(b.headers.iter().map(|f| (f.0.as_str(), f.1.as_bytes()))).unwrap();
+
+ recv.headers(Some(rheaders)).unwrap();
+
+ break;
+ }
+ },
+ }
+ }
+ println!("\tstream ended");
+ });
+
+ // echo just write the same data that was received
+ let out_stream = ReceiverStream::new(rx);
+
+ Ok(Response::new(
+ Box::pin(out_stream) as Self::SubmitTicketStream
+ ))
+ }
+} \ No newline at end of file
diff --git a/exes/ratelimit/src/main.rs b/exes/ratelimit/src/main.rs
new file mode 100644
index 0000000..8e43b34
--- /dev/null
+++ b/exes/ratelimit/src/main.rs
@@ -0,0 +1,47 @@
+use std::net::ToSocketAddrs;
+
+use futures_util::FutureExt;
+use grpc::RLServer;
+use leash::{ignite, AnyhowResultFuture, Component};
+use proto::nova::ratelimit::ratelimiter::ratelimiter_server::RatelimiterServer;
+use redis_global_local_bucket_ratelimiter::RedisGlobalLocalBucketRatelimiter;
+use shared::{config::Settings, redis_crate::Client};
+use tokio::sync::oneshot;
+use tonic::transport::Server;
+
+mod grpc;
+mod redis_global_local_bucket_ratelimiter;
+
+struct RatelimiterServerComponent {}
+impl Component for RatelimiterServerComponent {
+ type Config = ();
+ const SERVICE_NAME: &'static str = "rest";
+
+ fn start(
+ &self,
+ settings: Settings<Self::Config>,
+ stop: oneshot::Receiver<()>,
+ ) -> AnyhowResultFuture<()> {
+ Box::pin(async move {
+ // let config = Arc::new(settings.config);
+ let redis: Client = settings.redis.into();
+ let server = RLServer::new(RedisGlobalLocalBucketRatelimiter::new(redis.into()));
+
+ Server::builder()
+ .add_service(RatelimiterServer::new(server))
+ .serve_with_shutdown(
+ "0.0.0.0:8080".to_socket_addrs().unwrap().next().unwrap(),
+ stop.map(|_| ()),
+ )
+ .await?;
+
+ Ok(())
+ })
+ }
+
+ fn new() -> Self {
+ Self {}
+ }
+}
+
+ignite!(RatelimiterServerComponent);
diff --git a/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs
new file mode 100644
index 0000000..d739acf
--- /dev/null
+++ b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs
@@ -0,0 +1,323 @@
+//! [`Bucket`] management used by the [`super::InMemoryRatelimiter`] internally.
+//! Each bucket has an associated [`BucketQueue`] to queue an API request, which is
+//! consumed by the [`BucketQueueTask`] that manages the ratelimit for the bucket
+//! and respects the global ratelimit.
+
+use super::RedisLockPair;
+use twilight_http_ratelimiting::{headers::RatelimitHeaders, ticket::TicketNotifier};
+use std::{
+ collections::HashMap,
+ mem,
+ sync::{
+ atomic::{AtomicU64, Ordering},
+ Arc, Mutex,
+ },
+ time::{Duration, Instant},
+};
+use tokio::{
+ sync::{
+ mpsc::{self, UnboundedReceiver, UnboundedSender},
+ Mutex as AsyncMutex,
+ },
+ time::{sleep, timeout},
+};
+
+/// Time remaining until a bucket will reset.
+#[derive(Clone, Debug)]
+pub enum TimeRemaining {
+ /// Bucket has already reset.
+ Finished,
+ /// Bucket's ratelimit refresh countdown has not started yet.
+ NotStarted,
+ /// Amount of time until the bucket resets.
+ Some(Duration),
+}
+
+/// Ratelimit information for a bucket used in the [`super::InMemoryRatelimiter`].
+///
+/// A generic version not specific to this ratelimiter is [`crate::Bucket`].
+#[derive(Debug)]
+pub struct Bucket {
+ /// Total number of tickets allotted in a cycle.
+ pub limit: AtomicU64,
+ /// Path this ratelimit applies to.
+ pub path: String,
+ /// Queue associated with this bucket.
+ pub queue: BucketQueue,
+ /// Number of tickets remaining.
+ pub remaining: AtomicU64,
+ /// Duration after the [`Self::started_at`] time the bucket will refresh.
+ pub reset_after: AtomicU64,
+ /// When the bucket's ratelimit refresh countdown started.
+ pub started_at: Mutex<Option<Instant>>,
+}
+
+impl Bucket {
+ /// Create a new bucket for the specified [`Path`].
+ pub fn new(path: String) -> Self {
+ Self {
+ limit: AtomicU64::new(u64::max_value()),
+ path,
+ queue: BucketQueue::default(),
+ remaining: AtomicU64::new(u64::max_value()),
+ reset_after: AtomicU64::new(u64::max_value()),
+ started_at: Mutex::new(None),
+ }
+ }
+
+ /// Total number of tickets allotted in a cycle.
+ pub fn limit(&self) -> u64 {
+ self.limit.load(Ordering::Relaxed)
+ }
+
+ /// Number of tickets remaining.
+ pub fn remaining(&self) -> u64 {
+ self.remaining.load(Ordering::Relaxed)
+ }
+
+ /// Duration after the [`started_at`] time the bucket will refresh.
+ ///
+ /// [`started_at`]: Self::started_at
+ pub fn reset_after(&self) -> u64 {
+ self.reset_after.load(Ordering::Relaxed)
+ }
+
+ /// Time remaining until this bucket will reset.
+ pub fn time_remaining(&self) -> TimeRemaining {
+ let reset_after = self.reset_after();
+ let maybe_started_at = *self.started_at.lock().expect("bucket poisoned");
+
+ let started_at = if let Some(started_at) = maybe_started_at {
+ started_at
+ } else {
+ return TimeRemaining::NotStarted;
+ };
+
+ let elapsed = started_at.elapsed();
+
+ if elapsed > Duration::from_millis(reset_after) {
+ return TimeRemaining::Finished;
+ }
+
+ TimeRemaining::Some(Duration::from_millis(reset_after) - elapsed)
+ }
+
+ /// Try to reset this bucket's [`started_at`] value if it has finished.
+ ///
+ /// Returns whether resetting was possible.
+ ///
+ /// [`started_at`]: Self::started_at
+ pub fn try_reset(&self) -> bool {
+ if self.started_at.lock().expect("bucket poisoned").is_none() {
+ return false;
+ }
+
+ if let TimeRemaining::Finished = self.time_remaining() {
+ self.remaining.store(self.limit(), Ordering::Relaxed);
+ *self.started_at.lock().expect("bucket poisoned") = None;
+
+ true
+ } else {
+ false
+ }
+ }
+
+ /// Update this bucket's ratelimit data after a request has been made.
+ pub fn update(&self, ratelimits: Option<(u64, u64, u64)>) {
+ let bucket_limit = self.limit();
+
+ {
+ let mut started_at = self.started_at.lock().expect("bucket poisoned");
+
+ if started_at.is_none() {
+ started_at.replace(Instant::now());
+ }
+ }
+
+ if let Some((limit, remaining, reset_after)) = ratelimits {
+ if bucket_limit != limit && bucket_limit == u64::max_value() {
+ self.reset_after.store(reset_after, Ordering::SeqCst);
+ self.limit.store(limit, Ordering::SeqCst);
+ }
+
+ self.remaining.store(remaining, Ordering::Relaxed);
+ } else {
+ self.remaining.fetch_sub(1, Ordering::Relaxed);
+ }
+ }
+}
+
+/// Queue of ratelimit requests for a bucket.
+#[derive(Debug)]
+pub struct BucketQueue {
+ /// Receiver for the ratelimit requests.
+ rx: AsyncMutex<UnboundedReceiver<TicketNotifier>>,
+ /// Sender for the ratelimit requests.
+ tx: UnboundedSender<TicketNotifier>,
+}
+
+impl BucketQueue {
+ /// Add a new ratelimit request to the queue.
+ pub fn push(&self, tx: TicketNotifier) {
+ let _sent = self.tx.send(tx);
+ }
+
+ /// Receive the first incoming ratelimit request.
+ pub async fn pop(&self, timeout_duration: Duration) -> Option<TicketNotifier> {
+ let mut rx = self.rx.lock().await;
+
+ timeout(timeout_duration, rx.recv()).await.ok().flatten()
+ }
+}
+
+impl Default for BucketQueue {
+ fn default() -> Self {
+ let (tx, rx) = mpsc::unbounded_channel();
+
+ Self {
+ rx: AsyncMutex::new(rx),
+ tx,
+ }
+ }
+}
+
+/// A background task that handles ratelimit requests to a [`Bucket`]
+/// and processes them in order, keeping track of both the global and
+/// the [`Path`]-specific ratelimits.
+pub(super) struct BucketQueueTask {
+ /// The [`Bucket`] managed by this task.
+ bucket: Arc<Bucket>,
+ /// All buckets managed by the associated [`super::InMemoryRatelimiter`].
+ buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
+ /// Global ratelimit data.
+ global: Arc<RedisLockPair>,
+ /// The [`Path`] this [`Bucket`] belongs to.
+ path: String,
+}
+
+impl BucketQueueTask {
+ /// Timeout to wait for response headers after initiating a request.
+ const WAIT: Duration = Duration::from_secs(10);
+
+ /// Create a new task to manage the ratelimit for a [`Bucket`].
+ pub fn new(
+ bucket: Arc<Bucket>,
+ buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
+ global: Arc<RedisLockPair>,
+ path: String,
+ ) -> Self {
+ Self {
+ bucket,
+ buckets,
+ global,
+ path,
+ }
+ }
+
+ /// Process incoming ratelimit requests to this bucket and update the state
+ /// based on received [`RatelimitHeaders`].
+ #[tracing::instrument(name = "background queue task", skip(self), fields(path = ?self.path))]
+ pub async fn run(self) {
+ while let Some(queue_tx) = self.next().await {
+ if self.global.is_locked().await {
+ mem::drop(self.global.0.lock().await);
+ }
+
+ let ticket_headers = if let Some(ticket_headers) = queue_tx.available() {
+ ticket_headers
+ } else {
+ continue;
+ };
+
+ tracing::debug!("starting to wait for response headers");
+
+ match timeout(Self::WAIT, ticket_headers).await {
+ Ok(Ok(Some(headers))) => self.handle_headers(&headers).await,
+ Ok(Ok(None)) => {
+ tracing::debug!("request aborted");
+ }
+ Ok(Err(_)) => {
+ tracing::debug!("ticket channel closed");
+ }
+ Err(_) => {
+ tracing::debug!("receiver timed out");
+ }
+ }
+ }
+
+ tracing::debug!("bucket appears finished, removing");
+
+ self.buckets
+ .lock()
+ .expect("ratelimit buckets poisoned")
+ .remove(&self.path);
+ }
+
+ /// Update the bucket's ratelimit state.
+ async fn handle_headers(&self, headers: &RatelimitHeaders) {
+ let ratelimits = match headers {
+ RatelimitHeaders::Global(global) => {
+ self.lock_global(Duration::from_secs(global.retry_after()))
+ .await;
+
+ None
+ }
+ RatelimitHeaders::None => return,
+ RatelimitHeaders::Present(present) => {
+ Some((present.limit(), present.remaining(), present.reset_after()))
+ },
+ _=> unreachable!()
+ };
+
+ tracing::debug!(path=?self.path, "updating bucket");
+ self.bucket.update(ratelimits);
+ }
+
+ /// Lock the global ratelimit for a specified duration.
+ async fn lock_global(&self, wait: Duration) {
+ tracing::debug!(path=?self.path, "request got global ratelimited");
+ self.global.lock_for(wait).await;
+ }
+
+ /// Get the next [`TicketNotifier`] in the queue.
+ async fn next(&self) -> Option<TicketNotifier> {
+ tracing::debug!(path=?self.path, "starting to get next in queue");
+
+ self.wait_if_needed().await;
+
+ self.bucket.queue.pop(Self::WAIT).await
+ }
+
+ /// Wait for this bucket to refresh if it isn't ready yet.
+ #[tracing::instrument(name = "waiting for bucket to refresh", skip(self), fields(path = ?self.path))]
+ async fn wait_if_needed(&self) {
+ let wait = {
+ if self.bucket.remaining() > 0 {
+ return;
+ }
+
+ tracing::debug!("0 tickets remaining, may have to wait");
+
+ match self.bucket.time_remaining() {
+ TimeRemaining::Finished => {
+ self.bucket.try_reset();
+
+ return;
+ }
+ TimeRemaining::NotStarted => return,
+ TimeRemaining::Some(dur) => dur,
+ }
+ };
+
+ tracing::debug!(
+ milliseconds=%wait.as_millis(),
+ "waiting for ratelimit to pass",
+ );
+
+ sleep(wait).await;
+
+ tracing::debug!("done waiting for ratelimit to pass");
+
+ self.bucket.try_reset();
+ }
+}
diff --git a/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs
new file mode 100644
index 0000000..c759db9
--- /dev/null
+++ b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs
@@ -0,0 +1,99 @@
+use self::bucket::{Bucket, BucketQueueTask};
+use shared::redis_crate::{Client, Commands};
+use twilight_http_ratelimiting::ticket::{self, TicketNotifier};
+use twilight_http_ratelimiting::GetTicketFuture;
+mod bucket;
+
+use futures_util::future;
+use std::{
+ collections::hash_map::{Entry, HashMap},
+ sync::{Arc, Mutex},
+ time::Duration,
+};
+
+#[derive(Debug)]
+struct RedisLockPair(tokio::sync::Mutex<Client>);
+
+impl RedisLockPair {
+ /// Set the global ratelimit as exhausted.
+ pub async fn lock_for(&self, duration: Duration) {
+ let _: () = self
+ .0
+ .lock()
+ .await
+ .set_ex(
+ "nova:rls:lock",
+ 1,
+ (duration.as_secs() + 1).try_into().unwrap(),
+ )
+ .unwrap();
+ }
+
+ pub async fn is_locked(&self) -> bool {
+ self.0.lock().await.exists("nova:rls:lock").unwrap()
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct RedisGlobalLocalBucketRatelimiter {
+ buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
+
+ global: Arc<RedisLockPair>,
+}
+
+impl RedisGlobalLocalBucketRatelimiter {
+ #[must_use]
+ pub fn new(redis: tokio::sync::Mutex<Client>) -> Self {
+ Self {
+ buckets: Arc::default(),
+ global: Arc::new(RedisLockPair(redis)),
+ }
+ }
+
+ fn entry(&self, path: String, tx: TicketNotifier) -> Option<Arc<Bucket>> {
+ let mut buckets = self.buckets.lock().expect("buckets poisoned");
+
+ match buckets.entry(path.clone()) {
+ Entry::Occupied(bucket) => {
+ tracing::debug!("got existing bucket: {path:?}");
+
+ bucket.get().queue.push(tx);
+
+ tracing::debug!("added request into bucket queue: {path:?}");
+
+ None
+ }
+ Entry::Vacant(entry) => {
+ tracing::debug!("making new bucket for path: {path:?}");
+
+ let bucket = Bucket::new(path);
+ bucket.queue.push(tx);
+
+ let bucket = Arc::new(bucket);
+ entry.insert(Arc::clone(&bucket));
+
+ Some(bucket)
+ }
+ }
+ }
+
+ pub fn ticket(&self, path: String) -> GetTicketFuture {
+ tracing::debug!("getting bucket for path: {path:?}");
+
+ let (tx, rx) = ticket::channel();
+
+ if let Some(bucket) = self.entry(path.clone(), tx) {
+ tokio::spawn(
+ BucketQueueTask::new(
+ bucket,
+ Arc::clone(&self.buckets),
+ Arc::clone(&self.global),
+ path,
+ )
+ .run(),
+ );
+ }
+
+ Box::pin(future::ok(rx))
+ }
+}
diff --git a/proto/nova/ratelimit/ratelimiter.proto b/proto/nova/ratelimit/ratelimiter.proto
index 34d5b6f..3d7c75a 100644
--- a/proto/nova/ratelimit/ratelimiter.proto
+++ b/proto/nova/ratelimit/ratelimiter.proto
@@ -3,22 +3,9 @@ syntax = "proto3";
package nova.ratelimit.ratelimiter;
service Ratelimiter {
- rpc GetBucketInformation(BucketInformationRequest) returns (BucketInformationResponse);
rpc SubmitTicket(stream BucketSubmitTicketRequest) returns (stream BucketSubmitTicketResponse);
}
-message BucketInformationRequest {
- string path = 1;
-}
-
-message BucketInformationResponse {
- uint64 limit = 1;
- uint64 remaining = 2;
- uint64 reset_after = 3;
- uint64 started_at = 4;
-}
-
-
message BucketSubmitTicketRequest {
oneof data {
string path = 1;