From 6d3caebc79f340d26262adc2ba789b10da4108aa Mon Sep 17 00:00:00 2001 From: Matthieu Date: Sun, 19 Sep 2021 18:12:41 +0400 Subject: [PATCH] shard + implementation of the payload deserialization & restructure --- .gitignore | 3 +- Cargo.lock | 71 ++++++++ bazel/docker.bzl | 3 +- common/rust/src/error.rs | 8 +- gateway/Cargo.toml | 5 +- gateway/config/default.toml | 5 + gateway/src/client/connection/utils.rs | 29 ---- gateway/src/client/mod.rs | 5 - gateway/src/client/payloads/dispatch.rs | 20 --- gateway/src/client/payloads/gateway.rs | 86 ---------- gateway/src/client/payloads/mod.rs | 3 - .../src/client/payloads/payloads/identify.rs | 19 --- gateway/src/client/payloads/payloads/mod.rs | 2 - gateway/src/client/payloads/structs.rs | 35 ---- gateway/src/client/shard/actions.rs | 68 -------- gateway/src/client/shard/connection.rs | 18 --- gateway/src/client/shard/mod.rs | 22 --- gateway/src/{client => }/connection/mod.rs | 10 +- gateway/src/{client => }/connection/stream.rs | 39 +++-- gateway/src/connection/utils.rs | 42 +++++ .../src/{client/error_utils.rs => error.rs} | 0 gateway/src/main.rs | 48 ++---- gateway/src/payloads/dispatch.rs | 38 +++++ gateway/src/payloads/events/mod.rs | 1 + gateway/src/payloads/events/ready.rs | 13 ++ gateway/src/payloads/events/resume.rs | 0 gateway/src/payloads/gateway.rs | 91 +++++++++++ gateway/src/payloads/mod.rs | 4 + .../payloads => payloads/opcodes}/hello.rs | 1 + gateway/src/payloads/opcodes/identify.rs | 47 ++++++ gateway/src/payloads/opcodes/mod.rs | 22 +++ gateway/src/payloads/opcodes/presence.rs | 63 ++++++++ gateway/src/payloads/opcodes/resume.rs | 8 + gateway/src/shard/actions.rs | 113 +++++++++++++ gateway/src/shard/connection.rs | 151 ++++++++++++++++++ gateway/src/shard/mod.rs | 50 ++++++ gateway/src/{client => }/shard/state.rs | 1 + gateway/src/{client => }/utils.rs | 1 - 38 files changed, 772 insertions(+), 373 deletions(-) delete mode 100644 gateway/src/client/connection/utils.rs delete mode 100644 gateway/src/client/mod.rs delete mode 100644 gateway/src/client/payloads/dispatch.rs delete mode 100644 gateway/src/client/payloads/gateway.rs delete mode 100644 gateway/src/client/payloads/mod.rs delete mode 100644 gateway/src/client/payloads/payloads/identify.rs delete mode 100644 gateway/src/client/payloads/payloads/mod.rs delete mode 100644 gateway/src/client/payloads/structs.rs delete mode 100644 gateway/src/client/shard/actions.rs delete mode 100644 gateway/src/client/shard/connection.rs delete mode 100644 gateway/src/client/shard/mod.rs rename gateway/src/{client => }/connection/mod.rs (83%) rename gateway/src/{client => }/connection/stream.rs (66%) create mode 100644 gateway/src/connection/utils.rs rename gateway/src/{client/error_utils.rs => error.rs} (100%) create mode 100644 gateway/src/payloads/dispatch.rs create mode 100644 gateway/src/payloads/events/mod.rs create mode 100644 gateway/src/payloads/events/ready.rs create mode 100644 gateway/src/payloads/events/resume.rs create mode 100644 gateway/src/payloads/gateway.rs create mode 100644 gateway/src/payloads/mod.rs rename gateway/src/{client/payloads/payloads => payloads/opcodes}/hello.rs (70%) create mode 100644 gateway/src/payloads/opcodes/identify.rs create mode 100644 gateway/src/payloads/opcodes/mod.rs create mode 100644 gateway/src/payloads/opcodes/presence.rs create mode 100644 gateway/src/payloads/opcodes/resume.rs create mode 100644 gateway/src/shard/actions.rs create mode 100644 gateway/src/shard/connection.rs create mode 100644 gateway/src/shard/mod.rs rename gateway/src/{client => }/shard/state.rs (94%) rename gateway/src/{client => }/utils.rs (99%) diff --git a/.gitignore b/.gitignore index 5563ebd..2cd395b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ bazel-* .vscode ratelimiter/target -target/ \ No newline at end of file +target/ +**/local* \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index a5a9375..fa35862 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -682,6 +682,9 @@ dependencies = [ "futures-core", "futures-util", "log", + "num", + "num-derive", + "num-traits 0.2.14", "pretty_env_logger", "serde 1.0.130", "serde_json", @@ -1227,6 +1230,51 @@ dependencies = [ "rand 0.8.4", ] +[[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits 0.2.14", +] + +[[package]] +name = "num-bigint" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74e768dff5fb39a41b3bcd30bb25cf989706c90d028d1ad71971987aa309d535" +dependencies = [ + "autocfg", + "num-integer", + "num-traits 0.2.14", +] + +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits 0.2.14", +] + +[[package]] +name = "num-derive" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "num-integer" version = "0.1.44" @@ -1237,6 +1285,29 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-iter" +version = "0.1.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2021c8337a54d21aca0d59a92577a029af9431cb59b909b03252b9c164fad59" +dependencies = [ + "autocfg", + "num-integer", + "num-traits 0.2.14", +] + +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits 0.2.14", +] + [[package]] name = "num-traits" version = "0.1.43" diff --git a/bazel/docker.bzl b/bazel/docker.bzl index be6a1d0..7ec1674 100644 --- a/bazel/docker.bzl +++ b/bazel/docker.bzl @@ -3,10 +3,9 @@ load("@io_bazel_rules_docker//toolchains/docker:toolchain.bzl", "toolchain_configure") load("@io_bazel_rules_docker//repositories:repositories.bzl", "repositories") load("@io_bazel_rules_docker//repositories:deps.bzl", "deps") -load("@io_bazel_rules_docker//container:container.bzl", "container_pull") +load("@io_bazel_rules_docker//container:container.bzl", "container_pull", "container_image") load("@io_bazel_rules_docker//docker/package_managers:download_pkgs.bzl", "download_pkgs") load("@io_bazel_rules_docker//docker/package_managers:install_pkgs.bzl", "install_pkgs") -load("@io_bazel_rules_docker//container:container.bzl", "container_image") load( "@io_bazel_rules_docker//go:image.bzl", diff --git a/common/rust/src/error.rs b/common/rust/src/error.rs index dcb7a54..b602940 100644 --- a/common/rust/src/error.rs +++ b/common/rust/src/error.rs @@ -1,5 +1,6 @@ use std::fmt; +#[derive(Debug)] pub struct NovaError { pub message: String, } @@ -9,10 +10,3 @@ impl fmt::Display for NovaError { write!(f, "An error occured wihind the nova system: {}", self.message) // user-facing output } } - -impl fmt::Debug for NovaError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{{ file: {}, line: {} }}", file!(), line!()) // programmer-facing output - } -} - diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index b72a91e..d45615f 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -22,4 +22,7 @@ enumflags2 = { version ="0.7.1", features = ["serde"] } common = { path = "../common/rust" } tokio-scoped = "0.1.0" futures = "0.3.17" -async-trait = "0.1.51" \ No newline at end of file +async-trait = "0.1.51" +num-traits = "0.2" +num-derive = "0.3" +num = "0.4" \ No newline at end of file diff --git a/gateway/config/default.toml b/gateway/config/default.toml index d999fc9..252ff32 100644 --- a/gateway/config/default.toml +++ b/gateway/config/default.toml @@ -5,3 +5,8 @@ enabled = false host = "localhost" [gateway] +max_reconnects = 5 +reconnect_delay_growth_factor = 1.25 +reconnect_delay_minimum = 5000 +reconnect_delay_maximum = 60000 +intents = 32767 \ No newline at end of file diff --git a/gateway/src/client/connection/utils.rs b/gateway/src/client/connection/utils.rs deleted file mode 100644 index 49ccbcc..0000000 --- a/gateway/src/client/connection/utils.rs +++ /dev/null @@ -1,29 +0,0 @@ -use super::Connection; -use crate::client::{error_utils::GatewayError}; -use std::str::from_utf8; -use tokio_tungstenite::tungstenite::Message; - -impl Connection { - pub(crate) async fn _handle_message( - &mut self, - data: &Message, - ) -> Result { - match data { - Message::Text(text) => self._handle_discord_message(&text).await, - Message::Binary(message) => { - self._handle_discord_message(from_utf8(message).unwrap()) - .await - } - _ => Err(GatewayError::from("unknown error".to_string())), - } - } - - async fn _handle_discord_message( - &mut self, - raw_message: &str, - ) -> Result { - let a: Result = serde_json::from_str(raw_message); - let message = a.unwrap(); - Ok(message) - } -} diff --git a/gateway/src/client/mod.rs b/gateway/src/client/mod.rs deleted file mode 100644 index 51d8995..0000000 --- a/gateway/src/client/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod connection; -pub mod payloads; -pub mod shard; -pub mod utils; -mod error_utils; \ No newline at end of file diff --git a/gateway/src/client/payloads/dispatch.rs b/gateway/src/client/payloads/dispatch.rs deleted file mode 100644 index 62893b1..0000000 --- a/gateway/src/client/payloads/dispatch.rs +++ /dev/null @@ -1,20 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct Ready { - #[serde(rename = "v")] - version: u64, - user: Value, - guilds: Vec, - session_id: String, - shard: Option<[i64;2]>, - application: Value, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde(tag = "t", content = "d")] -pub enum Dispatch { - #[serde(rename = "READY")] - Ready(Ready), -} \ No newline at end of file diff --git a/gateway/src/client/payloads/gateway.rs b/gateway/src/client/payloads/gateway.rs deleted file mode 100644 index 788a05b..0000000 --- a/gateway/src/client/payloads/gateway.rs +++ /dev/null @@ -1,86 +0,0 @@ -use super::dispatch::Dispatch; -use super::payloads::hello::Hello; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use serde_repr::{Deserialize_repr, Serialize_repr}; - -macro_rules! num_to_enum { - ($num:expr => $enm:ident<$tpe:ty>{ $($fld:ident),+ }; $err:expr) => ({ - match $num { - $(_ if $num == $enm::$fld as $tpe => { $enm::$fld })+ - _ => $err - } - }); -} - -#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug)] -#[repr(u8)] -pub enum OpCodes { - Dispatch = 0, - Heartbeat = 1, - Identify = 2, - PresenceUpdate = 3, - VoiceStateUpdate = 4, - Resume = 6, - Reconnect = 7, - RequestGuildMembers = 8, - InvalidSession = 9, - Hello = 10, - HeartbeatACK = 11, -} - -#[derive(Serialize, Deserialize, PartialEq, Debug)] -#[serde(bound(deserialize = "T: Deserialize<'de>"))] -pub struct FullMessage { - #[serde(rename = "d")] - pub dispatch_type: Option, - #[serde(rename = "s")] - pub sequence: Option, - pub op: OpCodes, - #[serde(rename = "d")] - pub data: T, -} - -pub enum Message { - Dispatch(FullMessage), - Reconnect(FullMessage<()>), - InvalidSession(FullMessage), - Hello(FullMessage), - HeartbeatACK(FullMessage<()>), -} - -impl<'de> serde::Deserialize<'de> for Message { - fn deserialize>(d: D) -> Result { - let value = Value::deserialize(d)?; - let val = value.get("op").and_then(Value::as_u64).unwrap(); - let op_code = num_to_enum!( - val => OpCodes{ - Dispatch, - Heartbeat, - Identify, - PresenceUpdate, - VoiceStateUpdate, - Resume, - Reconnect, - RequestGuildMembers, - InvalidSession, - Hello, - HeartbeatACK - }; - panic!("Cannot convert number to `MyEnum`") - ); - - match op_code { - OpCodes::Dispatch => Ok(Message::Dispatch(FullMessage::deserialize(value).unwrap())), - OpCodes::Reconnect => Ok(Message::Reconnect(FullMessage::deserialize(value).unwrap())), - OpCodes::InvalidSession => Ok(Message::InvalidSession( - FullMessage::deserialize(value).unwrap(), - )), - OpCodes::Hello => Ok(Message::Hello(FullMessage::deserialize(value).unwrap())), - OpCodes::HeartbeatACK => Ok(Message::HeartbeatACK( - FullMessage::deserialize(value).unwrap(), - )), - _ => panic!("Cannot convert"), - } - } -} diff --git a/gateway/src/client/payloads/mod.rs b/gateway/src/client/payloads/mod.rs deleted file mode 100644 index e43a323..0000000 --- a/gateway/src/client/payloads/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod payloads; -pub mod dispatch; -pub mod gateway; \ No newline at end of file diff --git a/gateway/src/client/payloads/payloads/identify.rs b/gateway/src/client/payloads/payloads/identify.rs deleted file mode 100644 index 83f038a..0000000 --- a/gateway/src/client/payloads/payloads/identify.rs +++ /dev/null @@ -1,19 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct IdentifyProprerties { - #[serde(rename = "$os")] - pub os: String, - #[serde(rename = "$browser")] - pub browser: String, - #[serde(rename = "$device")] - pub device: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Identify { - pub token: String, - pub intents: u16, - pub properties: IdentifyProprerties, - pub shard: Option<[i64; 2]>, -} \ No newline at end of file diff --git a/gateway/src/client/payloads/payloads/mod.rs b/gateway/src/client/payloads/payloads/mod.rs deleted file mode 100644 index aa5a6de..0000000 --- a/gateway/src/client/payloads/payloads/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod hello; -pub mod identify; \ No newline at end of file diff --git a/gateway/src/client/payloads/structs.rs b/gateway/src/client/payloads/structs.rs deleted file mode 100644 index 1f186c6..0000000 --- a/gateway/src/client/payloads/structs.rs +++ /dev/null @@ -1,35 +0,0 @@ -use enumflags2::{bitflags, BitFlags}; - -#[bitflags] -#[repr(u16)] -#[derive(Clone, Copy, Debug)] -pub enum Intents { - Guilds = 1 << 0, - GuildMembers = 1 << 1, - GuildBans = 1 << 2, - GuildEmojisAndStickers = 1 << 3, - GuildIntegrations = 1 << 4, - GuildWebhoks = 1 << 5, - GuildInvites = 1 << 6, - GuildVoiceStates = 1 << 7, - GuildPresences = 1 << 8, - GuildMessages = 1 << 9, - GuildMessagesReactions = 1 << 10, - GuildMessageTyping = 1 << 11, - DirectMessages = 1 << 12, - DirectMessagesReactions = 1 << 13, - DirectMessageTyping = 1 << 14, -} - -pub struct Sharding { - pub total_shards: i64, - pub current_shard: i64 -} - -/// Config for the client connection. -pub struct ClientConfig { - pub token: String, - pub large_threshold: Option, - pub shard: Option, - pub intents: BitFlags -} \ No newline at end of file diff --git a/gateway/src/client/shard/actions.rs b/gateway/src/client/shard/actions.rs deleted file mode 100644 index cb29ace..0000000 --- a/gateway/src/client/shard/actions.rs +++ /dev/null @@ -1,68 +0,0 @@ -use futures::SinkExt; -use log::error; -use serde_json::Value; - -use crate::client::payloads::gateway::{FullMessage, OpCodes}; - -use super::Shard; - -/// Implement the available actions for nova in the gateway. -impl Shard { - /// Updates the presence of the current shard. - #[allow(dead_code)] - pub async fn presence_update(&mut self) -> Result<(), ()> { - if let Some(connection) = &mut self.connection { - connection - .send(FullMessage { - dispatch_type: None, - sequence: None, - op: OpCodes::PresenceUpdate, - // todo: proper payload for this - data: Value::Null, - }) - .await - .unwrap(); - } else { - error!("the connection is not open") - } - Ok(()) - } - /// Updates the voice status of the current shard in a certain channel. - #[allow(dead_code)] - pub async fn voice_state_update(&mut self) -> Result<(), ()> { - if let Some(connection) = &mut self.connection { - connection - .send(FullMessage { - dispatch_type: None, - sequence: None, - op: OpCodes::VoiceStateUpdate, - // todo: proper payload for this - data: Value::Null, - }) - .await - .unwrap(); - } else { - error!("the connection is not open") - } - Ok(()) - } - /// Ask discord for more informations about offline guild members. - #[allow(dead_code)] - pub async fn request_guild_members(&mut self) -> Result<(), ()> { - if let Some(connection) = &mut self.connection { - connection - .send(FullMessage { - dispatch_type: None, - sequence: None, - op: OpCodes::RequestGuildMembers, - // todo: proper payload for this - data: Value::Null, - }) - .await - .unwrap(); - } else { - error!("the connection is not open") - } - Ok(()) - } -} diff --git a/gateway/src/client/shard/connection.rs b/gateway/src/client/shard/connection.rs deleted file mode 100644 index 3395ff2..0000000 --- a/gateway/src/client/shard/connection.rs +++ /dev/null @@ -1,18 +0,0 @@ -use super::Shard; -use crate::client::connection::Connection; -use log::info; - -impl Shard { - pub async fn start(self: &mut Self) { - let mut should_exit = false; - - while !should_exit { - info!("Starting connection for shard"); - // create the new connection - let mut connection = Connection::new(); - connection.start().await.unwrap(); - self.connection = Some(connection); - should_exit = true; - } - } -} diff --git a/gateway/src/client/shard/mod.rs b/gateway/src/client/shard/mod.rs deleted file mode 100644 index aec93d6..0000000 --- a/gateway/src/client/shard/mod.rs +++ /dev/null @@ -1,22 +0,0 @@ -use self::state::SessionState; - -use super::connection::Connection; -mod actions; -mod connection; -mod state; - -/// Represents a shard & all the reconnection logic related to it -pub struct Shard { - connection: Option, - state: SessionState, -} - -impl Shard { - /// Creates a new shard instance - pub fn new() -> Self { - Shard { - connection: None, - state: SessionState::default(), - } - } -} diff --git a/gateway/src/client/connection/mod.rs b/gateway/src/connection/mod.rs similarity index 83% rename from gateway/src/client/connection/mod.rs rename to gateway/src/connection/mod.rs index 24ef334..c60068a 100644 --- a/gateway/src/client/connection/mod.rs +++ b/gateway/src/connection/mod.rs @@ -1,8 +1,8 @@ -use super::{error_utils::GatewayError, utils::get_gateway_url}; use tokio::net::TcpStream; -use tokio_tungstenite::{ - connect_async, tungstenite::handshake::client::Request, MaybeTlsStream, WebSocketStream, -}; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::http::Request}; + +use crate::{error::GatewayError, utils::get_gateway_url}; + mod stream; mod utils; @@ -33,4 +33,4 @@ impl Connection { Ok(()) } } -} +} \ No newline at end of file diff --git a/gateway/src/client/connection/stream.rs b/gateway/src/connection/stream.rs similarity index 66% rename from gateway/src/client/connection/stream.rs rename to gateway/src/connection/stream.rs index 6a6f5c9..dbfab60 100644 --- a/gateway/src/client/connection/stream.rs +++ b/gateway/src/connection/stream.rs @@ -1,5 +1,6 @@ +use crate::{error::GatewayError, payloads::gateway::BaseMessage}; + use super::Connection; -use crate::client::{error_utils::GatewayError}; use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt}; use log::info; use serde::Serialize; @@ -9,8 +10,9 @@ use std::{ }; use tokio_tungstenite::tungstenite::Message; +/// Implementation of the Stream trait for the Connection impl Stream for Connection { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // first, when a poll is called, we check if the connection is still open @@ -30,7 +32,7 @@ impl Stream for Connection { Err(e) => Poll::Ready(Some(Err(e))), }, // unknown behaviour? - Poll::Pending => unimplemented!(), + Poll::Pending => unreachable!(), } } Err(e) => Poll::Ready(Some(Err(GatewayError::from(e)))), @@ -38,7 +40,7 @@ impl Stream for Connection { // if no message is available, we return none, it's the end of the stream None => { info!("tokio_tungstenite stream finished successfully"); - Box::pin(conn.close(None)).poll_unpin(cx); + let _ = Box::pin(conn.close(None)).poll_unpin(cx); self.connection = None; Poll::Ready(None) } @@ -53,21 +55,22 @@ impl Stream for Connection { } } -impl Sink> for Connection { - type Error = GatewayError; +/// Implementation of the Sink trait for the Connection +impl Sink> for Connection { + type Error = tokio_tungstenite::tungstenite::Error; #[allow(dead_code)] - fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - if let Some(_) = &self.connection { + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(conn) = &mut self.connection { // a connection exists, we can send data - Poll::Ready(Ok(())) + conn.poll_ready_unpin(cx) } else { Poll::Pending } } #[allow(dead_code)] - fn start_send(mut self: Pin<&mut Self>, item: crate::client::payloads::gateway::FullMessage) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: BaseMessage) -> Result<(), Self::Error> { if let Some(conn) = &mut self.connection { let message = serde_json::to_string(&item); conn.start_send_unpin(Message::Text(message.unwrap())) @@ -77,12 +80,20 @@ impl Sink> for Co } #[allow(dead_code)] - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(conn) = &mut self.connection { + conn.poll_flush_unpin(cx) + } else { + Poll::Pending + } } #[allow(dead_code)] - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(conn) = &mut self.connection { + conn.poll_close_unpin(cx) + } else { + Poll::Pending + } } } diff --git a/gateway/src/connection/utils.rs b/gateway/src/connection/utils.rs new file mode 100644 index 0000000..fb07229 --- /dev/null +++ b/gateway/src/connection/utils.rs @@ -0,0 +1,42 @@ +use std::str::from_utf8; +use tokio_tungstenite::tungstenite::Message; +use log::info; + +use crate::error::GatewayError; + +use super::Connection; + +impl Connection { + + /// Handles the websocket events and calls the _handle_discord_message function for the deserialization. + pub(super) async fn _handle_message( + &mut self, + data: &Message, + ) -> Result { + match data { + Message::Text(text) => self._handle_discord_message(&text).await, + Message::Binary(message) => { + match from_utf8(message) { + Ok(data) => self._handle_discord_message(data).await, + Err(err) => Err(GatewayError::from(err.to_string())), + } + }, + Message::Close(close_frame) => { + info!("Discord connection closed {:?}", close_frame); + Err(GatewayError::from("connection closed".to_string())) + }, + _ => Err(GatewayError::from(format!("unknown variant of message specified to the handler {}", data).to_string())), + } + } + + /// Handle the decompression and deserialization process of a discord payload. + pub(super) async fn _handle_discord_message( + &mut self, + raw_message: &str, + ) -> Result { + match serde_json::from_str(raw_message) { + Ok(message) => Ok(message), + Err(err) => Err(GatewayError::from(err.to_string())), + } + } +} diff --git a/gateway/src/client/error_utils.rs b/gateway/src/error.rs similarity index 100% rename from gateway/src/client/error_utils.rs rename to gateway/src/error.rs diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 003a903..4c42c7a 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -1,41 +1,19 @@ -mod client; - -use client::connection::Connection; use common::config::Settings; -use futures::StreamExt; -use log::info; -use serde_json::Value; - -use crate::client::payloads::{dispatch::Dispatch, gateway::{FullMessage, Message, OpCodes}, payloads::identify::{Identify, IdentifyProprerties}}; +use shard::{Shard, ShardConfig}; +#[macro_use] +extern crate num_derive; -#[tokio::main] -async fn main() { - let settings: Settings = Settings::new("gateway").unwrap(); +pub mod connection; +mod error; +mod utils; +mod shard; +mod payloads; - let mut conn = Connection::new(); - conn.start().await.unwrap(); - loop { - if let Some(val) = conn.next().await { - let data = val.as_ref().unwrap(); - match data { - Message::Dispatch(dispatch) => { - match &dispatch.data { - Dispatch::Ready(_ready) => { - - }, - } - }, - Message::Reconnect(_) => todo!(), - Message::InvalidSession(_) => todo!(), - Message::Hello(_hello) => { - info!("Server said hello! {:?}", _hello); - }, - Message::HeartbeatACK(_) => todo!(), - } - } else { - break; - } - } +#[tokio::main] +async fn main() { + let settings: Settings = Settings::new("gateway").unwrap(); + let mut shard = Shard::new(settings.config); + shard.start().await; } diff --git a/gateway/src/payloads/dispatch.rs b/gateway/src/payloads/dispatch.rs new file mode 100644 index 0000000..b2ddd89 --- /dev/null +++ b/gateway/src/payloads/dispatch.rs @@ -0,0 +1,38 @@ +use log::info; +use serde::{Deserialize, Deserializer}; + +use serde_json::Value; + +use super::{events::ready::Ready, opcodes::OpCodes}; + +/// Represents an unknown event not handled by the gateway itself +#[derive(Clone, Debug, PartialEq, Deserialize)] +pub struct UnknownDispatch { + pub t: String, + pub d: Value, + pub s: i64, + pub op: OpCodes, +} + +#[derive(Clone, Debug, PartialEq, Deserialize)] +#[serde(tag = "t", content = "d")] +#[serde(remote = "Dispatch")] +pub enum Dispatch { + #[serde(rename = "READY")] + Ready(Ready), + #[serde(rename = "RESUMED")] + Resumed(()), + + #[serde(skip_deserializing)] + Other(UnknownDispatch), +} + +impl<'de> Deserialize<'de> for Dispatch { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> + { + info!("hey"); + let s = UnknownDispatch::deserialize(deserializer)?; + Ok(Self::Other(s)) + } +} diff --git a/gateway/src/payloads/events/mod.rs b/gateway/src/payloads/events/mod.rs new file mode 100644 index 0000000..3fef2d9 --- /dev/null +++ b/gateway/src/payloads/events/mod.rs @@ -0,0 +1 @@ +pub mod ready; \ No newline at end of file diff --git a/gateway/src/payloads/events/ready.rs b/gateway/src/payloads/events/ready.rs new file mode 100644 index 0000000..a5ec291 --- /dev/null +++ b/gateway/src/payloads/events/ready.rs @@ -0,0 +1,13 @@ +use serde::Deserialize; +use serde_json::Value; + +#[derive(Deserialize, Clone, Debug, PartialEq)] +pub struct Ready { + #[serde(rename = "v")] + pub version: u64, + pub user: Value, + pub guilds: Vec, + pub session_id: String, + pub shard: Option<[i64;2]>, + pub application: Value, +} diff --git a/gateway/src/payloads/events/resume.rs b/gateway/src/payloads/events/resume.rs new file mode 100644 index 0000000..e69de29 diff --git a/gateway/src/payloads/gateway.rs b/gateway/src/payloads/gateway.rs new file mode 100644 index 0000000..e8dff96 --- /dev/null +++ b/gateway/src/payloads/gateway.rs @@ -0,0 +1,91 @@ +use super::{dispatch::Dispatch, opcodes::{OpCodes, hello::Hello}}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde::de::Error; + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +#[serde(bound(deserialize = "T: Deserialize<'de> + std::fmt::Debug"))] +pub struct BaseMessage { + pub t: Option, + #[serde(rename = "s")] + pub sequence: Option, + pub op: OpCodes, + #[serde(rename = "d")] + pub data: T, +} + +#[derive(Debug)] +pub enum Message { + Dispatch(BaseMessage), + Reconnect(BaseMessage<()>), + InvalidSession(BaseMessage), + Hello(BaseMessage), + HeartbeatACK(BaseMessage<()>), +} + +impl<'de> serde::Deserialize<'de> for Message { + fn deserialize>(d: D) -> Result where D::Error : Error { + let value = Value::deserialize(d)?; + let val = value.get("op").and_then(Value::as_u64).unwrap(); + + if let Some(op) = num::FromPrimitive::from_u64(val) { + match op { + OpCodes::Dispatch => { + match Dispatch::deserialize(&value) { + Ok(data) => { + + let mut t = None; + if let Some(t_value) = &value.get("t") { + // this is safe since we know this is a string + t = Some(t_value.to_string()); + } + let mut sequence = None; + + if let Some(sequence_value) = value.get("s") { + if let Some(sequence_uint) = sequence_value.as_u64() { + sequence = Some(sequence_uint); + } + } + + Ok(Message::Dispatch(BaseMessage { + op, + t, + sequence, + data + })) + }, + Err(e) => Err(Error::custom(e)), + } + }, + + OpCodes::Reconnect => { + match BaseMessage::deserialize(value) { + Ok(data) => Ok(Message::Reconnect(data)), + Err(e) => Err(Error::custom(e)), + } + }, + OpCodes::InvalidSession => { + match BaseMessage::deserialize(value) { + Ok(data) => Ok(Message::InvalidSession(data)), + Err(e) => Err(Error::custom(e)), + } + }, + OpCodes::Hello => { + match BaseMessage::deserialize(value) { + Ok(data) => Ok(Message::Hello(data)), + Err(e) => Err(Error::custom(e)), + } + }, + OpCodes::HeartbeatACK => { + match BaseMessage::deserialize(value) { + Ok(data) => Ok(Message::HeartbeatACK(data)), + Err(e) => Err(Error::custom(e)), + } + }, + _ => panic!("Cannot convert"), + } + } else { + todo!(); + } + } +} diff --git a/gateway/src/payloads/mod.rs b/gateway/src/payloads/mod.rs new file mode 100644 index 0000000..e9849a7 --- /dev/null +++ b/gateway/src/payloads/mod.rs @@ -0,0 +1,4 @@ +pub mod opcodes; +pub mod dispatch; +pub mod gateway; +pub mod events; diff --git a/gateway/src/client/payloads/payloads/hello.rs b/gateway/src/payloads/opcodes/hello.rs similarity index 70% rename from gateway/src/client/payloads/payloads/hello.rs rename to gateway/src/payloads/opcodes/hello.rs index 0690a61..3d8fd0f 100644 --- a/gateway/src/client/payloads/payloads/hello.rs +++ b/gateway/src/payloads/opcodes/hello.rs @@ -1,5 +1,6 @@ use serde::{Serialize, Deserialize}; +/// The first message sent by the gateway to initialize the heartbeating #[derive(Debug, Serialize, Deserialize)] pub struct Hello { #[serde(rename = "heartbeat_interval")] diff --git a/gateway/src/payloads/opcodes/identify.rs b/gateway/src/payloads/opcodes/identify.rs new file mode 100644 index 0000000..5929c33 --- /dev/null +++ b/gateway/src/payloads/opcodes/identify.rs @@ -0,0 +1,47 @@ +use enumflags2::{BitFlags, bitflags}; +use serde::{Deserialize, Serialize}; +use super::presence::PresenceUpdate; + + +#[bitflags] +#[repr(u16)] +#[derive(Clone, Copy, Debug)] +pub enum Intents { + Guilds = 1 << 0, + GuildMembers = 1 << 1, + GuildBans = 1 << 2, + GuildEmojisAndStickers = 1 << 3, + GuildIntegrations = 1 << 4, + GuildWebhoks = 1 << 5, + GuildInvites = 1 << 6, + GuildVoiceStates = 1 << 7, + GuildPresences = 1 << 8, + GuildMessages = 1 << 9, + GuildMessagesReactions = 1 << 10, + GuildMessageTyping = 1 << 11, + DirectMessages = 1 << 12, + DirectMessagesReactions = 1 << 13, + DirectMessageTyping = 1 << 14, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IdentifyProprerties { + #[serde(rename = "$os")] + pub os: String, + #[serde(rename = "$browser")] + pub browser: String, + #[serde(rename = "$device")] + pub device: String, +} + +/// Messages sent by the shard to log-in to the gateway. +#[derive(Debug, Serialize, Deserialize)] +pub struct Identify { + pub token: String, + pub properties: IdentifyProprerties, + pub compress: Option, + pub large_threshold: Option, + pub shard: Option<[u64; 2]>, + pub presence: Option, + pub intents: BitFlags, +} \ No newline at end of file diff --git a/gateway/src/payloads/opcodes/mod.rs b/gateway/src/payloads/opcodes/mod.rs new file mode 100644 index 0000000..cfa453a --- /dev/null +++ b/gateway/src/payloads/opcodes/mod.rs @@ -0,0 +1,22 @@ +pub mod hello; +pub mod identify; +pub mod resume; +pub mod presence; +use serde_repr::{Deserialize_repr, Serialize_repr}; + +#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug, Clone, FromPrimitive, ToPrimitive)] +#[repr(u8)] + +pub enum OpCodes { + Dispatch = 0, + Heartbeat = 1, + Identify = 2, + PresenceUpdate = 3, + VoiceStateUpdate = 4, + Resume = 6, + Reconnect = 7, + RequestGuildMembers = 8, + InvalidSession = 9, + Hello = 10, + HeartbeatACK = 11, +} \ No newline at end of file diff --git a/gateway/src/payloads/opcodes/presence.rs b/gateway/src/payloads/opcodes/presence.rs new file mode 100644 index 0000000..a6c5773 --- /dev/null +++ b/gateway/src/payloads/opcodes/presence.rs @@ -0,0 +1,63 @@ +use serde_repr::{Deserialize_repr, Serialize_repr}; +use serde::{Deserialize, Serialize}; +#[derive(Serialize_repr, Deserialize_repr, Debug)] +#[repr(u8)] +pub enum ActivityType { + Game = 0, + Streaming = 1, + Listening = 2, + Watching = 3, + Custom = 4, + Competing = 5 +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ActivityTimestamps { + start: u64, + end: u64, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ActivityEmoji { + name: String, + id: Option, + animated: Option +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Activity { + name: String, + #[serde(rename = "type")] + t: ActivityType, + + url: Option, + created_at: i64, + timestamp: Option, + application_id: Option, + details: Option, + state: Option, + emoji: Option, + // todo: implement more +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum PresenceStatus { + #[serde(rename = "online")] + Online, + #[serde(rename = "dnd")] + Dnd, + #[serde(rename = "idle")] + Idle, + #[serde(rename = "invisible")] + Invisible, + #[serde(rename = "offline")] + Offline +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PresenceUpdate { + since: u64, + activities: Vec, + status: PresenceStatus, + afk: bool, +} diff --git a/gateway/src/payloads/opcodes/resume.rs b/gateway/src/payloads/opcodes/resume.rs new file mode 100644 index 0000000..e1bba91 --- /dev/null +++ b/gateway/src/payloads/opcodes/resume.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Resume { + pub token: String, + pub session_id: String, + pub seq: u64, +} \ No newline at end of file diff --git a/gateway/src/shard/actions.rs b/gateway/src/shard/actions.rs new file mode 100644 index 0000000..513d1a8 --- /dev/null +++ b/gateway/src/shard/actions.rs @@ -0,0 +1,113 @@ +use std::env; + +use futures::SinkExt; +use log::{debug, error}; +use serde::Serialize; +use serde_json::Value; +use std::fmt::Debug; + +use crate::{error::GatewayError, payloads::{gateway::BaseMessage, opcodes::{OpCodes, identify::{Identify, IdentifyProprerties}, presence::PresenceUpdate, resume::Resume}}}; + +use super::Shard; + +/// Implement the available actions for nova in the gateway. +impl Shard { + + /// sends a message through the websocket + pub async fn _send(&mut self, message: BaseMessage) -> Result<(), GatewayError> { + debug!("senging message {:?}", message); + if let Some(connection) = &mut self.connection { + if let Err(e) = connection.conn.send(message).await { + error!("failed to send message {:?}", e); + Err(GatewayError::from(e)) + } else { + Ok(()) + } + } else { + Err(GatewayError::from("no open connection".to_string())) + } + } + + pub async fn _identify(&mut self) -> Result<(), GatewayError> { + if let Some(state) = self.state.clone() { + self._send(BaseMessage{ + t: None, + sequence: None, + op: OpCodes::Resume, + data: Resume { + token: self.config.token.clone(), + seq: state.sequence, + session_id: state.session_id.clone(), + }, + }).await + } else { + self._send(BaseMessage{ + t: None, + sequence: None, + op: OpCodes::Identify, + data: Identify { + token: self.config.token.clone(), + intents: self.config.intents, + properties: IdentifyProprerties { + os: env::consts::OS.to_string(), + browser: "Nova".to_string(), + device: "Nova".to_string(), + }, + shard: Some([0,2]), + compress: Some(false), + large_threshold: Some(500), + presence: None, + }, + }).await + } + } + + pub async fn _disconnect(&mut self) {} + + /// Updates the presence of the current shard. + #[allow(dead_code)] + pub async fn presence_update(&mut self, update: PresenceUpdate) -> Result<(), GatewayError> { + self._send(BaseMessage{ + t: None, + sequence: None, + op: OpCodes::PresenceUpdate, + data: update, + }).await + } + /// Updates the voice status of the current shard in a certain channel. + #[allow(dead_code)] + pub async fn voice_state_update(&mut self) -> Result<(), GatewayError> { + if let Some(connection) = &mut self.connection { + connection.conn + .send(BaseMessage { + t: None, + sequence: None, + op: OpCodes::VoiceStateUpdate, + // todo: proper payload for this + data: Value::Null, + }) + .await? + } else { + error!("the connection is not open") + } + Ok(()) + } + /// Ask discord for more informations about offline guild members. + #[allow(dead_code)] + pub async fn request_guild_members(&mut self) -> Result<(), GatewayError> { + if let Some(connection) = &mut self.connection { + connection.conn + .send(BaseMessage { + t: None, + sequence: None, + op: OpCodes::RequestGuildMembers, + // todo: proper payload for this + data: Value::Null, + }) + .await? + } else { + error!("the connection is not open") + } + Ok(()) + } +} diff --git a/gateway/src/shard/connection.rs b/gateway/src/shard/connection.rs new file mode 100644 index 0000000..6f8503c --- /dev/null +++ b/gateway/src/shard/connection.rs @@ -0,0 +1,151 @@ +use std::{ + cmp::{max, min}, + convert::TryInto, + time::Duration, +}; + +use crate::{ + connection::Connection, + payloads::{ + dispatch::Dispatch, + gateway::{BaseMessage, Message}, + }, + shard::state::SessionState, +}; + +use super::{state::ConnectionState, ConnectionWithState, Shard}; +use futures::StreamExt; +use log::{error, info}; +use tokio::{select, time::sleep}; + +impl Shard { + pub async fn start(self: &mut Self) { + let mut reconnects = 1; + info!("Starting shard"); + + while reconnects < self.config.max_reconnects { + info!("Starting connection for shard"); + self._shard_task().await; + // when the shard got disconnected, the shard task ends + reconnects += 1; + + // wait reconnects min(max(reconnects * reconnect_delay_growth_factor, reconnect_delay_minimum),reconnect_delay_maximum) + if reconnects < self.config.max_reconnects { + let time = min( + self.config.reconnect_delay_maximum, + max( + ((reconnects as f32) * self.config.reconnect_delay_growth_factor) as usize, + self.config.reconnect_delay_minimum, + ), + ); + info!( + "The shard got disconnected, waiting for reconnect ({}ms)", + time + ); + sleep(Duration::from_millis(time.try_into().unwrap())).await; + } + } + info!( + "The shard got disconnected too many times and reached the maximum {}", + self.config.max_reconnects + ); + } + + async fn _shard_task(&mut self) { + // create the new connection + let mut connection = Connection::new(); + connection.start().await.unwrap(); + self.connection = Some(ConnectionWithState { + conn: connection, + state: ConnectionState::default(), + }); + loop { + if let Some(connection) = &mut self.connection { + select!( + payload = connection.conn.next() => { + match payload { + Some(data) => match data { + Ok(message) => self._handle(&message).await, + Err(error) => { + error!("An error occured while being connected to Discord: {:?}", error); + return; + }, + }, + None => { + info!("Connection terminated"); + return; + }, + } + } + ) + } + } + } + + fn _util_set_seq(&mut self, seq: Option) { + if let Some(seq) = seq { + if let Some(state) = &mut self.state { + state.sequence = seq; + } + } + } + + async fn _handle(&mut self, message: &Message) { + match message { + Message::Dispatch(msg) => { + self._util_set_seq(msg.sequence); + self._dispatch(&msg).await; + } + // we need to reconnect to the gateway + Message::Reconnect(msg) => { + self._util_set_seq(msg.sequence); + info!("gateway disconnect requested"); + self._disconnect().await; + } + Message::InvalidSession(msg) => { + self._util_set_seq(msg.sequence); + info!("invalid session"); + let data = msg.data; + if !data { + info!("session removed"); + // reset the session data + self.state = None; + if let Err(e) = self._identify().await { + error!("error while sending identify: {:?}", e); + } + } + } + Message::HeartbeatACK(msg) => { + self._util_set_seq(msg.sequence); + info!("heartbeat ack received"); + } + Message::Hello(msg) => { + self._util_set_seq(msg.sequence); + info!("server hello received"); + if let Err(e) = self._identify().await { + error!("error while sending identify: {:?}", e); + } + }, + } + } + + async fn _dispatch(&mut self, dispatch: &BaseMessage) { + match &dispatch.data { + Dispatch::Ready(ready) => { + info!("received gateway dispatch hello"); + info!( + "Logged in as {}", + ready.user.get("username").unwrap().to_string() + ); + self.state = Some(SessionState { + sequence: dispatch.sequence.unwrap(), + session_id: ready.session_id.clone(), + }); + } + Dispatch::Resumed(_) => { + info!("session resumed"); + } + Dispatch::Other(data) => {} + } + } +} diff --git a/gateway/src/shard/mod.rs b/gateway/src/shard/mod.rs new file mode 100644 index 0000000..b458451 --- /dev/null +++ b/gateway/src/shard/mod.rs @@ -0,0 +1,50 @@ +use enumflags2::BitFlags; +use serde::{Deserialize, Serialize}; +use crate::{connection::Connection, payloads::opcodes::identify::Intents}; +use self::state::{ConnectionState, SessionState}; +mod actions; +mod connection; +mod state; + +#[derive(Debug, Deserialize, Serialize, Default, Clone)] +pub struct Sharding { + pub total_shards: i64, + pub current_shard: i64 +} + + +#[derive(Debug, Deserialize, Serialize, Default, Clone)] +pub struct ShardConfig { + pub max_reconnects: usize, + pub reconnect_delay_growth_factor: f32, + pub reconnect_delay_minimum: usize, + pub reconnect_delay_maximum: usize, + pub token: String, + + pub large_threshold: Option, + pub shard: Option, + pub intents: BitFlags +} + +struct ConnectionWithState { + conn: Connection, + state: ConnectionState, +} + +/// Represents a shard & all the reconnection logic related to it +pub struct Shard { + connection: Option, + state: Option, + config: ShardConfig +} + +impl Shard { + /// Creates a new shard instance + pub fn new(config: ShardConfig) -> Self { + Shard { + connection: None, + state: None, + config, + } + } +} diff --git a/gateway/src/client/shard/state.rs b/gateway/src/shard/state.rs similarity index 94% rename from gateway/src/client/shard/state.rs rename to gateway/src/shard/state.rs index 4d40911..caed268 100644 --- a/gateway/src/client/shard/state.rs +++ b/gateway/src/shard/state.rs @@ -1,6 +1,7 @@ use std::time::Instant; /// This struct represents the state of a session +#[derive(Clone, Debug)] pub struct SessionState { pub sequence: u64, pub session_id: String, diff --git a/gateway/src/client/utils.rs b/gateway/src/utils.rs similarity index 99% rename from gateway/src/client/utils.rs rename to gateway/src/utils.rs index 141740e..48a9aed 100644 --- a/gateway/src/client/utils.rs +++ b/gateway/src/utils.rs @@ -1,4 +1,3 @@ - /// Formats a url of connection to the gateway pub fn get_gateway_url (compress: bool, encoding: &str, v: i16) -> String { return format!( -- 2.39.5