diff options
| author | Matthieu <matthieu@developershouse.xyz> | 2021-09-19 18:12:41 +0400 | 
|---|---|---|
| committer | Matthieu <matthieu@developershouse.xyz> | 2021-09-19 18:12:41 +0400 | 
| commit | 6d3caebc79f340d26262adc2ba789b10da4108aa (patch) | |
| tree | 208e1bfe0cf036c0561555e0fc35beabfd261375 | |
| parent | 88300c45202a228d54c1e99dd3b295ef3fb9aabd (diff) | |
shard + implementation of the payload deserialization & restructure
38 files changed, 772 insertions, 373 deletions
@@ -1,4 +1,5 @@  bazel-*
  .vscode
  ratelimiter/target
 -target/
\ No newline at end of file +target/
 +**/local*
\ No newline at end of file @@ -682,6 +682,9 @@ dependencies = [   "futures-core",   "futures-util",   "log", + "num", + "num-derive", + "num-traits 0.2.14",   "pretty_env_logger",   "serde 1.0.130",   "serde_json", @@ -1228,6 +1231,51 @@ dependencies = [  ]  [[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits 0.2.14", +] + +[[package]] +name = "num-bigint" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74e768dff5fb39a41b3bcd30bb25cf989706c90d028d1ad71971987aa309d535" +dependencies = [ + "autocfg", + "num-integer", + "num-traits 0.2.14", +] + +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits 0.2.14", +] + +[[package]] +name = "num-derive" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]]  name = "num-integer"  version = "0.1.44"  source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1238,6 +1286,29 @@ dependencies = [  ]  [[package]] +name = "num-iter" +version = "0.1.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2021c8337a54d21aca0d59a92577a029af9431cb59b909b03252b9c164fad59" +dependencies = [ + "autocfg", + "num-integer", + "num-traits 0.2.14", +] + +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits 0.2.14", +] + +[[package]]  name = "num-traits"  version = "0.1.43"  source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/bazel/docker.bzl b/bazel/docker.bzl index be6a1d0..7ec1674 100644 --- a/bazel/docker.bzl +++ b/bazel/docker.bzl @@ -3,10 +3,9 @@  load("@io_bazel_rules_docker//toolchains/docker:toolchain.bzl", "toolchain_configure")  load("@io_bazel_rules_docker//repositories:repositories.bzl", "repositories")  load("@io_bazel_rules_docker//repositories:deps.bzl", "deps") -load("@io_bazel_rules_docker//container:container.bzl", "container_pull") +load("@io_bazel_rules_docker//container:container.bzl", "container_pull", "container_image")  load("@io_bazel_rules_docker//docker/package_managers:download_pkgs.bzl", "download_pkgs")  load("@io_bazel_rules_docker//docker/package_managers:install_pkgs.bzl", "install_pkgs") -load("@io_bazel_rules_docker//container:container.bzl", "container_image")  load(      "@io_bazel_rules_docker//go:image.bzl", diff --git a/common/rust/src/error.rs b/common/rust/src/error.rs index dcb7a54..b602940 100644 --- a/common/rust/src/error.rs +++ b/common/rust/src/error.rs @@ -1,5 +1,6 @@  use std::fmt; +#[derive(Debug)]  pub struct NovaError {      pub message: String,  } @@ -9,10 +10,3 @@ impl fmt::Display for NovaError {          write!(f, "An error occured wihind the nova system: {}", self.message) // user-facing output      }  } - -impl fmt::Debug for NovaError { -    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { -        write!(f, "{{ file: {}, line: {} }}", file!(), line!()) // programmer-facing output -    } -} - diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index b72a91e..d45615f 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -22,4 +22,7 @@ enumflags2 = { version ="0.7.1", features = ["serde"] }  common = { path = "../common/rust" }  tokio-scoped = "0.1.0"  futures = "0.3.17" -async-trait = "0.1.51"
\ No newline at end of file +async-trait = "0.1.51" +num-traits = "0.2" +num-derive = "0.3" +num = "0.4"
\ No newline at end of file diff --git a/gateway/config/default.toml b/gateway/config/default.toml index d999fc9..252ff32 100644 --- a/gateway/config/default.toml +++ b/gateway/config/default.toml @@ -5,3 +5,8 @@ enabled = false  host = "localhost"  [gateway] +max_reconnects = 5 +reconnect_delay_growth_factor = 1.25 +reconnect_delay_minimum = 5000 +reconnect_delay_maximum = 60000 +intents = 32767
\ No newline at end of file diff --git a/gateway/src/client/connection/utils.rs b/gateway/src/client/connection/utils.rs deleted file mode 100644 index 49ccbcc..0000000 --- a/gateway/src/client/connection/utils.rs +++ /dev/null @@ -1,29 +0,0 @@ -use super::Connection; -use crate::client::{error_utils::GatewayError}; -use std::str::from_utf8; -use tokio_tungstenite::tungstenite::Message; - -impl Connection { -    pub(crate) async fn _handle_message( -        &mut self, -        data: &Message, -    ) -> Result<crate::client::payloads::gateway::Message, GatewayError> { -        match data { -            Message::Text(text) => self._handle_discord_message(&text).await, -            Message::Binary(message) => { -                self._handle_discord_message(from_utf8(message).unwrap()) -                    .await -            } -            _ => Err(GatewayError::from("unknown error".to_string())), -        } -    } - -    async fn _handle_discord_message( -        &mut self, -        raw_message: &str, -    ) -> Result<crate::client::payloads::gateway::Message, GatewayError> { -        let a: Result<crate::client::payloads::gateway::Message, serde_json::Error> = serde_json::from_str(raw_message); -        let message = a.unwrap(); -        Ok(message) -    } -} diff --git a/gateway/src/client/mod.rs b/gateway/src/client/mod.rs deleted file mode 100644 index 51d8995..0000000 --- a/gateway/src/client/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod connection; -pub mod payloads; -pub mod shard; -pub mod utils; -mod error_utils;
\ No newline at end of file diff --git a/gateway/src/client/payloads/dispatch.rs b/gateway/src/client/payloads/dispatch.rs deleted file mode 100644 index 62893b1..0000000 --- a/gateway/src/client/payloads/dispatch.rs +++ /dev/null @@ -1,20 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct Ready { -    #[serde(rename = "v")] -    version: u64, -    user: Value, -    guilds: Vec<Value>, -    session_id: String, -    shard: Option<[i64;2]>, -    application: Value, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde(tag = "t", content = "d")] -pub enum Dispatch { -    #[serde(rename = "READY")] -    Ready(Ready), -}
\ No newline at end of file diff --git a/gateway/src/client/payloads/gateway.rs b/gateway/src/client/payloads/gateway.rs deleted file mode 100644 index 788a05b..0000000 --- a/gateway/src/client/payloads/gateway.rs +++ /dev/null @@ -1,86 +0,0 @@ -use super::dispatch::Dispatch; -use super::payloads::hello::Hello; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use serde_repr::{Deserialize_repr, Serialize_repr}; - -macro_rules! num_to_enum { -    ($num:expr => $enm:ident<$tpe:ty>{ $($fld:ident),+ }; $err:expr) => ({ -        match $num { -            $(_ if $num == $enm::$fld as $tpe => { $enm::$fld })+ -            _ => $err -        } -    }); -} - -#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug)] -#[repr(u8)] -pub enum OpCodes { -    Dispatch = 0, -    Heartbeat = 1, -    Identify = 2, -    PresenceUpdate = 3, -    VoiceStateUpdate = 4, -    Resume = 6, -    Reconnect = 7, -    RequestGuildMembers = 8, -    InvalidSession = 9, -    Hello = 10, -    HeartbeatACK = 11, -} - -#[derive(Serialize, Deserialize, PartialEq, Debug)] -#[serde(bound(deserialize = "T: Deserialize<'de>"))] -pub struct FullMessage<T> { -    #[serde(rename = "d")] -    pub dispatch_type: Option<String>, -    #[serde(rename = "s")] -    pub sequence: Option<OpCodes>, -    pub op: OpCodes, -    #[serde(rename = "d")] -    pub data: T, -} - -pub enum Message { -    Dispatch(FullMessage<Dispatch>), -    Reconnect(FullMessage<()>), -    InvalidSession(FullMessage<bool>), -    Hello(FullMessage<Hello>), -    HeartbeatACK(FullMessage<()>), -} - -impl<'de> serde::Deserialize<'de> for Message { -    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> { -        let value = Value::deserialize(d)?; -        let val = value.get("op").and_then(Value::as_u64).unwrap(); -        let op_code = num_to_enum!( -            val => OpCodes<u64>{ -                Dispatch, -                Heartbeat, -                Identify, -                PresenceUpdate, -                VoiceStateUpdate, -                Resume, -                Reconnect, -                RequestGuildMembers, -                InvalidSession, -                Hello, -                HeartbeatACK -            }; -            panic!("Cannot convert number to `MyEnum`") -        ); - -        match op_code { -            OpCodes::Dispatch => Ok(Message::Dispatch(FullMessage::deserialize(value).unwrap())), -            OpCodes::Reconnect => Ok(Message::Reconnect(FullMessage::deserialize(value).unwrap())), -            OpCodes::InvalidSession => Ok(Message::InvalidSession( -                FullMessage::deserialize(value).unwrap(), -            )), -            OpCodes::Hello => Ok(Message::Hello(FullMessage::deserialize(value).unwrap())), -            OpCodes::HeartbeatACK => Ok(Message::HeartbeatACK( -                FullMessage::deserialize(value).unwrap(), -            )), -            _ => panic!("Cannot convert"), -        } -    } -} diff --git a/gateway/src/client/payloads/mod.rs b/gateway/src/client/payloads/mod.rs deleted file mode 100644 index e43a323..0000000 --- a/gateway/src/client/payloads/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod payloads; -pub mod dispatch; -pub mod gateway;
\ No newline at end of file diff --git a/gateway/src/client/payloads/payloads/identify.rs b/gateway/src/client/payloads/payloads/identify.rs deleted file mode 100644 index 83f038a..0000000 --- a/gateway/src/client/payloads/payloads/identify.rs +++ /dev/null @@ -1,19 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct IdentifyProprerties { -    #[serde(rename = "$os")] -    pub os: String, -    #[serde(rename = "$browser")] -    pub browser: String, -    #[serde(rename = "$device")] -    pub device: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Identify { -    pub token: String, -    pub intents: u16, -    pub properties: IdentifyProprerties, -    pub shard: Option<[i64; 2]>, -}
\ No newline at end of file diff --git a/gateway/src/client/payloads/payloads/mod.rs b/gateway/src/client/payloads/payloads/mod.rs deleted file mode 100644 index aa5a6de..0000000 --- a/gateway/src/client/payloads/payloads/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod hello; -pub mod identify;
\ No newline at end of file diff --git a/gateway/src/client/payloads/structs.rs b/gateway/src/client/payloads/structs.rs deleted file mode 100644 index 1f186c6..0000000 --- a/gateway/src/client/payloads/structs.rs +++ /dev/null @@ -1,35 +0,0 @@ -use enumflags2::{bitflags, BitFlags}; - -#[bitflags] -#[repr(u16)] -#[derive(Clone, Copy, Debug)] -pub enum Intents { -    Guilds = 1 << 0, -    GuildMembers = 1 << 1, -    GuildBans = 1 << 2, -    GuildEmojisAndStickers = 1 << 3, -    GuildIntegrations = 1 << 4, -    GuildWebhoks = 1 << 5, -    GuildInvites = 1 << 6, -    GuildVoiceStates = 1 << 7, -    GuildPresences = 1 << 8, -    GuildMessages = 1 << 9, -    GuildMessagesReactions = 1 << 10, -    GuildMessageTyping = 1 << 11, -    DirectMessages = 1 << 12, -    DirectMessagesReactions = 1 << 13, -    DirectMessageTyping = 1 << 14, -} - -pub struct Sharding { -    pub total_shards: i64, -    pub current_shard: i64 -} - -/// Config for the client connection. -pub struct ClientConfig { -    pub token: String, -    pub large_threshold: Option<u64>, -    pub shard: Option<Sharding>, -    pub intents: BitFlags<Intents> -}
\ No newline at end of file diff --git a/gateway/src/client/shard/actions.rs b/gateway/src/client/shard/actions.rs deleted file mode 100644 index cb29ace..0000000 --- a/gateway/src/client/shard/actions.rs +++ /dev/null @@ -1,68 +0,0 @@ -use futures::SinkExt; -use log::error; -use serde_json::Value; - -use crate::client::payloads::gateway::{FullMessage, OpCodes}; - -use super::Shard; - -/// Implement the available actions for nova in the gateway. -impl Shard { -    /// Updates the presence of the current shard. -    #[allow(dead_code)] -    pub async fn presence_update(&mut self) -> Result<(), ()> { -        if let Some(connection) = &mut self.connection { -            connection -                .send(FullMessage { -                    dispatch_type: None, -                    sequence: None, -                    op: OpCodes::PresenceUpdate, -                    // todo: proper payload for this -                    data: Value::Null, -                }) -                .await -                .unwrap(); -        } else { -            error!("the connection is not open") -        } -        Ok(()) -    } -    /// Updates the voice status of the current shard in a certain channel. -    #[allow(dead_code)] -    pub async fn voice_state_update(&mut self) -> Result<(), ()> { -        if let Some(connection) = &mut self.connection { -            connection -                .send(FullMessage { -                    dispatch_type: None, -                    sequence: None, -                    op: OpCodes::VoiceStateUpdate, -                    // todo: proper payload for this -                    data: Value::Null, -                }) -                .await -                .unwrap(); -        } else { -            error!("the connection is not open") -        } -        Ok(()) -    } -    /// Ask discord for more informations about offline guild members. -    #[allow(dead_code)] -    pub async fn request_guild_members(&mut self) -> Result<(), ()> { -        if let Some(connection) = &mut self.connection { -            connection -                .send(FullMessage { -                    dispatch_type: None, -                    sequence: None, -                    op: OpCodes::RequestGuildMembers, -                    // todo: proper payload for this -                    data: Value::Null, -                }) -                .await -                .unwrap(); -        } else { -            error!("the connection is not open") -        } -        Ok(()) -    } -} diff --git a/gateway/src/client/shard/connection.rs b/gateway/src/client/shard/connection.rs deleted file mode 100644 index 3395ff2..0000000 --- a/gateway/src/client/shard/connection.rs +++ /dev/null @@ -1,18 +0,0 @@ -use super::Shard; -use crate::client::connection::Connection; -use log::info; - -impl Shard { -    pub async fn start(self: &mut Self) { -        let mut should_exit = false; - -        while !should_exit { -            info!("Starting connection for shard"); -            // create the new connection -            let mut connection = Connection::new(); -            connection.start().await.unwrap(); -            self.connection = Some(connection); -            should_exit = true; -        } -    } -} diff --git a/gateway/src/client/shard/mod.rs b/gateway/src/client/shard/mod.rs deleted file mode 100644 index aec93d6..0000000 --- a/gateway/src/client/shard/mod.rs +++ /dev/null @@ -1,22 +0,0 @@ -use self::state::SessionState; - -use super::connection::Connection; -mod actions; -mod connection; -mod state; - -/// Represents a shard & all the reconnection logic related to it -pub struct Shard { -    connection: Option<Connection>, -    state: SessionState, -} - -impl Shard { -    /// Creates a new shard instance -    pub fn new() -> Self { -        Shard { -            connection: None, -            state: SessionState::default(), -        } -    } -} diff --git a/gateway/src/client/connection/mod.rs b/gateway/src/connection/mod.rs index 24ef334..c60068a 100644 --- a/gateway/src/client/connection/mod.rs +++ b/gateway/src/connection/mod.rs @@ -1,8 +1,8 @@ -use super::{error_utils::GatewayError, utils::get_gateway_url};  use tokio::net::TcpStream; -use tokio_tungstenite::{ -    connect_async, tungstenite::handshake::client::Request, MaybeTlsStream, WebSocketStream, -}; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::http::Request}; + +use crate::{error::GatewayError, utils::get_gateway_url}; +  mod stream;  mod utils; @@ -33,4 +33,4 @@ impl Connection {              Ok(())          }      } -} +}
\ No newline at end of file diff --git a/gateway/src/client/connection/stream.rs b/gateway/src/connection/stream.rs index 6a6f5c9..dbfab60 100644 --- a/gateway/src/client/connection/stream.rs +++ b/gateway/src/connection/stream.rs @@ -1,5 +1,6 @@ +use crate::{error::GatewayError, payloads::gateway::BaseMessage}; +  use super::Connection; -use crate::client::{error_utils::GatewayError};  use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};  use log::info;  use serde::Serialize; @@ -9,8 +10,9 @@ use std::{  };  use tokio_tungstenite::tungstenite::Message; +/// Implementation of the Stream trait for the Connection  impl Stream for Connection { -    type Item = Result<crate::client::payloads::gateway::Message, GatewayError>; +    type Item = Result<crate::payloads::gateway::Message, GatewayError>;      fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {          // first, when a poll is called, we check if the connection is still open @@ -30,7 +32,7 @@ impl Stream for Connection {                                          Err(e) => Poll::Ready(Some(Err(e))),                                      },                                      // unknown behaviour? -                                    Poll::Pending => unimplemented!(), +                                    Poll::Pending => unreachable!(),                                  }                              }                              Err(e) => Poll::Ready(Some(Err(GatewayError::from(e)))), @@ -38,7 +40,7 @@ impl Stream for Connection {                          // if no message is available, we return none, it's the end of the stream                          None => {                              info!("tokio_tungstenite stream finished successfully"); -                            Box::pin(conn.close(None)).poll_unpin(cx); +                            let _ = Box::pin(conn.close(None)).poll_unpin(cx);                              self.connection = None;                              Poll::Ready(None)                          } @@ -53,21 +55,22 @@ impl Stream for Connection {      }  } -impl<T: Serialize> Sink<crate::client::payloads::gateway::FullMessage<T>> for Connection { -    type Error = GatewayError; +/// Implementation of the Sink trait for the Connection +impl<T: Serialize> Sink<BaseMessage<T>> for Connection { +    type Error = tokio_tungstenite::tungstenite::Error;      #[allow(dead_code)] -    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { -        if let Some(_) = &self.connection { +    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { +        if let Some(conn) = &mut self.connection {              // a connection exists, we can send data -            Poll::Ready(Ok(())) +            conn.poll_ready_unpin(cx)          } else {              Poll::Pending          }      }      #[allow(dead_code)] -    fn start_send(mut self: Pin<&mut Self>, item: crate::client::payloads::gateway::FullMessage<T>) -> Result<(), Self::Error> { +    fn start_send(mut self: Pin<&mut Self>, item: BaseMessage<T>) -> Result<(), Self::Error> {          if let Some(conn) = &mut self.connection {              let message = serde_json::to_string(&item);              conn.start_send_unpin(Message::Text(message.unwrap())) @@ -77,12 +80,20 @@ impl<T: Serialize> Sink<crate::client::payloads::gateway::FullMessage<T>> for Co      }      #[allow(dead_code)] -    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { -        Poll::Ready(Ok(())) +    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { +        if let Some(conn) = &mut self.connection { +            conn.poll_flush_unpin(cx) +        } else { +            Poll::Pending +        }      }      #[allow(dead_code)] -    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { -        Poll::Ready(Ok(())) +    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { +        if let Some(conn) = &mut self.connection { +            conn.poll_close_unpin(cx) +        } else { +            Poll::Pending +        }      }  } diff --git a/gateway/src/connection/utils.rs b/gateway/src/connection/utils.rs new file mode 100644 index 0000000..fb07229 --- /dev/null +++ b/gateway/src/connection/utils.rs @@ -0,0 +1,42 @@ +use std::str::from_utf8; +use tokio_tungstenite::tungstenite::Message; +use log::info; + +use crate::error::GatewayError; + +use super::Connection; + +impl Connection { + +    /// Handles the websocket events and calls the _handle_discord_message function for the deserialization. +    pub(super) async fn _handle_message( +        &mut self, +        data: &Message, +    ) -> Result<crate::payloads::gateway::Message, GatewayError> { +        match data { +            Message::Text(text) => self._handle_discord_message(&text).await, +            Message::Binary(message) => { +                match from_utf8(message) { +                    Ok(data) => self._handle_discord_message(data).await, +                    Err(err) => Err(GatewayError::from(err.to_string())), +                } +            }, +            Message::Close(close_frame) => { +                info!("Discord connection closed {:?}", close_frame); +                Err(GatewayError::from("connection closed".to_string())) +            }, +            _ => Err(GatewayError::from(format!("unknown variant of message specified to the handler {}", data).to_string())), +        } +    } + +    /// Handle the decompression and deserialization process of a discord payload. +    pub(super) async fn _handle_discord_message( +        &mut self, +        raw_message: &str, +    ) -> Result<crate::payloads::gateway::Message, GatewayError> { +        match serde_json::from_str(raw_message) { +            Ok(message) => Ok(message), +            Err(err) => Err(GatewayError::from(err.to_string())), +        } +    } +} diff --git a/gateway/src/client/error_utils.rs b/gateway/src/error.rs index 603caab..603caab 100644 --- a/gateway/src/client/error_utils.rs +++ b/gateway/src/error.rs diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 003a903..4c42c7a 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -1,41 +1,19 @@ -mod client; - -use client::connection::Connection;  use common::config::Settings; -use futures::StreamExt; -use log::info; -use serde_json::Value; - -use crate::client::payloads::{dispatch::Dispatch, gateway::{FullMessage, Message, OpCodes}, payloads::identify::{Identify, IdentifyProprerties}}; +use shard::{Shard, ShardConfig}; +#[macro_use] +extern crate num_derive; -#[tokio::main] -async fn main() { -    let settings: Settings<Value> = Settings::new("gateway").unwrap(); +pub mod connection; +mod error; +mod utils; +mod shard; +mod payloads; -    let mut conn = Connection::new(); -    conn.start().await.unwrap(); -    loop { -        if let Some(val) = conn.next().await { -            let data = val.as_ref().unwrap(); -            match data { -                Message::Dispatch(dispatch) => { -                    match &dispatch.data { -                        Dispatch::Ready(_ready) => { -                             -                        }, -                    } -                }, -                Message::Reconnect(_) => todo!(), -                Message::InvalidSession(_) => todo!(), -                Message::Hello(_hello) => { -                    info!("Server said hello! {:?}", _hello); -                }, -                Message::HeartbeatACK(_) => todo!(), -            } -        } else { -            break; -        } -    } +#[tokio::main] +async fn main() { +    let settings: Settings<ShardConfig> = Settings::new("gateway").unwrap(); +    let mut shard = Shard::new(settings.config); +    shard.start().await;  } diff --git a/gateway/src/payloads/dispatch.rs b/gateway/src/payloads/dispatch.rs new file mode 100644 index 0000000..b2ddd89 --- /dev/null +++ b/gateway/src/payloads/dispatch.rs @@ -0,0 +1,38 @@ +use log::info; +use serde::{Deserialize, Deserializer}; + +use serde_json::Value; + +use super::{events::ready::Ready, opcodes::OpCodes}; + +/// Represents an unknown event not handled by the gateway itself +#[derive(Clone, Debug, PartialEq, Deserialize)] +pub struct UnknownDispatch { +    pub t: String, +    pub d: Value, +    pub s: i64, +    pub op: OpCodes, +} + +#[derive(Clone, Debug, PartialEq, Deserialize)] +#[serde(tag = "t", content = "d")] +#[serde(remote = "Dispatch")] +pub enum Dispatch { +    #[serde(rename = "READY")] +    Ready(Ready), +    #[serde(rename = "RESUMED")] +    Resumed(()), + +    #[serde(skip_deserializing)] +    Other(UnknownDispatch), +} + +impl<'de> Deserialize<'de> for Dispatch { +    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> +        where D: Deserializer<'de> +    { +        info!("hey"); +        let s = UnknownDispatch::deserialize(deserializer)?; +        Ok(Self::Other(s)) +    } +} diff --git a/gateway/src/payloads/events/mod.rs b/gateway/src/payloads/events/mod.rs new file mode 100644 index 0000000..3fef2d9 --- /dev/null +++ b/gateway/src/payloads/events/mod.rs @@ -0,0 +1 @@ +pub mod ready;
\ No newline at end of file diff --git a/gateway/src/payloads/events/ready.rs b/gateway/src/payloads/events/ready.rs new file mode 100644 index 0000000..a5ec291 --- /dev/null +++ b/gateway/src/payloads/events/ready.rs @@ -0,0 +1,13 @@ +use serde::Deserialize; +use serde_json::Value; + +#[derive(Deserialize, Clone, Debug, PartialEq)] +pub struct Ready { +    #[serde(rename = "v")] +    pub version: u64, +    pub user: Value, +    pub guilds: Vec<Value>, +    pub session_id: String, +    pub shard: Option<[i64;2]>, +    pub application: Value, +} diff --git a/gateway/src/payloads/events/resume.rs b/gateway/src/payloads/events/resume.rs new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/gateway/src/payloads/events/resume.rs diff --git a/gateway/src/payloads/gateway.rs b/gateway/src/payloads/gateway.rs new file mode 100644 index 0000000..e8dff96 --- /dev/null +++ b/gateway/src/payloads/gateway.rs @@ -0,0 +1,91 @@ +use super::{dispatch::Dispatch, opcodes::{OpCodes, hello::Hello}}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde::de::Error; + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +#[serde(bound(deserialize = "T: Deserialize<'de> + std::fmt::Debug"))] +pub struct BaseMessage<T> { +    pub t: Option<String>, +    #[serde(rename = "s")] +    pub sequence: Option<u64>, +    pub op: OpCodes, +    #[serde(rename = "d")] +    pub data: T, +} + +#[derive(Debug)] +pub enum Message { +    Dispatch(BaseMessage<Dispatch>), +    Reconnect(BaseMessage<()>), +    InvalidSession(BaseMessage<bool>), +    Hello(BaseMessage<Hello>), +    HeartbeatACK(BaseMessage<()>), +} + +impl<'de> serde::Deserialize<'de> for Message { +    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> where D::Error : Error { +        let value = Value::deserialize(d)?; +        let val = value.get("op").and_then(Value::as_u64).unwrap(); + +        if let Some(op) = num::FromPrimitive::from_u64(val) { +            match op { +                OpCodes::Dispatch => { +                    match Dispatch::deserialize(&value) { +                        Ok(data) => { + +                            let mut t = None; +                            if let Some(t_value) = &value.get("t") { +                                // this is safe since we know this is a string +                                t = Some(t_value.to_string()); +                            } +                            let mut sequence = None; + +                            if let Some(sequence_value) = value.get("s") { +                                if let Some(sequence_uint) = sequence_value.as_u64() { +                                    sequence = Some(sequence_uint); +                                } +                            } + +                            Ok(Message::Dispatch(BaseMessage { +                                op, +                                t, +                                sequence, +                                data +                            })) +                        }, +                        Err(e) => Err(Error::custom(e)), +                    } +                }, +                 +                OpCodes::Reconnect => { +                    match BaseMessage::deserialize(value) { +                        Ok(data) => Ok(Message::Reconnect(data)), +                        Err(e) => Err(Error::custom(e)), +                    } +                }, +                OpCodes::InvalidSession => { +                    match BaseMessage::deserialize(value) { +                        Ok(data) => Ok(Message::InvalidSession(data)), +                        Err(e) => Err(Error::custom(e)), +                    } +                }, +                OpCodes::Hello => { +                    match BaseMessage::deserialize(value) { +                        Ok(data) => Ok(Message::Hello(data)), +                        Err(e) => Err(Error::custom(e)), +                    } +                }, +                OpCodes::HeartbeatACK => { +                    match BaseMessage::deserialize(value) { +                        Ok(data) => Ok(Message::HeartbeatACK(data)), +                        Err(e) => Err(Error::custom(e)), +                    } +                }, +                _ => panic!("Cannot convert"), +            } +        } else { +            todo!(); +        } +    } +} diff --git a/gateway/src/payloads/mod.rs b/gateway/src/payloads/mod.rs new file mode 100644 index 0000000..e9849a7 --- /dev/null +++ b/gateway/src/payloads/mod.rs @@ -0,0 +1,4 @@ +pub mod opcodes; +pub mod dispatch; +pub mod gateway; +pub mod events; diff --git a/gateway/src/client/payloads/payloads/hello.rs b/gateway/src/payloads/opcodes/hello.rs index 0690a61..3d8fd0f 100644 --- a/gateway/src/client/payloads/payloads/hello.rs +++ b/gateway/src/payloads/opcodes/hello.rs @@ -1,5 +1,6 @@  use serde::{Serialize, Deserialize}; +/// The first message sent by the gateway to initialize the heartbeating  #[derive(Debug, Serialize, Deserialize)]  pub struct Hello {      #[serde(rename = "heartbeat_interval")] diff --git a/gateway/src/payloads/opcodes/identify.rs b/gateway/src/payloads/opcodes/identify.rs new file mode 100644 index 0000000..5929c33 --- /dev/null +++ b/gateway/src/payloads/opcodes/identify.rs @@ -0,0 +1,47 @@ +use enumflags2::{BitFlags, bitflags}; +use serde::{Deserialize, Serialize}; +use super::presence::PresenceUpdate; + + +#[bitflags] +#[repr(u16)] +#[derive(Clone, Copy, Debug)] +pub enum Intents { +    Guilds = 1 << 0, +    GuildMembers = 1 << 1, +    GuildBans = 1 << 2, +    GuildEmojisAndStickers = 1 << 3, +    GuildIntegrations = 1 << 4, +    GuildWebhoks = 1 << 5, +    GuildInvites = 1 << 6, +    GuildVoiceStates = 1 << 7, +    GuildPresences = 1 << 8, +    GuildMessages = 1 << 9, +    GuildMessagesReactions = 1 << 10, +    GuildMessageTyping = 1 << 11, +    DirectMessages = 1 << 12, +    DirectMessagesReactions = 1 << 13, +    DirectMessageTyping = 1 << 14, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IdentifyProprerties { +    #[serde(rename = "$os")] +    pub os: String, +    #[serde(rename = "$browser")] +    pub browser: String, +    #[serde(rename = "$device")] +    pub device: String, +} + +/// Messages sent by the shard to log-in to the gateway. +#[derive(Debug, Serialize, Deserialize)] +pub struct Identify { +    pub token: String, +    pub properties: IdentifyProprerties, +    pub compress: Option<bool>, +    pub large_threshold: Option<u64>, +    pub shard: Option<[u64; 2]>, +    pub presence: Option<PresenceUpdate>, +    pub intents: BitFlags<Intents>, +}
\ No newline at end of file diff --git a/gateway/src/payloads/opcodes/mod.rs b/gateway/src/payloads/opcodes/mod.rs new file mode 100644 index 0000000..cfa453a --- /dev/null +++ b/gateway/src/payloads/opcodes/mod.rs @@ -0,0 +1,22 @@ +pub mod hello; +pub mod identify; +pub mod resume; +pub mod presence; +use serde_repr::{Deserialize_repr, Serialize_repr}; + +#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug, Clone, FromPrimitive, ToPrimitive)] +#[repr(u8)] + +pub enum OpCodes { +    Dispatch = 0, +    Heartbeat = 1, +    Identify = 2, +    PresenceUpdate = 3, +    VoiceStateUpdate = 4, +    Resume = 6, +    Reconnect = 7, +    RequestGuildMembers = 8, +    InvalidSession = 9, +    Hello = 10, +    HeartbeatACK = 11, +}
\ No newline at end of file diff --git a/gateway/src/payloads/opcodes/presence.rs b/gateway/src/payloads/opcodes/presence.rs new file mode 100644 index 0000000..a6c5773 --- /dev/null +++ b/gateway/src/payloads/opcodes/presence.rs @@ -0,0 +1,63 @@ +use serde_repr::{Deserialize_repr, Serialize_repr}; +use serde::{Deserialize, Serialize}; +#[derive(Serialize_repr, Deserialize_repr, Debug)] +#[repr(u8)] +pub enum ActivityType { +    Game = 0, +    Streaming = 1, +    Listening = 2, +    Watching = 3, +    Custom = 4, +    Competing = 5 +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ActivityTimestamps { +    start: u64, +    end: u64, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ActivityEmoji { +    name: String, +    id: Option<String>, +    animated: Option<bool> +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Activity { +    name: String, +    #[serde(rename = "type")] +    t: ActivityType, + +    url: Option<String>, +    created_at: i64, +    timestamp: Option<ActivityTimestamps>, +    application_id: Option<String>, +    details: Option<String>, +    state: Option<String>, +    emoji: Option<ActivityEmoji>, +    // todo: implement more +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum PresenceStatus { +    #[serde(rename = "online")] +    Online, +    #[serde(rename = "dnd")] +    Dnd, +    #[serde(rename = "idle")] +    Idle, +    #[serde(rename = "invisible")] +    Invisible, +    #[serde(rename = "offline")] +    Offline +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PresenceUpdate { +    since: u64, +    activities: Vec<Activity>, +    status: PresenceStatus, +    afk: bool, +} diff --git a/gateway/src/payloads/opcodes/resume.rs b/gateway/src/payloads/opcodes/resume.rs new file mode 100644 index 0000000..e1bba91 --- /dev/null +++ b/gateway/src/payloads/opcodes/resume.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Resume { +    pub token: String, +    pub session_id: String, +    pub seq: u64, +}
\ No newline at end of file diff --git a/gateway/src/shard/actions.rs b/gateway/src/shard/actions.rs new file mode 100644 index 0000000..513d1a8 --- /dev/null +++ b/gateway/src/shard/actions.rs @@ -0,0 +1,113 @@ +use std::env; + +use futures::SinkExt; +use log::{debug, error}; +use serde::Serialize; +use serde_json::Value; +use std::fmt::Debug; + +use crate::{error::GatewayError, payloads::{gateway::BaseMessage, opcodes::{OpCodes, identify::{Identify, IdentifyProprerties}, presence::PresenceUpdate, resume::Resume}}}; + +use super::Shard; + +/// Implement the available actions for nova in the gateway. +impl Shard { + +    /// sends a message through the websocket +    pub async fn _send<T: Serialize + Debug>(&mut self, message: BaseMessage<T>) -> Result<(), GatewayError> { +        debug!("senging message {:?}", message); +        if let Some(connection) = &mut self.connection { +            if let Err(e) = connection.conn.send(message).await { +                error!("failed to send message {:?}", e); +                Err(GatewayError::from(e)) +            } else { +                Ok(()) +            } +        } else { +            Err(GatewayError::from("no open connection".to_string())) +        } +    } + +    pub async fn _identify(&mut self) -> Result<(), GatewayError> { +        if let Some(state) = self.state.clone()  { +            self._send(BaseMessage{ +                t: None, +                sequence: None, +                op: OpCodes::Resume, +                data: Resume { +                    token: self.config.token.clone(), +                    seq: state.sequence, +                    session_id: state.session_id.clone(), +                }, +            }).await +        } else { +            self._send(BaseMessage{ +                t: None, +                sequence: None, +                op: OpCodes::Identify, +                data: Identify { +                    token: self.config.token.clone(), +                    intents: self.config.intents, +                    properties: IdentifyProprerties { +                        os: env::consts::OS.to_string(), +                        browser: "Nova".to_string(), +                        device: "Nova".to_string(), +                    }, +                    shard: Some([0,2]), +                    compress: Some(false), +                    large_threshold: Some(500), +                    presence: None, +                }, +            }).await +        } +    } + +    pub async fn _disconnect(&mut self) {} + +    /// Updates the presence of the current shard. +    #[allow(dead_code)] +    pub async fn presence_update(&mut self, update: PresenceUpdate) -> Result<(), GatewayError> { +        self._send(BaseMessage{ +            t: None, +            sequence: None, +            op: OpCodes::PresenceUpdate, +            data: update, +        }).await +    } +    /// Updates the voice status of the current shard in a certain channel. +    #[allow(dead_code)] +    pub async fn voice_state_update(&mut self) -> Result<(), GatewayError> { +        if let Some(connection) = &mut self.connection { +            connection.conn +                .send(BaseMessage { +                    t: None, +                    sequence: None, +                    op: OpCodes::VoiceStateUpdate, +                    // todo: proper payload for this +                    data: Value::Null, +                }) +                .await? +        } else { +            error!("the connection is not open") +        } +        Ok(()) +    } +    /// Ask discord for more informations about offline guild members. +    #[allow(dead_code)] +    pub async fn request_guild_members(&mut self) -> Result<(), GatewayError> { +        if let Some(connection) = &mut self.connection { +            connection.conn +                .send(BaseMessage { +                    t: None, +                    sequence: None, +                    op: OpCodes::RequestGuildMembers, +                    // todo: proper payload for this +                    data: Value::Null, +                }) +                .await? +        } else { +            error!("the connection is not open") +        } +        Ok(()) +    } +} diff --git a/gateway/src/shard/connection.rs b/gateway/src/shard/connection.rs new file mode 100644 index 0000000..6f8503c --- /dev/null +++ b/gateway/src/shard/connection.rs @@ -0,0 +1,151 @@ +use std::{ +    cmp::{max, min}, +    convert::TryInto, +    time::Duration, +}; + +use crate::{ +    connection::Connection, +    payloads::{ +        dispatch::Dispatch, +        gateway::{BaseMessage, Message}, +    }, +    shard::state::SessionState, +}; + +use super::{state::ConnectionState, ConnectionWithState, Shard}; +use futures::StreamExt; +use log::{error, info}; +use tokio::{select, time::sleep}; + +impl Shard { +    pub async fn start(self: &mut Self) { +        let mut reconnects = 1; +        info!("Starting shard"); + +        while reconnects < self.config.max_reconnects { +            info!("Starting connection for shard"); +            self._shard_task().await; +            // when the shard got disconnected, the shard task ends +            reconnects += 1; + +            // wait reconnects min(max(reconnects * reconnect_delay_growth_factor, reconnect_delay_minimum),reconnect_delay_maximum) +            if reconnects < self.config.max_reconnects { +                let time = min( +                    self.config.reconnect_delay_maximum, +                    max( +                        ((reconnects as f32) * self.config.reconnect_delay_growth_factor) as usize, +                        self.config.reconnect_delay_minimum, +                    ), +                ); +                info!( +                    "The shard got disconnected, waiting for reconnect ({}ms)", +                    time +                ); +                sleep(Duration::from_millis(time.try_into().unwrap())).await; +            } +        } +        info!( +            "The shard got disconnected too many times and reached the maximum {}", +            self.config.max_reconnects +        ); +    } + +    async fn _shard_task(&mut self) { +        // create the new connection +        let mut connection = Connection::new(); +        connection.start().await.unwrap(); +        self.connection = Some(ConnectionWithState { +            conn: connection, +            state: ConnectionState::default(), +        }); +        loop { +            if let Some(connection) = &mut self.connection { +                select!( +                    payload = connection.conn.next() => { +                        match payload { +                            Some(data) => match data { +                                Ok(message) => self._handle(&message).await, +                                Err(error) => { +                                    error!("An error occured while being connected to Discord: {:?}", error); +                                    return; +                                }, +                            }, +                            None => { +                                info!("Connection terminated"); +                                return; +                            }, +                        } +                    } +                ) +            } +        } +    } + +    fn _util_set_seq(&mut self, seq: Option<u64>) { +        if let Some(seq) = seq { +            if let Some(state) = &mut self.state { +                state.sequence = seq; +            } +        } +    } + +    async fn _handle(&mut self, message: &Message) { +        match message { +            Message::Dispatch(msg) => { +                self._util_set_seq(msg.sequence); +                self._dispatch(&msg).await; +            } +            // we need to reconnect to the gateway +            Message::Reconnect(msg) => { +                self._util_set_seq(msg.sequence); +                info!("gateway disconnect requested"); +                self._disconnect().await; +            } +            Message::InvalidSession(msg) => { +                self._util_set_seq(msg.sequence); +                info!("invalid session"); +                let data = msg.data; +                if !data { +                    info!("session removed"); +                    // reset the session data +                    self.state = None; +                    if let Err(e) = self._identify().await { +                        error!("error while sending identify: {:?}", e); +                    } +                } +            } +            Message::HeartbeatACK(msg) => { +                self._util_set_seq(msg.sequence); +                info!("heartbeat ack received"); +            } +            Message::Hello(msg) => { +                self._util_set_seq(msg.sequence); +                info!("server hello received"); +                if let Err(e) = self._identify().await { +                    error!("error while sending identify: {:?}", e); +                } +            }, +        } +    } + +    async fn _dispatch(&mut self, dispatch: &BaseMessage<Dispatch>) { +        match &dispatch.data { +            Dispatch::Ready(ready) => { +                info!("received gateway dispatch hello"); +                info!( +                    "Logged in as {}", +                    ready.user.get("username").unwrap().to_string() +                ); +                self.state = Some(SessionState { +                    sequence: dispatch.sequence.unwrap(), +                    session_id: ready.session_id.clone(), +                }); +            } +            Dispatch::Resumed(_) => { +                info!("session resumed"); +            } +            Dispatch::Other(data) => {} +        } +    } +} diff --git a/gateway/src/shard/mod.rs b/gateway/src/shard/mod.rs new file mode 100644 index 0000000..b458451 --- /dev/null +++ b/gateway/src/shard/mod.rs @@ -0,0 +1,50 @@ +use enumflags2::BitFlags; +use serde::{Deserialize, Serialize}; +use crate::{connection::Connection, payloads::opcodes::identify::Intents}; +use self::state::{ConnectionState, SessionState}; +mod actions; +mod connection; +mod state; + +#[derive(Debug, Deserialize, Serialize, Default, Clone)] +pub struct Sharding { +    pub total_shards: i64, +    pub current_shard: i64 +} + + +#[derive(Debug, Deserialize, Serialize, Default, Clone)] +pub struct ShardConfig { +    pub max_reconnects: usize, +    pub reconnect_delay_growth_factor: f32, +    pub reconnect_delay_minimum: usize, +    pub reconnect_delay_maximum: usize, +    pub token: String, +     +    pub large_threshold: Option<u64>, +    pub shard: Option<Sharding>, +    pub intents: BitFlags<Intents> +} + +struct ConnectionWithState { +    conn: Connection, +    state: ConnectionState, +} + +/// Represents a shard & all the reconnection logic related to it +pub struct Shard { +    connection: Option<ConnectionWithState>, +    state: Option<SessionState>, +    config: ShardConfig +} + +impl Shard { +    /// Creates a new shard instance +    pub fn new(config: ShardConfig) -> Self { +        Shard { +            connection: None, +            state: None, +            config, +        } +    } +} diff --git a/gateway/src/client/shard/state.rs b/gateway/src/shard/state.rs index 4d40911..caed268 100644 --- a/gateway/src/client/shard/state.rs +++ b/gateway/src/shard/state.rs @@ -1,6 +1,7 @@  use std::time::Instant;  /// This struct represents the state of a session +#[derive(Clone, Debug)]  pub struct SessionState {      pub sequence: u64,      pub session_id: String, diff --git a/gateway/src/client/utils.rs b/gateway/src/utils.rs index 141740e..48a9aed 100644 --- a/gateway/src/client/utils.rs +++ b/gateway/src/utils.rs @@ -1,4 +1,3 @@ -  /// Formats a url of connection to the gateway  pub fn get_gateway_url (compress: bool, encoding: &str, v: i16) -> String {      return format!(  | 
