]> git.puffer.fish Git - matthieu/nova.git/commitdiff
better ratelimit, new go structure and better build system
authorMatthieuCoder <matthieu@matthieu-dev.xyz>
Fri, 13 Jan 2023 18:23:19 +0000 (22:23 +0400)
committerMatthieuCoder <matthieu@matthieu-dev.xyz>
Fri, 13 Jan 2023 18:23:19 +0000 (22:23 +0400)
47 files changed:
.gitignore
Cargo.lock
Cargo.toml
Makefile [new file with mode: 0644]
cmd/nova/nova.go [new file with mode: 0644]
config/default.yml [deleted file]
docker-compose.yaml
exes/all-in-one/Cargo.toml [new file with mode: 0644]
exes/all-in-one/Makefile [new file with mode: 0644]
exes/all-in-one/build.rs [new file with mode: 0644]
exes/all-in-one/main.go [new file with mode: 0644]
exes/all-in-one/src/errors.rs [new file with mode: 0644]
exes/all-in-one/src/ffi.rs [new file with mode: 0644]
exes/all-in-one/src/lib.rs [new file with mode: 0644]
exes/all-in-one/src/main.rs [new file with mode: 0644]
exes/all-in-one/src/utils.rs [new file with mode: 0644]
exes/all/.gitignore [deleted file]
exes/all/Cargo.toml [deleted file]
exes/all/Makefile [deleted file]
exes/all/build.rs [deleted file]
exes/all/build/.gitkeep [deleted file]
exes/all/main.go [deleted file]
exes/all/src/lib.rs [deleted file]
exes/gateway/src/lib.rs
exes/ratelimit/bench/req.rs [new file with mode: 0644]
exes/ratelimit/src/buckets/async_queue.rs [new file with mode: 0644]
exes/ratelimit/src/buckets/atomic_instant.rs [new file with mode: 0644]
exes/ratelimit/src/buckets/bucket.rs [new file with mode: 0644]
exes/ratelimit/src/buckets/mod.rs [new file with mode: 0644]
exes/ratelimit/src/buckets/redis_lock.rs [new file with mode: 0644]
exes/ratelimit/src/grpc.rs
exes/ratelimit/src/lib.rs
exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs [deleted file]
exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs [deleted file]
exes/rest/src/handler.rs
exes/rest/src/lib.rs
exes/rest/src/ratelimit_client/mod.rs
exes/rest/src/ratelimit_client/remote_hashring.rs
exes/webhook/src/config.rs
exes/webhook/src/handler/signature.rs
internal/pkg/all-in-one/.gitignore [new file with mode: 0644]
internal/pkg/all-in-one/all-in-one.go [new file with mode: 0644]
internal/pkg/all-in-one/error_handler.h [new file with mode: 0644]
libs/shared/src/config.rs
libs/shared/src/error.rs [deleted file]
libs/shared/src/lib.rs
proto/nova/ratelimit/ratelimiter.proto

index 6c9334c78b0864528bd02b04ca78a5d24a0502f0..bce375af00ca9b29bbfb45dd01a7a32dc880dcef 100644 (file)
@@ -5,4 +5,6 @@ target/
 .idea\r
 config.yml\r
 \r
-config/
\ No newline at end of file
+config/*\r
+build/\r
+*.yml\r
index 8962206fcd1113b9dcaa6f062061aae8b62d5a9b..24585f9193143ef97b453df2f660aa9f55d3224c 100644 (file)
@@ -2,6 +2,15 @@
 # It is not intended for manual editing.
 version = 3
 
+[[package]]
+name = "addr2line"
+version = "0.19.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97"
+dependencies = [
+ "gimli",
+]
+
 [[package]]
 name = "adler"
 version = "1.0.2"
@@ -29,13 +38,14 @@ dependencies = [
 ]
 
 [[package]]
-name = "all"
+name = "all-in-one"
 version = "0.1.0"
 dependencies = [
  "anyhow",
  "cache",
  "cbindgen",
  "config",
+ "ctrlc",
  "gateway",
  "leash",
  "libc",
@@ -58,6 +68,9 @@ name = "anyhow"
 version = "1.0.68"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "2cb2f989d18dd141ab8ae82f64d1a8cdd37e0840f73a406896cf5e99502fab61"
+dependencies = [
+ "backtrace",
+]
 
 [[package]]
 name = "arc-swap"
@@ -194,6 +207,21 @@ dependencies = [
  "tower-service",
 ]
 
+[[package]]
+name = "backtrace"
+version = "0.3.67"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca"
+dependencies = [
+ "addr2line",
+ "cc",
+ "cfg-if",
+ "libc",
+ "miniz_oxide",
+ "object",
+ "rustc-demangle",
+]
+
 [[package]]
 name = "base64"
 version = "0.13.1"
@@ -436,6 +464,16 @@ dependencies = [
  "typenum",
 ]
 
+[[package]]
+name = "ctrlc"
+version = "3.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1631ca6e3c59112501a9d87fd86f21591ff77acd31331e8a73f8d80a65bbdd71"
+dependencies = [
+ "nix",
+ "windows-sys 0.42.0",
+]
+
 [[package]]
 name = "curve25519-dalek"
 version = "3.2.0"
@@ -775,6 +813,12 @@ dependencies = [
  "wasi 0.11.0+wasi-snapshot-preview1",
 ]
 
+[[package]]
+name = "gimli"
+version = "0.27.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dec7af912d60cdbd3677c1af9352ebae6fb8394d165568a2234df0fa00f87793"
+
 [[package]]
 name = "glob"
 version = "0.3.0"
@@ -1180,6 +1224,18 @@ version = "0.8.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
 
+[[package]]
+name = "nix"
+version = "0.26.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "46a58d1d356c6597d08cde02c2f09d785b09e28711837b1ed667dc652c08a694"
+dependencies = [
+ "bitflags",
+ "cfg-if",
+ "libc",
+ "static_assertions",
+]
+
 [[package]]
 name = "nkeys"
 version = "0.2.0"
@@ -1244,6 +1300,15 @@ dependencies = [
  "libc",
 ]
 
+[[package]]
+name = "object"
+version = "0.30.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2b8c786513eb403643f2a88c244c2aaa270ef2153f55094587d0c48a3cf22a83"
+dependencies = [
+ "memchr",
+]
+
 [[package]]
 name = "once_cell"
 version = "1.17.0"
@@ -1861,6 +1926,12 @@ dependencies = [
  "ordered-multimap",
 ]
 
+[[package]]
+name = "rustc-demangle"
+version = "0.1.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342"
+
 [[package]]
 name = "rustix"
 version = "0.36.6"
@@ -2180,6 +2251,12 @@ dependencies = [
  "der",
 ]
 
+[[package]]
+name = "static_assertions"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
+
 [[package]]
 name = "strsim"
 version = "0.10.0"
index 6fd4694689c0bca58609537108b8566bcb2efc83..691bdd9d2068893d5db17584d0ac76d0bdc0b218 100644 (file)
@@ -5,7 +5,7 @@ members = [
     "exes/rest/",\r
     "exes/webhook/",\r
     "exes/ratelimit/",\r
-    "exes/all",\r
+    "exes/all-in-one/",\r
 \r
     "libs/proto/",\r
     "libs/shared/",\r
diff --git a/Makefile b/Makefile
new file mode 100644 (file)
index 0000000..5008c89
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,21 @@
+# Creates the bin folder for build artifacts
+build/{bin,lib}:
+       @mkdir -p build/{lib,bin}
+
+# Builds all rust targets
+build/lib/liball_in_one.a build/bin/{cache,gateway,ratelimit,rest,webhook}: build/{bin,lib}
+       @echo "Building rust project"
+       cargo build --release
+       @cp target/release/liball_in_one.a build/lib
+       @cp target/release/{cache,gateway,ratelimit,rest,webhook} build/bin
+
+# Generated by a rust build script.
+internal/pkg/all-in-one/all-in-one.h: build/lib/liball_in_one.a
+
+# All in one program build
+build/bin/nova: build/lib/liball_in_one.a internal/pkg/all-in-one/all-in-one.h
+       go build -a -ldflags '-s' -o build/bin/nova cmd/nova/nova.go
+
+all: build/{lib,bin}/nova
+
+.PHONY: all
\ No newline at end of file
diff --git a/cmd/nova/nova.go b/cmd/nova/nova.go
new file mode 100644 (file)
index 0000000..cde4e46
--- /dev/null
@@ -0,0 +1,32 @@
+package main
+
+import (
+       "os"
+       "os/signal"
+       "syscall"
+
+       allinone "github.com/discordnova/nova/internal/pkg/all-in-one"
+)
+
+func main() {
+       allInOne, err := allinone.NewAllInOne()
+       if err != nil {
+               panic(err)
+       }
+       err = allInOne.Start()
+       if err != nil {
+               panic(err)
+       }
+       // Wait for a SIGINT
+       c := make(chan os.Signal, 1)
+       signal.Notify(c,
+               syscall.SIGHUP,
+               syscall.SIGINT,
+               syscall.SIGTERM,
+               syscall.SIGQUIT)
+       <-c
+
+       allInOne.Stop()
+
+       println("Arret de nova all in one")
+}
diff --git a/config/default.yml b/config/default.yml
deleted file mode 100644 (file)
index 1e42f07..0000000
+++ /dev/null
@@ -1,51 +0,0 @@
-gateway:
-  token: ODA3MTg4MzM1NzE3Mzg0MjEy.Gtk5vu.Ejt9d70tnB9W_tXYMUATsBU24nqwQjlUZy7QUo
-  intents: 3276799
-  shard: 0
-  shard_total: 1
-
-rest:
-  discord:
-    token: ODA3MTg4MzM1NzE3Mzg0MjEy.Gtk5vu.Ejt9d70tnB9W_tXYMUATsBU24nqwQjlUZy7QUo
-  server:
-    listening_adress: 0.0.0.0:8090
-
-webhook:
-  discord:
-    public_key: a3d9368eda990e11ca501655d219a2d88591de93d32037f62f453a8587a46ff5
-    client_id: 807188335717384212
-  server:
-    listening_adress: 0.0.0.0:8091
-
-cache:
-  toggles:
-    - channels_cache
-    - guilds_cache
-    - guild_schedules_cache
-    - stage_instances_cache
-    - integrations_cache
-    - members_cache
-    - bans_cache
-    - reactions_cache
-    - messages_cache
-    - threads_cache
-    - invites_cache
-    - roles_cache
-    - automoderation_cache
-    - voice_states_cache
-
-ratelimiter:
-  
-
-# Prometheus monitoring configuration
-monitoring:
-  enabled: false
-  address: 0.0.0.0
-  port: 9000
-
-# Nats broker configuration
-nats:
-  host: nats
-
-redis:
-  url: redis://redis
index 6d5ccdb89b4dab557956565c79786b7ee1ce85ea..712643319b5f9e11ae3265b6b7b9e33fcddebbfd 100644 (file)
@@ -9,6 +9,8 @@ services:
   
   redis:
     image: redis
+    ports:
+      - 6379:6379
 
   cache:
     image: ghcr.io/discordnova/nova/cache
diff --git a/exes/all-in-one/Cargo.toml b/exes/all-in-one/Cargo.toml
new file mode 100644 (file)
index 0000000..8fa4bde
--- /dev/null
@@ -0,0 +1,38 @@
+[package]
+name = "all-in-one"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+libc = "0.2.139"
+leash = { path = "../../libs/leash" }
+shared = { path = "../../libs/shared" }
+
+cache = { path = "../cache" }
+gateway = { path = "../gateway" }
+ratelimit = { path = "../ratelimit" }
+rest = { path = "../rest" }
+webhook = { path = "../webhook" }
+ctrlc = "3.2.4"
+
+tokio = { version = "1.23.1", features = ["rt"] }
+serde = "1.0.152"
+serde_json = "1.0.91"
+anyhow = { version = "1.0.68", features = ["backtrace"] }
+
+tracing = "0.1.37"
+
+config = "0.13.3"
+
+tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
+tracing-opentelemetry = "0.18.0"
+opentelemetry = { version ="0.18.0", features = ["rt-tokio"] }
+opentelemetry-otlp = { version = "0.11.0" }
+
+[lib]
+crate-type = ["staticlib", "rlib"]
+
+[build-dependencies]
+cbindgen = "0.24.3"
\ No newline at end of file
diff --git a/exes/all-in-one/Makefile b/exes/all-in-one/Makefile
new file mode 100644 (file)
index 0000000..8ed17c4
--- /dev/null
@@ -0,0 +1,15 @@
+clean:
+       rm ./build/*
+
+library:
+       cargo build --release
+
+build: library
+       cp ../../target/release/liball_in_one.a ./build
+       go build -a -ldflags '-s' -o build/all-in-one
+
+all: library build
+
+run: all-in-one
+       ./build/all-in-one
+
diff --git a/exes/all-in-one/build.rs b/exes/all-in-one/build.rs
new file mode 100644 (file)
index 0000000..83ca650
--- /dev/null
@@ -0,0 +1,27 @@
+extern crate cbindgen;
+
+use cbindgen::{Config, Language};
+use std::env;
+use std::error::Error;
+use std::path::PathBuf;
+
+/// Generates the headers for the go program.
+fn main() -> Result<(), Box<dyn Error>> {
+    let crate_dir = env::var("CARGO_MANIFEST_DIR")?;
+    let package_name = env::var("CARGO_PKG_NAME")?;
+
+    // We export the header file to build/{package_name}.h
+    let output_file = PathBuf::from("../../internal/pkg/all-in-one")
+        .join(format!("{}.h", package_name))
+        .display()
+        .to_string();
+
+    let config = Config {
+        language: Language::C,
+        ..Default::default()
+    };
+
+    cbindgen::generate_with_config(crate_dir, config)?.write_to_file(output_file);
+    
+    Ok(())
+}
diff --git a/exes/all-in-one/main.go b/exes/all-in-one/main.go
new file mode 100644 (file)
index 0000000..1de08db
--- /dev/null
@@ -0,0 +1,72 @@
+package main
+
+/*
+#cgo LDFLAGS: -L./build -lall_in_one -lz -lm
+#include "./build/all-in-one.h"
+*/
+import "C"
+
+import (
+       "fmt"
+       "os"
+       "os/signal"
+       "syscall"
+       "time"
+
+       "github.com/Jeffail/gabs"
+       "github.com/alicebob/miniredis/v2"
+
+       server "github.com/nats-io/nats-server/v2/server"
+)
+
+func main() {
+       // Intialise les logs de la librarie Rust
+       C.init_logs()
+       // Charge la configuration
+       str := C.GoString(C.load_config())
+
+       // Démarre une instance MiniRedis
+       mr := miniredis.NewMiniRedis()
+       err := mr.Start()
+
+       if err != nil {
+               panic(err)
+       }
+
+       // Démarre un serveur Nats
+       opts := &server.Options{}
+       opts.Host = "0.0.0.0"
+       ns, err := server.NewServer(opts)
+
+       if err != nil {
+               panic(err)
+       }
+
+       go ns.Start()
+
+       if !ns.ReadyForConnections(4 * time.Second) {
+               panic("not ready for connection")
+       }
+
+       // Edite le json de configuration donné
+       // Et injecte la configuration des servers Nats et MiniRedis
+       json, _ := gabs.ParseJSON([]byte(str))
+       json.Set(fmt.Sprintf("redis://%s", mr.Addr()), "redis", "url")
+       json.Set("localhost", "nats", "host")
+       json.Set(1, "webhook", "discord", "client_id")
+
+       // Démarre une instance de nova
+       instance := C.start_instance(C.CString(json.String()))
+
+       // Wait for a SIGINT
+       c := make(chan os.Signal, 1)
+       signal.Notify(c,
+               syscall.SIGHUP,
+               syscall.SIGINT,
+               syscall.SIGTERM,
+               syscall.SIGQUIT)
+       <-c
+
+       println("Arret de nova all in one")
+       C.stop_instance(instance)
+}
diff --git a/exes/all-in-one/src/errors.rs b/exes/all-in-one/src/errors.rs
new file mode 100644 (file)
index 0000000..f169fb4
--- /dev/null
@@ -0,0 +1,53 @@
+use std::cell::RefCell;
+
+use anyhow::Result;
+use libc::c_int;
+use tracing::error;
+
+thread_local! {
+    pub static ERROR_HANDLER: std::cell::RefCell<Option<unsafe extern "C" fn(libc::c_int, *mut libc::c_char)>>  = RefCell::new(None);
+}
+
+/// Update the most recent error, clearing whatever may have been there before.
+pub fn stacktrace(err: anyhow::Error) -> String {
+    format!("{err}")
+}
+
+pub fn wrap_result<T, F>(func: F) -> Option<T>
+where
+    F: Fn() -> Result<T>,
+{
+    let result = func();
+
+    match result {
+        Ok(ok) => Some(ok),
+        Err(error) => {
+            // Call the handler
+            handle_error(error);
+            None
+        }
+    }
+}
+
+pub fn handle_error(error: anyhow::Error) {
+    ERROR_HANDLER.with(|val| {
+        let mut stacktrace = stacktrace(error);
+
+        error!("Error emitted: {}", stacktrace);
+        if let Some(func) = *val.borrow() {
+
+            // Call the error handler
+            unsafe {
+                func(
+                    stacktrace.len() as c_int + 1,
+                    stacktrace.as_mut_ptr() as *mut i8,
+                );
+            }
+        }
+    });
+}
+
+#[cfg(test)]
+mod tests {
+    // todo
+}
diff --git a/exes/all-in-one/src/ffi.rs b/exes/all-in-one/src/ffi.rs
new file mode 100644 (file)
index 0000000..49586cf
--- /dev/null
@@ -0,0 +1,138 @@
+use std::{
+    ffi::{c_char, c_int, CString},
+    mem::take,
+    ptr,
+    str::FromStr,
+    time::Duration,
+};
+
+use gateway::GatewayServer;
+use opentelemetry::{global::set_text_map_propagator, sdk::propagation::TraceContextPropagator};
+use ratelimit::RatelimiterServerComponent;
+use rest::ReverseProxyServer;
+use tokio::{runtime::Runtime, sync::mpsc};
+use tracing::{debug, error};
+use tracing_subscriber::{
+    filter::Directive, fmt, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt,
+    EnvFilter,
+};
+use webhook::WebhookServer;
+
+use crate::{
+    errors::{handle_error, wrap_result, ERROR_HANDLER},
+    utils::{load_config_file, start_component, AllInOneInstance},
+};
+
+#[no_mangle]
+pub unsafe extern "C" fn set_error_handler(func: unsafe extern "C" fn(c_int, *mut c_char)) {
+    debug!("Setting error handler");
+    ERROR_HANDLER.with(|prev| {
+        *prev.borrow_mut() = Some(func);
+    });
+}
+
+#[no_mangle]
+/// Loads the config json using the nova shared config loader
+pub extern "C" fn load_config() -> *mut c_char {
+    wrap_result(move || {
+        let config = serde_json::to_string(&load_config_file()?)?;
+        let c_str_song = CString::new(config)?;
+        Ok(c_str_song.into_raw())
+    })
+    .or(Some(ptr::null::<i8>() as *mut i8))
+    .expect("something has gone terribly wrong")
+}
+
+#[no_mangle]
+pub extern "C" fn stop_instance(instance: *mut AllInOneInstance) {
+    wrap_result(move || {
+        let mut instance = unsafe { Box::from_raw(instance) };
+        let handles = take(&mut instance.handles);
+        instance.runtime.block_on(async move {
+            for (name, sender, join) in handles {
+                debug!("Halting component {}", name);
+                let _ = sender
+                    .send(())
+                    .or_else(|_| Err(error!("Component {} is not online", name)));
+                match join.await {
+                    Ok(_) => {}
+                    Err(error) => error!("Task for component {} panic'ed {}", name, error),
+                };
+                debug!("Component {} halted", name);
+            }
+        });
+
+        instance.runtime.shutdown_timeout(Duration::from_secs(5));
+
+        Ok(())
+    });
+}
+
+#[no_mangle]
+pub extern "C" fn create_instance(config: *mut c_char) -> *mut AllInOneInstance {
+    wrap_result(move || {
+        let value = unsafe { CString::from_raw(config) };
+        let json = value.to_str()?;
+
+        // Main stop signal for this instance
+        let (error_sender, mut errors) = mpsc::channel(50);
+        let mut handles = vec![];
+
+        let runtime = Runtime::new()?;
+
+        // Setup the tracing system
+        set_text_map_propagator(TraceContextPropagator::new());
+        tracing_subscriber::registry()
+            .with(fmt::layer())
+            .with(
+                EnvFilter::builder()
+                    .with_default_directive(Directive::from_str("info").unwrap())
+                    .from_env()
+                    .unwrap(),
+            )
+            .init();
+
+        // Error handling task
+        runtime.spawn(async move {
+            while let Some(error) = errors.recv().await {
+                handle_error(error)
+            }
+        });
+
+        handles.push(start_component::<GatewayServer>(
+            json,
+            error_sender.clone(),
+            &runtime,
+        )?);
+
+        std::thread::sleep(Duration::from_secs(1));
+
+        handles.push(start_component::<RatelimiterServerComponent>(
+            json,
+            error_sender.clone(),
+            &runtime,
+        )?);
+
+        std::thread::sleep(Duration::from_secs(1));
+
+        handles.push(start_component::<ReverseProxyServer>(
+            json,
+            error_sender.clone(),
+            &runtime,
+        )?);
+
+        std::thread::sleep(Duration::from_secs(1));
+
+        handles.push(start_component::<WebhookServer>(
+            json,
+            error_sender.clone(),
+            &runtime,
+        )?);
+
+        let all_in_one = Box::into_raw(Box::new(AllInOneInstance { runtime, handles }));
+
+        Ok(all_in_one)
+    })
+    .or(Some(ptr::null::<i8>() as *mut AllInOneInstance))
+    .expect("something has gone terribly wrong")
+}
diff --git a/exes/all-in-one/src/lib.rs b/exes/all-in-one/src/lib.rs
new file mode 100644 (file)
index 0000000..5a41344
--- /dev/null
@@ -0,0 +1,3 @@
+pub mod utils;
+pub mod errors;
+pub mod ffi;
diff --git a/exes/all-in-one/src/main.rs b/exes/all-in-one/src/main.rs
new file mode 100644 (file)
index 0000000..e05b853
--- /dev/null
@@ -0,0 +1,20 @@
+use all_in_one::ffi::{create_instance, load_config, stop_instance};
+use std::sync::mpsc::channel;
+use ctrlc;
+
+fn main() {
+    let c = load_config();
+    let comp = create_instance(c);
+
+    // wait for signal
+    let (tx, rx) = channel();
+    
+    ctrlc::set_handler(move || tx.send(()).expect("Could not send signal on channel."))
+        .expect("Error setting Ctrl-C handler");
+
+    rx.recv().unwrap();
+
+    println!("Exiting.");
+
+    stop_instance(comp);
+}
\ No newline at end of file
diff --git a/exes/all-in-one/src/utils.rs b/exes/all-in-one/src/utils.rs
new file mode 100644 (file)
index 0000000..7a7134d
--- /dev/null
@@ -0,0 +1,82 @@
+use anyhow::Result;
+use config::{Config, Environment, File};
+use leash::Component;
+use serde::de::DeserializeOwned;
+use serde_json::Value;
+use shared::config::Settings;
+use tokio::{
+    runtime::Runtime,
+    sync::{mpsc, oneshot::Sender},
+    task::JoinHandle,
+};
+use tracing::{
+    debug,
+    log::{error, info},
+};
+
+/// Represents a all in one instance
+pub struct AllInOneInstance {
+    pub runtime: Runtime,
+    pub(crate) handles: Vec<(&'static str, Sender<()>, JoinHandle<()>)>,
+}
+
+/// Loads the settings from a component using a string
+fn load_settings_for<T: Default + DeserializeOwned + Clone>(
+    settings: &str,
+    name: &str,
+) -> Result<Settings<T>> {
+    let value: Value = serde_json::from_str(settings)?;
+    let section: T = serde_json::from_value(value.get(name).unwrap().to_owned())?;
+    let mut settings: Settings<T> = serde_json::from_value(value)?;
+    settings.config = section;
+
+    Ok(settings)
+}
+
+pub(crate) fn start_component<T: Component>(
+    json: &str,
+    error_sender: mpsc::Sender<anyhow::Error>,
+    runtime: &Runtime,
+) -> Result<(&'static str, Sender<()>, JoinHandle<()>)> {
+    let name = T::SERVICE_NAME;
+    let instance = T::new();
+
+    // We setup stop signals
+    let (stop, signal) = tokio::sync::oneshot::channel();
+    let settings = load_settings_for(json, name)?;
+
+    let handle = runtime.spawn(async move {
+        debug!("starting component {}", name);
+        match instance.start(settings, signal).await {
+            Ok(_) => info!("Component {} gracefully exited", name),
+            Err(error) => {
+                error!("Component {} exited with error {}", name, error);
+                error_sender
+                    .send(error)
+                    .await
+                    .expect("Couldn't send the error notification to the error mpsc");
+            }
+        }
+    });
+
+    Ok((name, stop, handle))
+}
+
+pub(crate) fn load_config_file() -> Result<Value> {
+    let mut builder = Config::builder();
+
+    builder = builder.add_source(File::with_name("config/default"));
+    let mode = std::env::var("ENV").unwrap_or_else(|_| "development".into());
+    info!("Configuration Environment: {}", mode);
+
+    builder = builder.add_source(File::with_name(&format!("config/{}", mode)).required(false));
+    builder = builder.add_source(File::with_name("config/local").required(false));
+
+    let env = Environment::with_prefix("NOVA").separator("__");
+    // we can configure each component using environment variables
+    builder = builder.add_source(env);
+
+    let config: Value = builder.build()?.try_deserialize()?;
+
+    Ok(config)
+}
diff --git a/exes/all/.gitignore b/exes/all/.gitignore
deleted file mode 100644 (file)
index 27531c7..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-build/*
-!build/.gitkeep
-config/
\ No newline at end of file
diff --git a/exes/all/Cargo.toml b/exes/all/Cargo.toml
deleted file mode 100644 (file)
index ef23e6b..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-[package]
-name = "all"
-version = "0.1.0"
-edition = "2021"
-
-# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
-
-[dependencies]
-libc = "0.2.139"
-leash = { path = "../../libs/leash" }
-shared = { path = "../../libs/shared" }
-
-cache = { path = "../cache" }
-gateway = { path = "../gateway" }
-ratelimit = { path = "../ratelimit" }
-rest = { path = "../rest" }
-webhook = { path = "../webhook" }
-
-tokio = { version = "1.23.1", features = ["rt"] }
-serde = "1.0.152"
-serde_json = "1.0.91"
-anyhow = "1.0.68"
-
-tracing = "0.1.37"
-
-config = "0.13.3"
-
-tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
-tracing-opentelemetry = "0.18.0"
-opentelemetry = { version ="0.18.0", features = ["rt-tokio"] }
-opentelemetry-otlp = { version = "0.11.0" }
-
-[lib]
-crate-type = ["staticlib"]
-
-[build-dependencies]
-cbindgen = "0.24.3"
\ No newline at end of file
diff --git a/exes/all/Makefile b/exes/all/Makefile
deleted file mode 100644 (file)
index 22a67e0..0000000
+++ /dev/null
@@ -1,14 +0,0 @@
-clean:
-       rm ./build/*
-
-library:
-       cargo build --release
-
-build: library
-       cp ../../target/release/liball.a ./build
-       go build -a -ldflags '-s' -o build/all
-
-all: library build
-
-run: all
-       ./build/all
diff --git a/exes/all/build.rs b/exes/all/build.rs
deleted file mode 100644 (file)
index 00d7afc..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-extern crate cbindgen;
-
-use cbindgen::{Config, Language};
-use std::env;
-use std::path::PathBuf;
-
-fn main() {
-    let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
-
-    let package_name = env::var("CARGO_PKG_NAME").unwrap();
-    let output_file = PathBuf::from("./build")
-        .join(format!("{}.h", package_name))
-        .display()
-        .to_string();
-
-    let config = Config {
-        language: Language::C,
-        ..Default::default()
-    };
-
-    cbindgen::generate_with_config(crate_dir, config)
-        .unwrap()
-        .write_to_file(output_file);
-}
diff --git a/exes/all/build/.gitkeep b/exes/all/build/.gitkeep
deleted file mode 100644 (file)
index e69de29..0000000
diff --git a/exes/all/main.go b/exes/all/main.go
deleted file mode 100644 (file)
index 3d528c9..0000000
+++ /dev/null
@@ -1,72 +0,0 @@
-package main
-
-/*
-#cgo LDFLAGS: -L./build -lall -lz -lm
-#include "./build/all.h"
-*/
-import "C"
-
-import (
-       "fmt"
-       "os"
-       "os/signal"
-       "syscall"
-       "time"
-
-       "github.com/Jeffail/gabs"
-       "github.com/alicebob/miniredis/v2"
-
-       server "github.com/nats-io/nats-server/v2/server"
-)
-
-func main() {
-       // Intialise les logs de la librarie Rust
-       C.init_logs()
-       // Charge la configuration
-       str := C.GoString(C.load_config())
-
-       // Démarre une instance MiniRedis
-       mr := miniredis.NewMiniRedis()
-       err := mr.Start()
-
-       if err != nil {
-               panic(err)
-       }
-
-       // Démarre un serveur Nats
-       opts := &server.Options{}
-       opts.Host = "0.0.0.0"
-       ns, err := server.NewServer(opts)
-
-       if err != nil {
-               panic(err)
-       }
-
-       go ns.Start()
-
-       if !ns.ReadyForConnections(4 * time.Second) {
-               panic("not ready for connection")
-       }
-
-       // Edite le json de configuration donné
-       // Et injecte la configuration des servers Nats et MiniRedis
-       json, _ := gabs.ParseJSON([]byte(str))
-       json.Set(fmt.Sprintf("redis://%s", mr.Addr()), "redis", "url")
-       json.Set("localhost", "nats", "host")
-       json.Set(1, "webhook", "discord", "client_id")
-
-       // Démarre une instance de nova
-       instance := C.start_instance(C.CString(json.String()))
-
-       // Wait for a SIGINT
-       c := make(chan os.Signal, 1)
-       signal.Notify(c,
-               syscall.SIGHUP,
-               syscall.SIGINT,
-               syscall.SIGTERM,
-               syscall.SIGQUIT)
-       <-c
-
-       println("Arret de nova all in one")
-       C.stop_instance(instance)
-}
diff --git a/exes/all/src/lib.rs b/exes/all/src/lib.rs
deleted file mode 100644 (file)
index 1c99109..0000000
+++ /dev/null
@@ -1,172 +0,0 @@
-#![allow(clippy::missing_safety_doc)]
-
-extern crate libc;
-use anyhow::Result;
-use config::{Config, Environment, File};
-use gateway::GatewayServer;
-use leash::Component;
-use opentelemetry::{
-    global,
-    sdk::{propagation::TraceContextPropagator, trace, Resource},
-    KeyValue,
-};
-use opentelemetry_otlp::WithExportConfig;
-use ratelimit::RatelimiterServerComponent;
-use rest::ReverseProxyServer;
-use serde::de::DeserializeOwned;
-use serde_json::Value;
-use shared::config::Settings;
-use std::{
-    env,
-    ffi::{CStr, CString},
-    str::FromStr,
-    time::Duration,
-};
-use tokio::{
-    runtime::Runtime,
-    sync::oneshot::{self, Sender},
-    task::JoinHandle,
-};
-use tracing::info;
-use tracing_subscriber::{
-    filter::Directive, fmt, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt,
-    EnvFilter,
-};
-use webhook::WebhookServer;
-
-pub struct AllInOneInstance {
-    pub stop: Sender<Sender<()>>,
-    pub runtime: Runtime,
-}
-
-fn load_settings_for<T: Default + DeserializeOwned + Clone>(
-    settings: &str,
-    name: &str,
-) -> Result<Settings<T>> {
-    let value: Value = serde_json::from_str(settings)?;
-    let section: T = serde_json::from_value(value.get(name).unwrap().to_owned())?;
-    let mut settings: Settings<T> = serde_json::from_value(value)?;
-    settings.config = section;
-
-    Ok(settings)
-}
-
-// Start a component
-async fn start_component<T: Component>(
-    settings: String,
-    aio: &mut Vec<Sender<()>>,
-) -> JoinHandle<()> {
-    let name = T::SERVICE_NAME;
-    let instance = T::new();
-
-    let (stop, signal) = oneshot::channel();
-
-    aio.push(stop);
-
-    tokio::spawn(async move {
-        let config = load_settings_for::<<T as Component>::Config>(&settings, name).unwrap();
-        instance.start(config, signal).await.unwrap();
-    })
-}
-
-#[no_mangle]
-/// Loads the config json using the nova shared config loader
-pub extern "C" fn load_config() -> *const libc::c_char {
-    let mut builder = Config::builder();
-
-    builder = builder.add_source(File::with_name("config/default"));
-    let mode = env::var("ENV").unwrap_or_else(|_| "development".into());
-    info!("Configuration Environment: {}", mode);
-
-    builder = builder.add_source(File::with_name(&format!("config/{}", mode)).required(false));
-    builder = builder.add_source(File::with_name("config/local").required(false));
-
-    let env = Environment::with_prefix("NOVA").separator("__");
-    // we can configure each component using environment variables
-    builder = builder.add_source(env);
-
-    let config: Value = builder.build().unwrap().try_deserialize().unwrap();
-    let s = serde_json::to_string(&config).unwrap();
-
-    let c_str_song = CString::new(s).unwrap();
-    c_str_song.into_raw()
-}
-
-#[no_mangle]
-/// Initialise les logs des composants de nova
-/// Utilise la crate `pretty_log_env`
-pub extern "C" fn init_logs() {}
-
-#[no_mangle]
-/// Stops a nova instance
-pub unsafe extern "C" fn stop_instance(instance: *mut AllInOneInstance) {
-    let instance = Box::from_raw(instance);
-    let (tell_ready, ready) = oneshot::channel();
-    instance.stop.send(tell_ready).unwrap();
-    ready.blocking_recv().unwrap();
-    instance.runtime.shutdown_timeout(Duration::from_secs(5));
-}
-
-#[no_mangle]
-/// Initialized a new nova instance and an async runtime (tokio reactor)
-/// Dont forget to stop this instance using `stop_instance`
-pub unsafe extern "C" fn start_instance(config: *const libc::c_char) -> *mut AllInOneInstance {
-    let buf_name = unsafe { CStr::from_ptr(config).to_bytes() };
-    let settings = String::from_utf8(buf_name.to_vec()).unwrap();
-    let (stop, trigger_stop) = oneshot::channel();
-
-    // Initialize a tokio runtime
-    let rt = Runtime::new().unwrap();
-    rt.block_on(async move {
-        global::set_text_map_propagator(TraceContextPropagator::new());
-        let tracer =
-            opentelemetry_otlp::new_pipeline()
-                .tracing()
-                .with_trace_config(trace::config().with_resource(Resource::new(vec![
-                    KeyValue::new("service.name", "all-in-one"),
-                ])))
-                .with_exporter(opentelemetry_otlp::new_exporter().tonic().with_env())
-                .install_batch(opentelemetry::runtime::Tokio)
-                .unwrap();
-
-        let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
-
-        tracing_subscriber::registry()
-            .with(fmt::layer())
-            .with(telemetry)
-            .with(
-                EnvFilter::builder()
-                    .with_default_directive(Directive::from_str("info").unwrap())
-                    .from_env()
-                    .unwrap(),
-            )
-            .init();
-        // Start the gateway server
-
-        let mut aio = vec![];
-        let mut handles = vec![];
-
-        // Start components
-        handles.push(start_component::<GatewayServer>(settings.clone(), &mut aio).await);
-        handles
-            .push(start_component::<RatelimiterServerComponent>(settings.clone(), &mut aio).await);
-        handles.push(start_component::<ReverseProxyServer>(settings.clone(), &mut aio).await);
-        handles.push(start_component::<WebhookServer>(settings.clone(), &mut aio).await);
-
-        // wait for exit
-        let done: Sender<()> = trigger_stop.await.unwrap();
-
-        // Tell all the threads to stop.
-        while let Some(stop_signal) = aio.pop() {
-            stop_signal.send(()).unwrap();
-        }
-
-        // Wait for all the threads to finish.
-        while let Some(handle) = handles.pop() {
-            handle.await.unwrap();
-        }
-
-        done.send(()).unwrap();
-    });
-    Box::into_raw(Box::new(AllInOneInstance { stop, runtime: rt }))
-}
index 014b72addc039e24960a2e17b117fe59e7615ce7..ec3337b47ebf4698ac4491b815b60e8d4d8e0124 100644 (file)
@@ -24,6 +24,7 @@ impl<'a> Injector for MetadataMap<'a> {
 }
 
 pub struct GatewayServer {}
+
 impl Component for GatewayServer {
     type Config = GatewayConfig;
     const SERVICE_NAME: &'static str = "gateway";
diff --git a/exes/ratelimit/bench/req.rs b/exes/ratelimit/bench/req.rs
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/exes/ratelimit/src/buckets/async_queue.rs b/exes/ratelimit/src/buckets/async_queue.rs
new file mode 100644 (file)
index 0000000..70b4ebd
--- /dev/null
@@ -0,0 +1,39 @@
+use tokio::sync::{
+    mpsc::{self, UnboundedReceiver, UnboundedSender},
+    oneshot::Sender,
+    Mutex,
+};
+
+/// Queue of ratelimit requests for a bucket.
+#[derive(Debug)]
+pub struct AsyncQueue {
+    /// Receiver for the ratelimit requests.
+    rx: Mutex<UnboundedReceiver<Sender<()>>>,
+    /// Sender for the ratelimit requests.
+    tx: UnboundedSender<Sender<()>>,
+}
+
+impl AsyncQueue {
+    /// Add a new ratelimit request to the queue.
+    pub fn push(&self, tx: Sender<()>) {
+        let _sent = self.tx.send(tx);
+    }
+
+    /// Receive the first incoming ratelimit request.
+    pub async fn pop(&self) -> Option<Sender<()>> {
+        let mut rx = self.rx.lock().await;
+
+        rx.recv().await
+    }
+}
+
+impl Default for AsyncQueue {
+    fn default() -> Self {
+        let (tx, rx) = mpsc::unbounded_channel();
+
+        Self {
+            rx: Mutex::new(rx),
+            tx,
+        }
+    }
+}
diff --git a/exes/ratelimit/src/buckets/atomic_instant.rs b/exes/ratelimit/src/buckets/atomic_instant.rs
new file mode 100644 (file)
index 0000000..ec31808
--- /dev/null
@@ -0,0 +1,35 @@
+use std::{
+    sync::atomic::{AtomicU64, Ordering},
+    time::{Duration, SystemTime, UNIX_EPOCH},
+};
+
+#[derive(Default, Debug)]
+pub struct AtomicInstant(AtomicU64);
+
+impl AtomicInstant {
+    pub const fn empty() -> Self {
+        Self(AtomicU64::new(0))
+    }
+
+    pub fn elapsed(&self) -> Duration {
+        Duration::from_millis(
+            SystemTime::now()
+                .duration_since(UNIX_EPOCH)
+                .unwrap()
+                .as_millis() as u64 as u64
+                - self.0.load(Ordering::SeqCst),
+        )
+    }
+
+    pub fn as_millis(&self) -> u64 {
+        self.0.load(Ordering::SeqCst)
+    }
+
+    pub fn set_millis(&self, millis: u64) {
+        self.0.store(millis, Ordering::SeqCst);
+    }
+
+    pub fn is_empty(&self) -> bool {
+        self.as_millis() == 0
+    }
+}
diff --git a/exes/ratelimit/src/buckets/bucket.rs b/exes/ratelimit/src/buckets/bucket.rs
new file mode 100644 (file)
index 0000000..f4a9b43
--- /dev/null
@@ -0,0 +1,173 @@
+use std::{
+    sync::{
+        atomic::{AtomicU64, Ordering},
+        Arc,
+    },
+    time::Duration,
+};
+use tokio::{sync::oneshot, task::JoinHandle};
+use tracing::debug;
+use twilight_http_ratelimiting::headers::Present;
+
+use super::{async_queue::AsyncQueue, atomic_instant::AtomicInstant, redis_lock::RedisLock};
+
+#[derive(Clone, Debug)]
+pub enum TimeRemaining {
+    Finished,
+    NotStarted,
+    Some(Duration),
+}
+
+#[derive(Debug)]
+pub struct Bucket {
+    pub limit: AtomicU64,
+    /// Queue associated with this bucket.
+    pub queue: AsyncQueue,
+    /// Number of tickets remaining.
+    pub remaining: AtomicU64,
+    /// Duration after the [`Self::last_update`] time the bucket will refresh.
+    pub reset_after: AtomicU64,
+    /// When the bucket's ratelimit refresh countdown started (unix millis)
+    pub last_update: AtomicInstant,
+
+    pub tasks: Vec<JoinHandle<()>>,
+}
+
+impl Drop for Bucket {
+    fn drop(&mut self) {
+        for join in &self.tasks {
+            join.abort();
+        }
+    }
+}
+
+impl Bucket {
+    /// Create a new bucket for the specified [`Path`].
+    pub fn new(global: Arc<RedisLock>) -> Arc<Self> {
+        let tasks = vec![];
+
+        let this = Arc::new(Self {
+            limit: AtomicU64::new(u64::max_value()),
+            queue: AsyncQueue::default(),
+            remaining: AtomicU64::new(u64::max_value()),
+            reset_after: AtomicU64::new(u64::max_value()),
+            last_update: AtomicInstant::empty(),
+            tasks,
+        });
+
+        // Run with 4 dequeue tasks
+        for _ in 0..4 {
+            let this = this.clone();
+            let global = global.clone();
+            tokio::spawn(async move {
+                while let Some(element) = this.queue.pop().await {
+                    // we need to wait
+                    if let Some(duration) = global.locked_for().await {
+                        tokio::time::sleep(duration).await;
+                    }
+
+                    if this.remaining() == 0 {
+                        debug!("0 tickets remaining, we have to wait.");
+
+                        match this.time_remaining() {
+                            TimeRemaining::Finished => {
+                                this.try_reset();
+                            }
+                            TimeRemaining::Some(duration) => {
+                                debug!(milliseconds=%duration.as_millis(), "waiting for ratelimit");
+                                tokio::time::sleep(duration).await;
+
+                                this.try_reset();
+                            }
+                            TimeRemaining::NotStarted => {}
+                        }
+                    }
+
+                    this.remaining.fetch_sub(1, Ordering::Relaxed);
+                    let _ = element.send(()).map_err(|_| { debug!("response channel was closed.") });
+                }
+            });
+        }
+
+        this
+    }
+
+    /// Total number of tickets allotted in a cycle.
+    pub fn limit(&self) -> u64 {
+        self.limit.load(Ordering::Relaxed)
+    }
+
+    /// Number of tickets remaining.
+    pub fn remaining(&self) -> u64 {
+        self.remaining.load(Ordering::Relaxed)
+    }
+
+    /// Duration after the [`started_at`] time the bucket will refresh.
+    ///
+    /// [`started_at`]: Self::started_at
+    pub fn reset_after(&self) -> u64 {
+        self.reset_after.load(Ordering::Relaxed)
+    }
+
+    /// Time remaining until this bucket will reset.
+    pub fn time_remaining(&self) -> TimeRemaining {
+        let reset_after = self.reset_after();
+        let last_update = &self.last_update;
+
+        if last_update.is_empty() {
+            let elapsed = last_update.elapsed();
+
+            if elapsed > Duration::from_millis(reset_after) {
+                return TimeRemaining::Finished;
+            }
+
+            TimeRemaining::Some(Duration::from_millis(reset_after) - elapsed)
+        } else {
+            return TimeRemaining::NotStarted;
+        }
+    }
+
+    /// Try to reset this bucket's [`started_at`] value if it has finished.
+    ///
+    /// Returns whether resetting was possible.
+    ///
+    /// [`started_at`]: Self::started_at
+    pub fn try_reset(&self) -> bool {
+        if self.last_update.is_empty() {
+            return false;
+        }
+
+        if let TimeRemaining::Finished = self.time_remaining() {
+            self.remaining.store(self.limit(), Ordering::Relaxed);
+            self.last_update.set_millis(0);
+
+            true
+        } else {
+            false
+        }
+    }
+
+    /// Update this bucket's ratelimit data after a request has been made.
+    pub fn update(&self, ratelimits: Present, time: u64) {
+        let bucket_limit = self.limit();
+
+        if self.last_update.is_empty() {
+            self.last_update.set_millis(time);
+        }
+
+        if bucket_limit != ratelimits.limit() && bucket_limit == u64::max_value() {
+            self.reset_after
+                .store(ratelimits.reset_after(), Ordering::SeqCst);
+            self.limit.store(ratelimits.limit(), Ordering::SeqCst);
+        }
+
+        self.remaining
+            .store(ratelimits.remaining(), Ordering::Relaxed);
+    }
+
+    pub async fn ticket(&self) -> oneshot::Receiver<()> {
+        let (tx, rx) = oneshot::channel();
+        self.queue.push(tx);
+        rx
+    }
+}
diff --git a/exes/ratelimit/src/buckets/mod.rs b/exes/ratelimit/src/buckets/mod.rs
new file mode 100644 (file)
index 0000000..e2f52c8
--- /dev/null
@@ -0,0 +1,4 @@
+pub mod bucket;
+pub mod redis_lock;
+pub mod atomic_instant;
+pub mod async_queue;
\ No newline at end of file
diff --git a/exes/ratelimit/src/buckets/redis_lock.rs b/exes/ratelimit/src/buckets/redis_lock.rs
new file mode 100644 (file)
index 0000000..fdae149
--- /dev/null
@@ -0,0 +1,84 @@
+use std::{
+    sync::{atomic::AtomicU64, Arc},
+    time::{Duration, SystemTime},
+};
+
+use redis::{aio::MultiplexedConnection, AsyncCommands};
+use tokio::sync::Mutex;
+use tracing::debug;
+
+/// This is flawed and needs to be replaced sometime with the real RedisLock algorithm
+#[derive(Debug)]
+pub struct RedisLock {
+    redis: Mutex<MultiplexedConnection>,
+    is_locked: AtomicU64,
+}
+
+impl RedisLock {
+    /// Set the global ratelimit as exhausted.
+    pub async fn lock_for(self: &Arc<Self>, duration: Duration) {
+        debug!("locking globally for {}", duration.as_secs());
+        let _: () = self
+            .redis
+            .lock()
+            .await
+            .set_ex(
+                "nova:rls:lock",
+                1,
+                (duration.as_secs() + 1).try_into().unwrap(),
+            )
+            .await
+            .unwrap();
+
+        self.is_locked.store(
+            (SystemTime::now() + duration)
+                .duration_since(SystemTime::UNIX_EPOCH)
+                .unwrap()
+                .as_millis() as u64,
+            std::sync::atomic::Ordering::Relaxed,
+        );
+    }
+
+    pub async fn locked_for(self: &Arc<Self>) -> Option<Duration> {
+        let load = self.is_locked.load(std::sync::atomic::Ordering::Relaxed);
+        if load != 0 {
+            if load
+                > SystemTime::now()
+                    .duration_since(SystemTime::UNIX_EPOCH)
+                    .unwrap()
+                    .as_millis() as u64
+            {
+                return Some(Duration::from_millis(load));
+            } else {
+                self.is_locked
+                    .store(0, std::sync::atomic::Ordering::Relaxed);
+            }
+        }
+
+        let result = self.redis.lock().await.ttl::<_, i64>("nova:rls:lock").await;
+        match result {
+            Ok(remaining_time) => {
+                if remaining_time > 0 {
+                    let duration = Duration::from_secs(remaining_time as u64);
+                    debug!("external global lock detected, locking");
+                    self.lock_for(duration).await;
+                    Some(duration)
+                } else {
+                    None
+                }
+            }
+            Err(error) => {
+                debug!("redis call failed: {}", error);
+
+                None
+            }
+        }
+    }
+
+    pub fn new(redis: MultiplexedConnection) -> Arc<Self> {
+        Arc::new(Self {
+            redis: Mutex::new(redis),
+            is_locked: AtomicU64::new(0),
+        })
+    }
+}
index fbcf3b746d8337df8bee5d6fdcc3c318ddc52ffa..0819e4615202ce326ade20dcfa4c90f31a5322ab 100644 (file)
@@ -1,24 +1,33 @@
-use opentelemetry::{global, propagation::Extractor};
+use std::collections::HashMap;
+use std::sync::Arc;
+use std::time::Duration;
+
+use opentelemetry::global;
+use opentelemetry::propagation::Extractor;
+use proto::nova::ratelimit::ratelimiter::HeadersSubmitRequest;
 use proto::nova::ratelimit::ratelimiter::{
-    ratelimiter_server::Ratelimiter, BucketSubmitTicketRequest, BucketSubmitTicketResponse,
+    ratelimiter_server::Ratelimiter, BucketSubmitTicketRequest,
 };
-use std::pin::Pin;
-use tokio::sync::mpsc;
-use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
-use tonic::{Request, Response, Status, Streaming};
-use tracing::{debug, debug_span, info, Instrument};
+use tokio::sync::RwLock;
+use tonic::Response;
+use tracing::debug;
 use tracing_opentelemetry::OpenTelemetrySpanExt;
-use twilight_http_ratelimiting::{ticket::TicketReceiver, RatelimitHeaders};
+use twilight_http_ratelimiting::RatelimitHeaders;
 
-use crate::redis_global_local_bucket_ratelimiter::RedisGlobalLocalBucketRatelimiter;
+use crate::buckets::bucket::Bucket;
+use crate::buckets::redis_lock::RedisLock;
 
 pub struct RLServer {
-    ratelimiter: RedisGlobalLocalBucketRatelimiter,
+    global: Arc<RedisLock>,
+    buckets: RwLock<HashMap<String, Arc<Bucket>>>,
 }
 
 impl RLServer {
-    pub fn new(ratelimiter: RedisGlobalLocalBucketRatelimiter) -> Self {
-        Self { ratelimiter }
+    pub fn new(redis_lock: Arc<RedisLock>) -> Self {
+        Self {
+            global: redis_lock,
+            buckets: RwLock::new(HashMap::new()),
+        }
     }
 }
 
@@ -44,66 +53,85 @@ impl<'a> Extractor for MetadataMap<'a> {
 
 #[tonic::async_trait]
 impl Ratelimiter for RLServer {
-    type SubmitTicketStream =
-        Pin<Box<dyn Stream<Item = Result<BucketSubmitTicketResponse, Status>> + Send>>;
-
-    async fn submit_ticket(
+    async fn submit_headers(
         &self,
-        req: Request<Streaming<BucketSubmitTicketRequest>>,
-    ) -> Result<Response<Self::SubmitTicketStream>, Status> {
+        request: tonic::Request<HeadersSubmitRequest>,
+    ) -> Result<tonic::Response<()>, tonic::Status> {
         let parent_cx =
-            global::get_text_map_propagator(|prop| prop.extract(&MetadataMap(req.metadata())));
+            global::get_text_map_propagator(|prop| prop.extract(&MetadataMap(request.metadata())));
         // Generate a tracing span as usual
         let span = tracing::span!(tracing::Level::INFO, "request process");
-
-        // Assign parent trace from external context
         span.set_parent(parent_cx);
 
-        let mut in_stream = req.into_inner();
-        let (tx, rx) = mpsc::channel(128);
-        let imrl = self.ratelimiter.clone();
-
-        // this spawn here is required if you want to handle connection error.
-        // If we just map `in_stream` and write it back as `out_stream` the `out_stream`
-        // will be drooped when connection error occurs and error will never be propagated
-        // to mapped version of `in_stream`.
-        tokio::spawn(async move {
-            let mut receiver: Option<TicketReceiver> = None;
-            while let Some(result) = in_stream.next().await {
-                let result = result.unwrap();
-
-                match result.data.unwrap() {
-                    proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Path(path) => {
-                        let span = debug_span!("requesting ticket");
-                        let a = imrl.ticket(path).instrument(span).await.unwrap();
-                        receiver = Some(a);
-
-                        tx.send(Ok(BucketSubmitTicketResponse {
-                            accepted: 1
-                        })).await.unwrap();
-                    },
-                    proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::Data::Headers(b) => {
-                        if let Some(recv) = receiver {
-                            let span = debug_span!("waiting for headers data");
-                            let recv = recv.instrument(span).await.unwrap();
-                            let rheaders = RatelimitHeaders::from_pairs(b.headers.iter().map(|f| (f.0.as_str(), f.1.as_bytes()))).unwrap();
-
-                            recv.headers(Some(rheaders)).unwrap();
-                            break;
-                        }
-                    },
-                }
+        let data = request.into_inner();
+
+        let ratelimit_headers = RatelimitHeaders::from_pairs(
+            data.headers.iter().map(|f| (f.0 as &str, f.1.as_bytes())),
+        )
+        .unwrap();
+
+        let bucket: Arc<Bucket> = if self.buckets.read().await.contains_key(&data.path) {
+            self.buckets
+                .read()
+                .await
+                .get(&data.path)
+                .expect("impossible")
+                .clone()
+        } else {
+            let bucket = Bucket::new(self.global.clone());
+            self.buckets.write().await.insert(data.path, bucket.clone());
+            bucket
+        };
+
+        match ratelimit_headers {
+            RatelimitHeaders::Global(global) => {
+                // If we are globally ratelimited, we lock using the redis lock
+                // This is using redis because a global ratelimit should be executed in all
+                // ratelimit workers.
+                debug!("global ratelimit headers detected: {}", global.retry_after());
+                self.global
+                    .lock_for(Duration::from_secs(global.retry_after()))
+                    .await;
+            }
+            RatelimitHeaders::None => {}
+            RatelimitHeaders::Present(present) => {
+                // we should update the bucket.
+                bucket.update(present, data.precise_time);
             }
+            _ => unreachable!(),
+        };
 
-            debug!("\tstream ended");
-            info!("request terminated");
-        }.instrument(span));
+        Ok(Response::new(()))
+    }
 
-        // echo just write the same data that was received
-        let out_stream = ReceiverStream::new(rx);
+    async fn submit_ticket(
+        &self,
+        request: tonic::Request<BucketSubmitTicketRequest>,
+    ) -> Result<tonic::Response<()>, tonic::Status> {
+        let parent_cx =
+            global::get_text_map_propagator(|prop| prop.extract(&MetadataMap(request.metadata())));
+        // Generate a tracing span as usual
+        let span = tracing::span!(tracing::Level::INFO, "request process");
+        span.set_parent(parent_cx);
 
-        Ok(Response::new(
-            Box::pin(out_stream) as Self::SubmitTicketStream
-        ))
+        let data = request.into_inner();
+
+        let bucket: Arc<Bucket> = if self.buckets.read().await.contains_key(&data.path) {
+            self.buckets
+                .read()
+                .await
+                .get(&data.path)
+                .expect("impossible")
+                .clone()
+        } else {
+            let bucket = Bucket::new(self.global.clone());
+            self.buckets.write().await.insert(data.path, bucket.clone());
+            bucket
+        };
+
+        // wait for the ticket to be accepted
+        bucket.ticket().await;
+
+        Ok(Response::new(()))
     }
 }
index 349c8acb726bd009181da6ab450e2f9581e208a6..5891b5853b54ef13e848c3b3405a068b49c6548b 100644 (file)
@@ -1,18 +1,18 @@
+use buckets::redis_lock::RedisLock;
 use config::RatelimitServerConfig;
 use grpc::RLServer;
 use leash::{AnyhowResultFuture, Component};
 use proto::nova::ratelimit::ratelimiter::ratelimiter_server::RatelimiterServer;
 use redis::aio::MultiplexedConnection;
-use redis_global_local_bucket_ratelimiter::RedisGlobalLocalBucketRatelimiter;
 use shared::config::Settings;
 use std::future::Future;
-use std::{net::ToSocketAddrs, pin::Pin};
+use std::{pin::Pin};
 use tokio::sync::oneshot;
 use tonic::transport::Server;
 
 mod config;
 mod grpc;
-mod redis_global_local_bucket_ratelimiter;
+mod buckets;
 
 pub struct RatelimiterServerComponent {}
 impl Component for RatelimiterServerComponent {
@@ -31,7 +31,7 @@ impl Component for RatelimiterServerComponent {
             >::into(settings.redis)
             .await?;
 
-            let server = RLServer::new(RedisGlobalLocalBucketRatelimiter::new(redis));
+            let server = RLServer::new(RedisLock::new(redis));
 
             Server::builder()
                 .add_service(RatelimiterServer::new(server))
diff --git a/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/bucket.rs
deleted file mode 100644 (file)
index b35dc45..0000000
+++ /dev/null
@@ -1,323 +0,0 @@
-//! [`Bucket`] management used by the [`super::InMemoryRatelimiter`] internally.
-//! Each bucket has an associated [`BucketQueue`] to queue an API request, which is
-//! consumed by the [`BucketQueueTask`] that manages the ratelimit for the bucket
-//! and respects the global ratelimit.
-
-use super::RedisLockPair;
-use std::{
-    collections::HashMap,
-    mem,
-    sync::{
-        atomic::{AtomicU64, Ordering},
-        Arc, Mutex,
-    },
-    time::{Duration, Instant},
-};
-use tokio::{
-    sync::{
-        mpsc::{self, UnboundedReceiver, UnboundedSender},
-        Mutex as AsyncMutex,
-    },
-    time::{sleep, timeout},
-};
-use twilight_http_ratelimiting::{headers::RatelimitHeaders, ticket::TicketNotifier};
-
-/// Time remaining until a bucket will reset.
-#[derive(Clone, Debug)]
-pub enum TimeRemaining {
-    /// Bucket has already reset.
-    Finished,
-    /// Bucket's ratelimit refresh countdown has not started yet.
-    NotStarted,
-    /// Amount of time until the bucket resets.
-    Some(Duration),
-}
-
-/// Ratelimit information for a bucket used in the [`super::InMemoryRatelimiter`].
-///
-/// A generic version not specific to this ratelimiter is [`crate::Bucket`].
-#[derive(Debug)]
-pub struct Bucket {
-    /// Total number of tickets allotted in a cycle.
-    pub limit: AtomicU64,
-    /// Path this ratelimit applies to.
-    pub path: String,
-    /// Queue associated with this bucket.
-    pub queue: BucketQueue,
-    /// Number of tickets remaining.
-    pub remaining: AtomicU64,
-    /// Duration after the [`Self::started_at`] time the bucket will refresh.
-    pub reset_after: AtomicU64,
-    /// When the bucket's ratelimit refresh countdown started.
-    pub started_at: Mutex<Option<Instant>>,
-}
-
-impl Bucket {
-    /// Create a new bucket for the specified [`Path`].
-    pub fn new(path: String) -> Self {
-        Self {
-            limit: AtomicU64::new(u64::max_value()),
-            path,
-            queue: BucketQueue::default(),
-            remaining: AtomicU64::new(u64::max_value()),
-            reset_after: AtomicU64::new(u64::max_value()),
-            started_at: Mutex::new(None),
-        }
-    }
-
-    /// Total number of tickets allotted in a cycle.
-    pub fn limit(&self) -> u64 {
-        self.limit.load(Ordering::Relaxed)
-    }
-
-    /// Number of tickets remaining.
-    pub fn remaining(&self) -> u64 {
-        self.remaining.load(Ordering::Relaxed)
-    }
-
-    /// Duration after the [`started_at`] time the bucket will refresh.
-    ///
-    /// [`started_at`]: Self::started_at
-    pub fn reset_after(&self) -> u64 {
-        self.reset_after.load(Ordering::Relaxed)
-    }
-
-    /// Time remaining until this bucket will reset.
-    pub fn time_remaining(&self) -> TimeRemaining {
-        let reset_after = self.reset_after();
-        let maybe_started_at = *self.started_at.lock().expect("bucket poisoned");
-
-        let started_at = if let Some(started_at) = maybe_started_at {
-            started_at
-        } else {
-            return TimeRemaining::NotStarted;
-        };
-
-        let elapsed = started_at.elapsed();
-
-        if elapsed > Duration::from_millis(reset_after) {
-            return TimeRemaining::Finished;
-        }
-
-        TimeRemaining::Some(Duration::from_millis(reset_after) - elapsed)
-    }
-
-    /// Try to reset this bucket's [`started_at`] value if it has finished.
-    ///
-    /// Returns whether resetting was possible.
-    ///
-    /// [`started_at`]: Self::started_at
-    pub fn try_reset(&self) -> bool {
-        if self.started_at.lock().expect("bucket poisoned").is_none() {
-            return false;
-        }
-
-        if let TimeRemaining::Finished = self.time_remaining() {
-            self.remaining.store(self.limit(), Ordering::Relaxed);
-            *self.started_at.lock().expect("bucket poisoned") = None;
-
-            true
-        } else {
-            false
-        }
-    }
-
-    /// Update this bucket's ratelimit data after a request has been made.
-    pub fn update(&self, ratelimits: Option<(u64, u64, u64)>) {
-        let bucket_limit = self.limit();
-
-        {
-            let mut started_at = self.started_at.lock().expect("bucket poisoned");
-
-            if started_at.is_none() {
-                started_at.replace(Instant::now());
-            }
-        }
-
-        if let Some((limit, remaining, reset_after)) = ratelimits {
-            if bucket_limit != limit && bucket_limit == u64::max_value() {
-                self.reset_after.store(reset_after, Ordering::SeqCst);
-                self.limit.store(limit, Ordering::SeqCst);
-            }
-
-            self.remaining.store(remaining, Ordering::Relaxed);
-        } else {
-            self.remaining.fetch_sub(1, Ordering::Relaxed);
-        }
-    }
-}
-
-/// Queue of ratelimit requests for a bucket.
-#[derive(Debug)]
-pub struct BucketQueue {
-    /// Receiver for the ratelimit requests.
-    rx: AsyncMutex<UnboundedReceiver<TicketNotifier>>,
-    /// Sender for the ratelimit requests.
-    tx: UnboundedSender<TicketNotifier>,
-}
-
-impl BucketQueue {
-    /// Add a new ratelimit request to the queue.
-    pub fn push(&self, tx: TicketNotifier) {
-        let _sent = self.tx.send(tx);
-    }
-
-    /// Receive the first incoming ratelimit request.
-    pub async fn pop(&self, timeout_duration: Duration) -> Option<TicketNotifier> {
-        let mut rx = self.rx.lock().await;
-
-        timeout(timeout_duration, rx.recv()).await.ok().flatten()
-    }
-}
-
-impl Default for BucketQueue {
-    fn default() -> Self {
-        let (tx, rx) = mpsc::unbounded_channel();
-
-        Self {
-            rx: AsyncMutex::new(rx),
-            tx,
-        }
-    }
-}
-
-/// A background task that handles ratelimit requests to a [`Bucket`]
-/// and processes them in order, keeping track of both the global and
-/// the [`Path`]-specific ratelimits.
-pub(super) struct BucketQueueTask {
-    /// The [`Bucket`] managed by this task.
-    bucket: Arc<Bucket>,
-    /// All buckets managed by the associated [`super::InMemoryRatelimiter`].
-    buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
-    /// Global ratelimit data.
-    global: Arc<RedisLockPair>,
-    /// The [`Path`] this [`Bucket`] belongs to.
-    path: String,
-}
-
-impl BucketQueueTask {
-    /// Timeout to wait for response headers after initiating a request.
-    const WAIT: Duration = Duration::from_secs(10);
-
-    /// Create a new task to manage the ratelimit for a [`Bucket`].
-    pub fn new(
-        bucket: Arc<Bucket>,
-        buckets: Arc<Mutex<HashMap<String, Arc<Bucket>>>>,
-        global: Arc<RedisLockPair>,
-        path: String,
-    ) -> Self {
-        Self {
-            bucket,
-            buckets,
-            global,
-            path,
-        }
-    }
-
-    /// Process incoming ratelimit requests to this bucket and update the state
-    /// based on received [`RatelimitHeaders`].
-    #[tracing::instrument(name = "background queue task", skip(self), fields(path = ?self.path))]
-    pub async fn run(self) {
-        while let Some(queue_tx) = self.next().await {
-            if self.global.is_locked().await {
-                mem::drop(self.global.0.lock().await);
-            }
-
-            let ticket_headers = if let Some(ticket_headers) = queue_tx.available() {
-                ticket_headers
-            } else {
-                continue;
-            };
-
-            tracing::debug!("starting to wait for response headers");
-
-            match timeout(Self::WAIT, ticket_headers).await {
-                Ok(Ok(Some(headers))) => self.handle_headers(&headers).await,
-                Ok(Ok(None)) => {
-                    tracing::debug!("request aborted");
-                }
-                Ok(Err(_)) => {
-                    tracing::debug!("ticket channel closed");
-                }
-                Err(_) => {
-                    tracing::debug!("receiver timed out");
-                }
-            }
-        }
-
-        tracing::debug!("bucket appears finished, removing");
-
-        self.buckets
-            .lock()
-            .expect("ratelimit buckets poisoned")
-            .remove(&self.path);
-    }
-
-    /// Update the bucket's ratelimit state.
-    async fn handle_headers(&self, headers: &RatelimitHeaders) {
-        let ratelimits = match headers {
-            RatelimitHeaders::Global(global) => {
-                self.lock_global(Duration::from_secs(global.retry_after()))
-                    .await;
-
-                None
-            }
-            RatelimitHeaders::None => return,
-            RatelimitHeaders::Present(present) => {
-                Some((present.limit(), present.remaining(), present.reset_after()))
-            }
-            _ => unreachable!(),
-        };
-
-        tracing::debug!(path=?self.path, "updating bucket");
-        self.bucket.update(ratelimits);
-    }
-
-    /// Lock the global ratelimit for a specified duration.
-    async fn lock_global(&self, wait: Duration) {
-        tracing::debug!(path=?self.path, "request got global ratelimited");
-        self.global.lock_for(wait).await;
-    }
-
-    /// Get the next [`TicketNotifier`] in the queue.
-    async fn next(&self) -> Option<TicketNotifier> {
-        tracing::debug!(path=?self.path, "starting to get next in queue");
-
-        self.wait_if_needed().await;
-
-        self.bucket.queue.pop(Self::WAIT).await
-    }
-
-    /// Wait for this bucket to refresh if it isn't ready yet.
-    #[tracing::instrument(name = "waiting for bucket to refresh", skip(self), fields(path = ?self.path))]
-    async fn wait_if_needed(&self) {
-        let wait = {
-            if self.bucket.remaining() > 0 {
-                return;
-            }
-
-            tracing::debug!("0 tickets remaining, may have to wait");
-
-            match self.bucket.time_remaining() {
-                TimeRemaining::Finished => {
-                    self.bucket.try_reset();
-
-                    return;
-                }
-                TimeRemaining::NotStarted => return,
-                TimeRemaining::Some(dur) => dur,
-            }
-        };
-
-        tracing::debug!(
-            milliseconds=%wait.as_millis(),
-            "waiting for ratelimit to pass",
-        );
-
-        sleep(wait).await;
-
-        tracing::debug!("done waiting for ratelimit to pass");
-
-        self.bucket.try_reset();
-    }
-}
diff --git a/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs b/exes/ratelimit/src/redis_global_local_bucket_ratelimiter/mod.rs
deleted file mode 100644 (file)
index a055b04..0000000
+++ /dev/null
@@ -1,101 +0,0 @@
-use self::bucket::{Bucket, BucketQueueTask};
-use redis::aio::MultiplexedConnection;
-use redis::AsyncCommands;
-use tokio::sync::Mutex;
-use twilight_http_ratelimiting::ticket::{self, TicketNotifier};
-use twilight_http_ratelimiting::GetTicketFuture;
-mod bucket;
-use std::future;
-use std::{
-    collections::hash_map::{Entry, HashMap},
-    sync::Arc,
-    time::Duration,
-};
-
-#[derive(Debug)]
-struct RedisLockPair(Mutex<MultiplexedConnection>);
-
-impl RedisLockPair {
-    /// Set the global ratelimit as exhausted.
-    pub async fn lock_for(&self, duration: Duration) {
-        let _: () = self
-            .0
-            .lock()
-            .await
-            .set_ex(
-                "nova:rls:lock",
-                1,
-                (duration.as_secs() + 1).try_into().unwrap(),
-            )
-            .await
-            .unwrap();
-    }
-
-    pub async fn is_locked(&self) -> bool {
-        self.0.lock().await.exists("nova:rls:lock").await.unwrap()
-    }
-}
-
-#[derive(Clone, Debug)]
-pub struct RedisGlobalLocalBucketRatelimiter {
-    buckets: Arc<std::sync::Mutex<HashMap<String, Arc<Bucket>>>>,
-
-    global: Arc<RedisLockPair>,
-}
-
-impl RedisGlobalLocalBucketRatelimiter {
-    #[must_use]
-    pub fn new(redis: MultiplexedConnection) -> Self {
-        Self {
-            buckets: Arc::default(),
-            global: Arc::new(RedisLockPair(Mutex::new(redis))),
-        }
-    }
-
-    fn entry(&self, path: String, tx: TicketNotifier) -> Option<Arc<Bucket>> {
-        let mut buckets = self.buckets.lock().expect("buckets poisoned");
-
-        match buckets.entry(path.clone()) {
-            Entry::Occupied(bucket) => {
-                tracing::debug!("got existing bucket: {path:?}");
-
-                bucket.get().queue.push(tx);
-
-                tracing::debug!("added request into bucket queue: {path:?}");
-
-                None
-            }
-            Entry::Vacant(entry) => {
-                tracing::debug!("making new bucket for path: {path:?}");
-
-                let bucket = Bucket::new(path);
-                bucket.queue.push(tx);
-
-                let bucket = Arc::new(bucket);
-                entry.insert(Arc::clone(&bucket));
-
-                Some(bucket)
-            }
-        }
-    }
-
-    pub fn ticket(&self, path: String) -> GetTicketFuture {
-        tracing::debug!("getting bucket for path: {path:?}");
-
-        let (tx, rx) = ticket::channel();
-
-        if let Some(bucket) = self.entry(path.clone(), tx) {
-            tokio::spawn(
-                BucketQueueTask::new(
-                    bucket,
-                    Arc::clone(&self.buckets),
-                    Arc::clone(&self.global),
-                    path,
-                )
-                .run(),
-            );
-        }
-
-        Box::pin(future::ready(Ok(rx)))
-    }
-}
index 6d58dce94f3a2d91ab073d931b2ce52be9add306..f59583d069c3f848607b3232af6b9ac4da54989a 100644 (file)
@@ -10,8 +10,9 @@ use std::{
     convert::TryFrom,
     hash::{Hash, Hasher},
     str::FromStr,
+    sync::Arc,
 };
-use tracing::{debug_span, error, instrument, Instrument};
+use tracing::{debug_span, error, info_span, instrument, Instrument};
 use twilight_http_ratelimiting::{Method, Path};
 
 use crate::ratelimit_client::RemoteRatelimiter;
@@ -37,8 +38,8 @@ fn normalize_path(request_path: &str) -> (&str, &str) {
 #[instrument]
 pub async fn handle_request(
     client: Client<HttpsConnector<HttpConnector>, Body>,
-    ratelimiter: RemoteRatelimiter,
-    token: &str,
+    ratelimiter: Arc<RemoteRatelimiter>,
+    token: String,
     mut request: Request<Body>,
 ) -> Result<Response<Body>, anyhow::Error> {
     let (hash, uri_string) = {
@@ -78,19 +79,12 @@ pub async fn handle_request(
 
         (hash.finish().to_string(), uri_string)
     };
+    // waits for the request to be authorized
+    ratelimiter
+        .ticket(hash.clone())
+        .instrument(debug_span!("ticket validation request"))
+        .await?;
 
-    let span = debug_span!("ticket validation request");
-    let header_sender = match span
-        .in_scope(|| ratelimiter.ticket(hash))
-        .await
-    {
-        Ok(sender) => sender,
-        Err(e) => {
-            error!("Failed to receive ticket for ratelimiting: {:?}", e);
-            bail!("failed to reteive ticket");
-        }
-    };
-    
     request
         .headers_mut()
         .insert(HOST, HeaderValue::from_static("discord.com"));
@@ -135,15 +129,18 @@ pub async fn handle_request(
         }
     };
 
-    let ratelimit_headers = resp
+    let headers = resp
         .headers()
         .into_iter()
-        .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string()))
+        .map(|(k, v)| (k.to_string(), v.to_str().map(|f| f.to_string())))
+        .filter(|f| f.1.is_ok())
+        .map(|f| (f.0, f.1.expect("errors should be filtered")))
         .collect();
 
-    if header_sender.send(ratelimit_headers).is_err() {
-        error!("Error when sending ratelimit headers to ratelimiter");
-    };
+    let _ = ratelimiter
+        .submit_headers(hash, headers)
+        .instrument(info_span!("submitting headers"))
+        .await;
 
     Ok(resp)
 }
index 2cef63102cb938897f036853250510fe89991403..b2287b4b8db9c552d40835e59f9397a80689e695 100644 (file)
@@ -7,11 +7,12 @@ use hyper::{
     Body, Client, Request, Server,
 };
 use leash::{AnyhowResultFuture, Component};
-use opentelemetry::{global, trace::Tracer};
+use opentelemetry::{global};
 use opentelemetry_http::HeaderExtractor;
 use shared::config::Settings;
 use std::{convert::Infallible, sync::Arc};
 use tokio::sync::oneshot;
+use tracing_opentelemetry::OpenTelemetrySpanExt;
 
 mod config;
 mod handler;
@@ -29,7 +30,9 @@ impl Component for ReverseProxyServer {
     ) -> AnyhowResultFuture<()> {
         Box::pin(async move {
             // Client to the remote ratelimiters
-            let ratelimiter = ratelimit_client::RemoteRatelimiter::new(settings.config.clone());
+            let ratelimiter = Arc::new(ratelimit_client::RemoteRatelimiter::new(
+                settings.config.clone(),
+            ));
             let https = hyper_rustls::HttpsConnectorBuilder::new()
                 .with_native_roots()
                 .https_only()
@@ -37,31 +40,37 @@ impl Component for ReverseProxyServer {
                 .build();
 
             let client: Client<_, hyper::Body> = Client::builder().build(https);
-            let token = Arc::new(settings.discord.token.clone());
+            let token = settings.config.discord.token.clone();
+
             let service_fn = make_service_fn(move |_: &AddrStream| {
                 let client = client.clone();
                 let ratelimiter = ratelimiter.clone();
                 let token = token.clone();
                 async move {
                     Ok::<_, Infallible>(service_fn(move |request: Request<Body>| {
+                        let token = token.clone();
                         let parent_cx = global::get_text_map_propagator(|propagator| {
                             propagator.extract(&HeaderExtractor(request.headers()))
                         });
-                        let _span =
-                            global::tracer("").start_with_context("handle_request", &parent_cx);
+
+                        let span = tracing::span!(tracing::Level::INFO, "request process");
+                        span.set_parent(parent_cx);
 
                         let client = client.clone();
                         let ratelimiter = ratelimiter.clone();
-                        let token = token.clone();
+
                         async move {
-                            let token = token.as_str();
+                            let token = token.clone();
+                            let ratelimiter = ratelimiter.clone();
                             handle_request(client, ratelimiter, token, request).await
                         }
                     }))
                 }
             });
 
-            let server = Server::bind(&settings.config.server.listening_adress).http1_only(true).serve(service_fn);
+            let server = Server::bind(&settings.config.server.listening_adress)
+                .http1_only(true)
+                .serve(service_fn);
 
             server
                 .with_graceful_shutdown(async {
index 6212529319cd87f7ec86589d34fe0ce9bcfc9d90..ecd489c552b803c474e1b5ab213ca9331198bb02 100644 (file)
@@ -1,21 +1,18 @@
 use crate::config::ReverseProxyConfig;
 
 use self::remote_hashring::{HashRingWrapper, MetadataMap, VNode};
+use anyhow::anyhow;
 use opentelemetry::global;
-use proto::nova::ratelimit::ratelimiter::bucket_submit_ticket_request::{Data, Headers};
-use proto::nova::ratelimit::ratelimiter::BucketSubmitTicketRequest;
+use proto::nova::ratelimit::ratelimiter::{BucketSubmitTicketRequest, HeadersSubmitRequest};
 use std::collections::HashMap;
 use std::fmt::Debug;
 use std::future::Future;
 use std::pin::Pin;
 use std::sync::Arc;
-use std::time::UNIX_EPOCH;
 use std::time::{Duration, SystemTime};
-use tokio::sync::oneshot::{self};
-use tokio::sync::{broadcast, mpsc, RwLock};
-use tokio_stream::wrappers::ReceiverStream;
+use tokio::sync::{broadcast, RwLock};
 use tonic::Request;
-use tracing::{debug, debug_span, instrument, Instrument, Span};
+use tracing::{debug, error, info_span, instrument, trace_span, Instrument, Span};
 use tracing_opentelemetry::OpenTelemetrySpanExt;
 
 mod remote_hashring;
@@ -29,25 +26,21 @@ pub struct RemoteRatelimiter {
 
 impl Drop for RemoteRatelimiter {
     fn drop(&mut self) {
-        self.stop.clone().send(()).unwrap();
+        let _ = self
+            .stop
+            .clone()
+            .send(())
+            .map_err(|_| error!("ratelimiter was already stopped"));
     }
 }
 
-type IssueTicket = Pin<
-    Box<
-        dyn Future<Output = anyhow::Result<oneshot::Sender<HashMap<String, String>>>>
-            + Send
-            + 'static,
-    >,
->;
-
 impl RemoteRatelimiter {
     async fn get_ratelimiters(&self) -> Result<(), anyhow::Error> {
         // get list of dns responses
-        let responses = dns_lookup::lookup_host(&self.config.ratelimiter_address)
-            .unwrap()
+        let responses = dns_lookup::lookup_host(&self.config.ratelimiter_address)?
             .into_iter()
-            .map(|f| f.to_string());
+            .filter(|address| address.is_ipv4())
+            .map(|address| address.to_string());
 
         let mut write = self.remotes.write().await;
 
@@ -72,11 +65,19 @@ impl RemoteRatelimiter {
         // Task to update the ratelimiters in the background
         tokio::spawn(async move {
             loop {
+                debug!("refreshing");
+
+                match obj_clone.get_ratelimiters().await {
+                    Ok(_) => {
+                        debug!("refreshed ratelimiting servers")
+                    }
+                    Err(err) => {
+                        error!("refreshing ratelimiting servers failed {}", err);
+                    }
+                }
+
                 let sleep = tokio::time::sleep(Duration::from_secs(10));
                 tokio::pin!(sleep);
-
-                debug!("refreshing");
-                obj_clone.get_ratelimiters().await.unwrap();
                 tokio::select! {
                     () = &mut sleep => {
                         debug!("timer elapsed");
@@ -90,79 +91,79 @@ impl RemoteRatelimiter {
     }
 
     #[instrument(name = "ticket task")]
-    pub fn ticket(&self, path: String) -> IssueTicket {
+    pub fn ticket(
+        &self,
+        path: String,
+    ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> {
         let remotes = self.remotes.clone();
-        let (tx, rx) = oneshot::channel::<HashMap<String, String>>();
         Box::pin(
             async move {
-                // Get node managing this path
-                let mut node = (*remotes.read().await.get(&path).unwrap()).clone();
-
-                // Buffers for the gRPC streaming channel.
-                let (send, remote) = mpsc::channel(5);
-                let (do_request, wait) = oneshot::channel();
-                // Tonic requires a stream to be used; Since we use a mpsc channel, we can create a stream from it
-                let stream = ReceiverStream::new(remote);
-
-                let mut request = Request::new(stream);
-
-                let span = debug_span!("remote request");
+                // Getting the node managing this path
+                let mut node = remotes
+                    .write()
+                    .instrument(trace_span!("acquiring ring lock"))
+                    .await
+                    .get(&path)
+                    .and_then(|node| Some(node.clone()))
+                    .ok_or_else(|| {
+                        anyhow!(
+                            "did not compute ratelimit because no ratelimiter nodes are detected"
+                        )
+                    })?;
+
+                // Initialize span for tracing (headers injection)
+                let span = info_span!("remote request");
                 let context = span.context();
+                let mut request = Request::new(BucketSubmitTicketRequest { path });
                 global::get_text_map_propagator(|propagator| {
                     propagator.inject_context(&context, &mut MetadataMap(request.metadata_mut()))
                 });
 
-                // Start the grpc streaming
-                let ticket = node.submit_ticket(request).await?;
-
-                // First, send the request
-                send.send(BucketSubmitTicketRequest {
-                    data: Some(Data::Path(path)),
-                })
-                .await?;
-
-                // We continuously listen for events in the channel.
-                let span = debug_span!("stream worker");
-                tokio::spawn(
-                    async move {
-                        let span = debug_span!("waiting for ticket upstream");
-                        let message = ticket
-                            .into_inner()
-                            .message()
-                            .instrument(span)
-                            .await
-                            .unwrap()
-                            .unwrap();
-
-                        if message.accepted == 1 {
-                            debug!("request ticket was accepted");
-                            do_request.send(()).unwrap();
-                            let span = debug_span!("waiting for response headers");
-                            let headers = rx.instrument(span).await.unwrap();
-
-                            send.send(BucketSubmitTicketRequest {
-                                data: Some(Data::Headers(Headers {
-                                    precise_time: SystemTime::now()
-                                        .duration_since(UNIX_EPOCH)
-                                        .expect("time went backwards")
-                                        .as_millis()
-                                        as u64,
-                                    headers,
-                                })),
-                            })
-                            .await
-                            .unwrap();
-                        }
-                    }
-                    .instrument(span),
-                );
-
-                // Wait for the message to be sent
-                wait.await?;
+                // Requesting
+                node.submit_ticket(request)
+                    .instrument(info_span!("waiting for ticket response"))
+                    .await?;
 
-                Ok(tx)
+                Ok(())
             }
             .instrument(Span::current()),
         )
     }
+
+    pub fn submit_headers(
+        &self,
+        path: String,
+        headers: HashMap<String, String>,
+    ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> {
+        let remotes = self.remotes.clone();
+        Box::pin(async move {
+            let mut node = remotes
+                .write()
+                .instrument(trace_span!("acquiring ring lock"))
+                .await
+                .get(&path)
+                .and_then(|node| Some(node.clone()))
+                .ok_or_else(|| {
+                    anyhow!("did not compute ratelimit because no ratelimiter nodes are detected")
+                })?;
+
+            let span = info_span!("remote request");
+            let context = span.context();
+            let time = SystemTime::now()
+                .duration_since(SystemTime::UNIX_EPOCH)?
+                .as_millis();
+            let mut request = Request::new(HeadersSubmitRequest {
+                path,
+                precise_time: time as u64,
+                headers,
+            });
+            global::get_text_map_propagator(|propagator| {
+                propagator.inject_context(&context, &mut MetadataMap(request.metadata_mut()))
+            });
+
+            node.submit_headers(request).await?;
+
+            Ok(())
+        })
+    }
 }
index b8771e51d98e0b7ec61464f22e9c2dcd24d8af2e..ac025c802897d66187d964127cbc532a516e1f8e 100644 (file)
@@ -6,6 +6,7 @@ use std::hash::Hash;
 use std::ops::Deref;
 use std::ops::DerefMut;
 use tonic::transport::Channel;
+use tracing::debug;
 
 #[derive(Debug, Clone)]
 pub struct VNode {
@@ -49,15 +50,14 @@ impl<'a> Injector for MetadataMap<'a> {
 
 impl VNode {
     pub async fn new(address: String, port: u16) -> Result<Self, tonic::transport::Error> {
-        let client =
-            RatelimiterClient::connect(format!("http://{}:{}", address.clone(), port)).await?;
+        let host = format!("http://{}:{}", address.clone(), port);
+        debug!("connecting to {}", host);
+        let client = RatelimiterClient::connect(host).await?;
 
         Ok(VNode { client, address })
     }
 }
 
-unsafe impl Send for VNode {}
-
 #[repr(transparent)]
 #[derive(Default)]
 pub struct HashRingWrapper(hashring::HashRing<VNode>);
index 02543e65fa3adfc4c129739a61725ff5dabe95e0..b96f368b904e388168c16549f5057188959452fb 100644 (file)
@@ -32,7 +32,6 @@ where
 pub struct Discord {
     #[serde(deserialize_with = "deserialize_pk")]
     pub public_key: PublicKey,
-    pub client_id: u32,
 }
 
 #[derive(Debug, Deserialize, Clone, Default, Copy)]
index ece7b85c8951ca48a24c53d4007586939fda88b4..05221d3de71b01bce7e966cb1f16790bac0a4274 100644 (file)
@@ -1,5 +1,6 @@
 use ed25519_dalek::{PublicKey, Signature, Verifier};
 
+#[inline]
 pub fn validate_signature(public_key: &PublicKey, data: &[u8], hex_signature: &str) -> bool {
     let mut slice: [u8; Signature::BYTE_SIZE] = [0; Signature::BYTE_SIZE];
     let signature_result = hex::decode_to_slice(hex_signature, &mut slice);
diff --git a/internal/pkg/all-in-one/.gitignore b/internal/pkg/all-in-one/.gitignore
new file mode 100644 (file)
index 0000000..ca1584e
--- /dev/null
@@ -0,0 +1 @@
+all-in-one.h
diff --git a/internal/pkg/all-in-one/all-in-one.go b/internal/pkg/all-in-one/all-in-one.go
new file mode 100644 (file)
index 0000000..76c11f2
--- /dev/null
@@ -0,0 +1,87 @@
+package allinone
+
+/*
+#cgo LDFLAGS: -L../../../build/lib -lall_in_one -lz -lm
+#include "./all-in-one.h"
+#include "./error_handler.h"
+*/
+import "C"
+import (
+       "fmt"
+       "time"
+       "unsafe"
+
+       "github.com/Jeffail/gabs"
+       "github.com/alicebob/miniredis/v2"
+       "github.com/nats-io/nats-server/v2/server"
+)
+
+type AllInOne struct {
+       redis    *miniredis.Miniredis
+       nats     *server.Server
+       instance *C.AllInOneInstance
+}
+
+//export goErrorHandler
+func goErrorHandler(size C.int, start *C.char) {
+       dest := make([]byte, size)
+       copy(dest, (*(*[1024]byte)(unsafe.Pointer(start)))[:size:size])
+
+       println("Error from all in one runner: %s", string(dest))
+}
+
+func NewAllInOne() (*AllInOne, error) {
+       redis := miniredis.NewMiniRedis()
+       nats, err := server.NewServer(&server.Options{})
+
+       if err != nil {
+               return nil, err
+       }
+
+       return &AllInOne{
+               redis: redis,
+               nats:  nats,
+       }, nil
+}
+
+func (s *AllInOne) Start() error {
+       err := s.redis.Start()
+       if err != nil {
+               return err
+       }
+
+       go s.nats.Start()
+
+       if !s.nats.ReadyForConnections(5 * time.Second) {
+               return fmt.Errorf("nats server didn't start after 5 seconds, please check if there is another service listening on the same port as nats")
+       }
+       handler := C.ErrorHandler(C.allInOneErrorHandler)
+       // Set the error handler
+       C.set_error_handler(handler)
+       config := C.GoString(C.load_config())
+
+       json, _ := gabs.ParseJSON([]byte(config))
+       json.Set(fmt.Sprintf("redis://%s", s.redis.Addr()), "redis", "url")
+       json.Set("localhost", "nats", "host")
+       json.Set(1, "webhook", "discord", "client_id")
+
+       a := ""
+       a += ("Starting nova All-in-one!\n")
+       a += fmt.Sprintf(" * Rest proxy running on         : http://%s\n", json.Path("rest.server.listening_adress").Data().(string))
+       a += fmt.Sprintf(" * Webhook server running on     : http://%s\n", json.Path("webhook.server.listening_adress").Data().(string))
+       a += fmt.Sprintf(" * Ratelimiter server running on : grpc://%s\n", json.Path("ratelimiter.server.listening_adress").Data().(string))
+       a += (" * The gateway server should be running\n")
+       a += (" * The cache server should be running\n")
+       a += (" * Servers\n")
+       a += fmt.Sprintf("    * Running MiniREDIS on %s\n", s.redis.Addr())
+       a += fmt.Sprintf("    * Running NATS on %s\n", s.nats.ClientURL())
+       s.instance = C.create_instance(C.CString(json.String()))
+
+       print(a)
+
+       return nil
+}
+
+func (s *AllInOne) Stop() {
+       C.stop_instance(s.instance)
+}
diff --git a/internal/pkg/all-in-one/error_handler.h b/internal/pkg/all-in-one/error_handler.h
new file mode 100644 (file)
index 0000000..e04f68d
--- /dev/null
@@ -0,0 +1,8 @@
+extern void goErrorHandler(int, char*);
+
+typedef void (*ErrorHandler)(int, char*);
+
+__attribute__((weak))
+void allInOneErrorHandler(int size, char* string) {
+  goErrorHandler(size, string);
+}
\ No newline at end of file
index 73bc4493781408a516de20d1f6b38f27c4b68ff6..52f6c9228a6fdc839d724c6225a8a27e68f34b87 100644 (file)
@@ -2,8 +2,8 @@ use config::{Config, Environment, File};
 use serde::{de::DeserializeOwned, Deserialize};
 use std::{env, ops::Deref};
 use tracing::info;
+use anyhow::Result;
 
-use crate::error::GenericError;
 #[derive(Debug, Deserialize, Clone)]
 pub struct Settings<T: Clone + DeserializeOwned + Default> {
     #[serde(skip_deserializing)]
@@ -13,7 +13,7 @@ pub struct Settings<T: Clone + DeserializeOwned + Default> {
 }
 
 impl<T: Clone + DeserializeOwned + Default> Settings<T> {
-    pub fn new(service_name: &str) -> Result<Settings<T>, GenericError> {
+    pub fn new(service_name: &str) -> Result<Settings<T>> {
         let mut builder = Config::builder();
 
         builder = builder.add_source(File::with_name("config/default"));
diff --git a/libs/shared/src/error.rs b/libs/shared/src/error.rs
deleted file mode 100644 (file)
index 990dd1c..0000000
+++ /dev/null
@@ -1,18 +0,0 @@
-use config::ConfigError;
-use std::{fmt::Debug, io};
-use thiserror::Error;
-
-#[derive(Debug, Error)]
-pub enum GenericError {
-    #[error("invalid configuration")]
-    InvalidConfiguration(#[from] ConfigError),
-
-    #[error("invalid parameter `{0}`")]
-    InvalidParameter(String),
-
-    #[error("step `{0}` failed")]
-    StepFailed(String),
-
-    #[error("io error")]
-    Io(#[from] io::Error),
-}
index 68ff335f408079d881681b698c63b6b4b3a5d0ae..a714a1bb71f944911fca61c5e1a56bc0a145c5b2 100644 (file)
@@ -1,7 +1,6 @@
 /// This crate is all the utilities shared by the nova rust projects
 /// It includes logging, config and protocols.
 pub mod config;
-pub mod error;
 pub mod nats;
 pub mod payloads;
 pub mod redis;
index 3d7c75af7a7dd3b21a19807017f38204260323f8..ff1db0fcd2ee1b07f598eb907043e98e521b854e 100644 (file)
@@ -1,24 +1,20 @@
 syntax = "proto3";
 
+import "google/protobuf/empty.proto";
+
 package nova.ratelimit.ratelimiter;
 
 service Ratelimiter {
-    rpc SubmitTicket(stream BucketSubmitTicketRequest) returns (stream BucketSubmitTicketResponse);
+    rpc SubmitTicket(BucketSubmitTicketRequest) returns (google.protobuf.Empty);
+    rpc SubmitHeaders(HeadersSubmitRequest) returns (google.protobuf.Empty);
 }
 
 message BucketSubmitTicketRequest {
-    oneof data {
-        string path = 1;
-        Headers headers = 2;
-    }
-
-    message Headers {
-        map<string, string> headers = 1;
-        uint64 precise_time = 2;
-    }
-    
+    string path = 1;
 }
 
-message BucketSubmitTicketResponse {
-    int64 accepted = 1;
-}
+message HeadersSubmitRequest {
+    map<string, string> headers = 1;
+    uint64 precise_time = 2;
+    string path = 3;
+}
\ No newline at end of file