diff options
Diffstat (limited to 'gateway/src')
| -rw-r--r-- | gateway/src/config.rs | 14 | ||||
| -rw-r--r-- | gateway/src/connection/mod.rs | 36 | ||||
| -rw-r--r-- | gateway/src/connection/stream.rs | 99 | ||||
| -rw-r--r-- | gateway/src/connection/utils.rs | 42 | ||||
| -rw-r--r-- | gateway/src/error.rs | 22 | ||||
| -rw-r--r-- | gateway/src/main.rs | 69 | ||||
| -rw-r--r-- | gateway/src/payloads/dispatch.rs | 46 | ||||
| -rw-r--r-- | gateway/src/payloads/events/mod.rs | 1 | ||||
| -rw-r--r-- | gateway/src/payloads/events/ready.rs | 13 | ||||
| -rw-r--r-- | gateway/src/payloads/gateway.rs | 78 | ||||
| -rw-r--r-- | gateway/src/payloads/mod.rs | 4 | ||||
| -rw-r--r-- | gateway/src/payloads/opcodes/hello.rs | 8 | ||||
| -rw-r--r-- | gateway/src/payloads/opcodes/identify.rs | 47 | ||||
| -rw-r--r-- | gateway/src/payloads/opcodes/mod.rs | 22 | ||||
| -rw-r--r-- | gateway/src/payloads/opcodes/presence.rs | 63 | ||||
| -rw-r--r-- | gateway/src/payloads/opcodes/resume.rs | 8 | ||||
| -rw-r--r-- | gateway/src/shard/actions.rs | 132 | ||||
| -rw-r--r-- | gateway/src/shard/connection.rs | 193 | ||||
| -rw-r--r-- | gateway/src/shard/mod.rs | 49 | ||||
| -rw-r--r-- | gateway/src/shard/state.rs | 35 | ||||
| -rw-r--r-- | gateway/src/utils.rs | 8 |
21 files changed, 69 insertions, 920 deletions
diff --git a/gateway/src/config.rs b/gateway/src/config.rs new file mode 100644 index 0000000..999892b --- /dev/null +++ b/gateway/src/config.rs @@ -0,0 +1,14 @@ +use common::serde::{Deserialize, Serialize}; +use twilight_gateway::Intents; + +#[derive(Serialize, Deserialize, Clone)] +pub struct Config { + pub token: String, + pub intents: Intents +} + +impl Default for Config { + fn default() -> Self { + Self { intents: Intents::empty(), token: String::default() } + } +} diff --git a/gateway/src/connection/mod.rs b/gateway/src/connection/mod.rs deleted file mode 100644 index c60068a..0000000 --- a/gateway/src/connection/mod.rs +++ /dev/null @@ -1,36 +0,0 @@ -use tokio::net::TcpStream; -use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::http::Request}; - -use crate::{error::GatewayError, utils::get_gateway_url}; - -mod stream; -mod utils; - -/// Underlying representation of a Discord event stream -/// that streams the Event payloads to the shard structure -pub struct Connection { - /// The channel given by tokio_tungstenite that represents the websocket connection - connection: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>, -} - -impl Connection { - pub fn new() -> Self { - Connection { connection: None } - } - - pub async fn start(&mut self) -> Result<(), GatewayError> { - let request = Request::builder() - .uri(get_gateway_url(false, "json", 9)) - .body(()) - .unwrap(); - - let connection_result = connect_async(request).await; - // we connect outselves to the websocket server - if let Err(err) = connection_result { - Err(GatewayError::from(err)) - } else { - self.connection = Some(connection_result.unwrap().0); - Ok(()) - } - } -}
\ No newline at end of file diff --git a/gateway/src/connection/stream.rs b/gateway/src/connection/stream.rs deleted file mode 100644 index 5a12daf..0000000 --- a/gateway/src/connection/stream.rs +++ /dev/null @@ -1,99 +0,0 @@ -use crate::{error::GatewayError, payloads::gateway::BaseMessage}; - -use super::Connection; -use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt}; -use common::log::info; -use serde::Serialize; -use std::{ - pin::Pin, - task::{Context, Poll}, -}; -use tokio_tungstenite::tungstenite::Message; - -/// Implementation of the Stream trait for the Connection -impl Stream for Connection { - type Item = Result<crate::payloads::gateway::Message, GatewayError>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - // first, when a poll is called, we check if the connection is still open - if let Some(conn) = &mut self.connection { - // we need to wait poll the message using the tokio_tungstenite stream - let message = conn.poll_next_unpin(cx); - - match message { - Poll::Ready(packet) => { - // if data is available, we can continue - match packet { - Some(result) => match result { - Ok(message) => { - match Box::pin(self._handle_message(&message)).poll_unpin(cx) { - Poll::Ready(data) => match data { - Ok(d) => Poll::Ready(Some(Ok(d))), - Err(e) => Poll::Ready(Some(Err(e))), - }, - // unknown behaviour? - Poll::Pending => unreachable!(), - } - } - Err(e) => Poll::Ready(Some(Err(GatewayError::from(e)))), - }, - // if no message is available, we return none, it's the end of the stream - None => { - info!("tokio_tungstenite stream finished successfully"); - let _ = Box::pin(conn.close(None)).poll_unpin(cx); - self.connection = None; - Poll::Ready(None) - } - } - } - // if the message is pending, we return the same result - Poll::Pending => Poll::Pending, - } - } else { - Poll::Ready(None) - } - } -} - -/// Implementation of the Sink trait for the Connection -impl<T: Serialize> Sink<BaseMessage<T>> for Connection { - type Error = tokio_tungstenite::tungstenite::Error; - - #[allow(dead_code)] - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - if let Some(conn) = &mut self.connection { - // a connection exists, we can send data - conn.poll_ready_unpin(cx) - } else { - Poll::Pending - } - } - - #[allow(dead_code)] - fn start_send(mut self: Pin<&mut Self>, item: BaseMessage<T>) -> Result<(), Self::Error> { - if let Some(conn) = &mut self.connection { - let message = serde_json::to_string(&item); - conn.start_send_unpin(Message::Text(message.unwrap())) - .unwrap(); - } - Ok(()) - } - - #[allow(dead_code)] - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - if let Some(conn) = &mut self.connection { - conn.poll_flush_unpin(cx) - } else { - Poll::Pending - } - } - - #[allow(dead_code)] - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - if let Some(conn) = &mut self.connection { - conn.poll_close_unpin(cx) - } else { - Poll::Pending - } - } -} diff --git a/gateway/src/connection/utils.rs b/gateway/src/connection/utils.rs deleted file mode 100644 index bb425da..0000000 --- a/gateway/src/connection/utils.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::str::from_utf8; -use tokio_tungstenite::tungstenite::Message; -use common::log::info; - -use crate::error::GatewayError; - -use super::Connection; - -impl Connection { - - /// Handles the websocket events and calls the _handle_discord_message function for the deserialization. - pub(super) async fn _handle_message( - &mut self, - data: &Message, - ) -> Result<crate::payloads::gateway::Message, GatewayError> { - match data { - Message::Text(text) => self._handle_discord_message(&text).await, - Message::Binary(message) => { - match from_utf8(message) { - Ok(data) => self._handle_discord_message(data).await, - Err(err) => Err(GatewayError::from(err.to_string())), - } - }, - Message::Close(close_frame) => { - info!("Discord connection closed {:?}", close_frame); - Err(GatewayError::from("connection closed".to_string())) - }, - _ => Err(GatewayError::from(format!("unknown variant of message specified to the handler {}", data).to_string())), - } - } - - /// Handle the decompression and deserialization process of a discord payload. - pub(super) async fn _handle_discord_message( - &mut self, - raw_message: &str, - ) -> Result<crate::payloads::gateway::Message, GatewayError> { - match serde_json::from_str(raw_message) { - Ok(message) => Ok(message), - Err(err) => Err(GatewayError::from(err.to_string())), - } - } -} diff --git a/gateway/src/error.rs b/gateway/src/error.rs deleted file mode 100644 index eb3a245..0000000 --- a/gateway/src/error.rs +++ /dev/null @@ -1,22 +0,0 @@ -use common::error::NovaError; - -#[derive(Debug)] -pub struct GatewayError(NovaError); - -impl From<tokio_tungstenite::tungstenite::Error> for GatewayError { - fn from(e: tokio_tungstenite::tungstenite::Error) -> Self { - GatewayError { - 0: NovaError { - message: e.to_string(), - }, - } - } -} - -impl From<String> for GatewayError { - fn from(e: String) -> Self { - GatewayError { - 0: NovaError { message: e }, - } - } -} diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 4c42c7a..4f67183 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -1,19 +1,60 @@ -use common::config::Settings; -use shard::{Shard, ShardConfig}; -#[macro_use] -extern crate num_derive; +use common::{ + config::Settings, + log::{debug, info}, + nats_crate::Connection, + payloads::{CachePayload, SerializeHelper, Tracing}, +}; +use config::Config; +use std::error::Error; +use twilight_gateway::{Event, Shard}; +mod config; +use futures::StreamExt; -pub mod connection; -mod error; -mod utils; -mod shard; -mod payloads; +#[tokio::main] +async fn main() -> Result<(), Box<dyn Error + Send + Sync>> { + let settings: Settings<Config> = Settings::new("gateway").unwrap(); + let (shard, mut events) = Shard::new(settings.config.token, settings.config.intents); + let nats: Connection = settings.nats.into(); + shard.start().await?; + while let Some(event) = events.next().await { + match event { + Event::Ready(ready) => { + info!("Logged in as {}", ready.user.name); + } + Event::Resumed => {} + Event::GatewayHeartbeat(_) => {} + Event::GatewayHeartbeatAck => {} + Event::GatewayInvalidateSession(_) => {} + Event::GatewayReconnect => {} + Event::GatewayHello(_) => {} -#[tokio::main] -async fn main() { - let settings: Settings<ShardConfig> = Settings::new("gateway").unwrap(); - let mut shard = Shard::new(settings.config); - shard.start().await; + Event::ShardConnected(_) => {} + Event::ShardConnecting(_) => {} + Event::ShardDisconnected(_) => {} + Event::ShardIdentifying(_) => {} + Event::ShardReconnecting(_) => {} + Event::ShardPayload(_) => {} + Event::ShardResuming(_) => {} + + _ => { + let data = CachePayload { + tracing: Tracing { + node_id: "".to_string(), + span: None, + }, + data: SerializeHelper(event), + }; + let value = serde_json::to_string(&data)?; + debug!("nats send: {}", value); + nats.request( + &format!("nova.cache.discord.{}", data.data.0.kind().name().unwrap()), + value, + )?; + } + } + } + + Ok(()) } diff --git a/gateway/src/payloads/dispatch.rs b/gateway/src/payloads/dispatch.rs deleted file mode 100644 index 9eca9c5..0000000 --- a/gateway/src/payloads/dispatch.rs +++ /dev/null @@ -1,46 +0,0 @@ -use serde::{Deserialize, Deserializer, Serialize}; - -use serde_json::Value; - -use super::gateway::BaseMessage; - -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub struct Ready { - #[serde(rename = "v")] - pub version: u64, - pub user: Value, - pub guilds: Vec<Value>, - pub session_id: String, - pub shard: Option<[i64;2]>, - pub application: Value, -} - -#[derive(Clone, Debug, PartialEq, Deserialize)] -#[serde(tag = "t", content = "d")] -pub enum FakeDispatch { - #[serde(rename = "READY")] - Ready(Ready), - Other(Value), -} - -#[derive(Clone, Debug, PartialEq)] -pub enum Dispatch { - Ready(Ready), - Other(BaseMessage<Value>) -} - -impl<'de> Deserialize<'de> for Dispatch { - fn deserialize<D>(d: D) -> Result<Self, D::Error> - where - D: Deserializer<'de>, - { - // todo: error handling - let value = Value::deserialize(d)?; - - if value.get("t").unwrap() == "READY" { - Ok(Dispatch::Ready(Ready::deserialize(value.get("d").unwrap()).unwrap())) - } else { - Ok(Dispatch::Other(BaseMessage::deserialize(value).unwrap())) - } - } -}
\ No newline at end of file diff --git a/gateway/src/payloads/events/mod.rs b/gateway/src/payloads/events/mod.rs deleted file mode 100644 index 3fef2d9..0000000 --- a/gateway/src/payloads/events/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod ready;
\ No newline at end of file diff --git a/gateway/src/payloads/events/ready.rs b/gateway/src/payloads/events/ready.rs deleted file mode 100644 index a5ec291..0000000 --- a/gateway/src/payloads/events/ready.rs +++ /dev/null @@ -1,13 +0,0 @@ -use serde::Deserialize; -use serde_json::Value; - -#[derive(Deserialize, Clone, Debug, PartialEq)] -pub struct Ready { - #[serde(rename = "v")] - pub version: u64, - pub user: Value, - pub guilds: Vec<Value>, - pub session_id: String, - pub shard: Option<[i64;2]>, - pub application: Value, -} diff --git a/gateway/src/payloads/gateway.rs b/gateway/src/payloads/gateway.rs deleted file mode 100644 index 4f24890..0000000 --- a/gateway/src/payloads/gateway.rs +++ /dev/null @@ -1,78 +0,0 @@ -use super::{ - dispatch::Dispatch, - opcodes::{hello::Hello, OpCodes}, -}; -use serde::de::Error; -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] -#[serde(bound(deserialize = "T: Deserialize<'de> + std::fmt::Debug"))] -pub struct BaseMessage<T> { - pub t: Option<String>, - #[serde(rename = "s")] - pub sequence: Option<u64>, - pub op: OpCodes, - #[serde(rename = "d")] - pub data: T, -} - -#[derive(Debug)] -pub enum Message { - Dispatch(BaseMessage<Dispatch>), - Reconnect(BaseMessage<()>), - InvalidSession(BaseMessage<bool>), - Hello(BaseMessage<Hello>), - HeartbeatACK(BaseMessage<()>), -} - -impl<'de> serde::Deserialize<'de> for Message { - fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> - where - D::Error: Error, - { - let value = Value::deserialize(d)?; - let val = value.get("op").and_then(Value::as_u64).unwrap(); - - if let Some(op) = num::FromPrimitive::from_u64(val) { - match op { - OpCodes::Dispatch => { - // todo: remove unwrap - let t = Some(value.get("t").unwrap().to_string()); - let sequence = value.get("s").unwrap().as_u64(); - - // we need to find a better solution than clone - match Dispatch::deserialize(value) { - Ok(data) => Ok(Message::Dispatch(BaseMessage { - op, - t, - sequence, - data, - })), - Err(e) => Err(Error::custom(e)), - } - } - - OpCodes::Reconnect => match BaseMessage::deserialize(value) { - Ok(data) => Ok(Message::Reconnect(data)), - Err(e) => Err(Error::custom(e)), - }, - OpCodes::InvalidSession => match BaseMessage::deserialize(value) { - Ok(data) => Ok(Message::InvalidSession(data)), - Err(e) => Err(Error::custom(e)), - }, - OpCodes::Hello => match BaseMessage::deserialize(value) { - Ok(data) => Ok(Message::Hello(data)), - Err(e) => Err(Error::custom(e)), - }, - OpCodes::HeartbeatACK => match BaseMessage::deserialize(value) { - Ok(data) => Ok(Message::HeartbeatACK(data)), - Err(e) => Err(Error::custom(e)), - }, - _ => panic!("Cannot convert"), - } - } else { - Err(Error::custom("unknown opcode")) - } - } -} diff --git a/gateway/src/payloads/mod.rs b/gateway/src/payloads/mod.rs deleted file mode 100644 index e9849a7..0000000 --- a/gateway/src/payloads/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod opcodes; -pub mod dispatch; -pub mod gateway; -pub mod events; diff --git a/gateway/src/payloads/opcodes/hello.rs b/gateway/src/payloads/opcodes/hello.rs deleted file mode 100644 index 3d8fd0f..0000000 --- a/gateway/src/payloads/opcodes/hello.rs +++ /dev/null @@ -1,8 +0,0 @@ -use serde::{Serialize, Deserialize}; - -/// The first message sent by the gateway to initialize the heartbeating -#[derive(Debug, Serialize, Deserialize)] -pub struct Hello { - #[serde(rename = "heartbeat_interval")] - pub heartbeat_interval: u64 -} diff --git a/gateway/src/payloads/opcodes/identify.rs b/gateway/src/payloads/opcodes/identify.rs deleted file mode 100644 index 5929c33..0000000 --- a/gateway/src/payloads/opcodes/identify.rs +++ /dev/null @@ -1,47 +0,0 @@ -use enumflags2::{BitFlags, bitflags}; -use serde::{Deserialize, Serialize}; -use super::presence::PresenceUpdate; - - -#[bitflags] -#[repr(u16)] -#[derive(Clone, Copy, Debug)] -pub enum Intents { - Guilds = 1 << 0, - GuildMembers = 1 << 1, - GuildBans = 1 << 2, - GuildEmojisAndStickers = 1 << 3, - GuildIntegrations = 1 << 4, - GuildWebhoks = 1 << 5, - GuildInvites = 1 << 6, - GuildVoiceStates = 1 << 7, - GuildPresences = 1 << 8, - GuildMessages = 1 << 9, - GuildMessagesReactions = 1 << 10, - GuildMessageTyping = 1 << 11, - DirectMessages = 1 << 12, - DirectMessagesReactions = 1 << 13, - DirectMessageTyping = 1 << 14, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct IdentifyProprerties { - #[serde(rename = "$os")] - pub os: String, - #[serde(rename = "$browser")] - pub browser: String, - #[serde(rename = "$device")] - pub device: String, -} - -/// Messages sent by the shard to log-in to the gateway. -#[derive(Debug, Serialize, Deserialize)] -pub struct Identify { - pub token: String, - pub properties: IdentifyProprerties, - pub compress: Option<bool>, - pub large_threshold: Option<u64>, - pub shard: Option<[u64; 2]>, - pub presence: Option<PresenceUpdate>, - pub intents: BitFlags<Intents>, -}
\ No newline at end of file diff --git a/gateway/src/payloads/opcodes/mod.rs b/gateway/src/payloads/opcodes/mod.rs deleted file mode 100644 index cfa453a..0000000 --- a/gateway/src/payloads/opcodes/mod.rs +++ /dev/null @@ -1,22 +0,0 @@ -pub mod hello; -pub mod identify; -pub mod resume; -pub mod presence; -use serde_repr::{Deserialize_repr, Serialize_repr}; - -#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug, Clone, FromPrimitive, ToPrimitive)] -#[repr(u8)] - -pub enum OpCodes { - Dispatch = 0, - Heartbeat = 1, - Identify = 2, - PresenceUpdate = 3, - VoiceStateUpdate = 4, - Resume = 6, - Reconnect = 7, - RequestGuildMembers = 8, - InvalidSession = 9, - Hello = 10, - HeartbeatACK = 11, -}
\ No newline at end of file diff --git a/gateway/src/payloads/opcodes/presence.rs b/gateway/src/payloads/opcodes/presence.rs deleted file mode 100644 index a6c5773..0000000 --- a/gateway/src/payloads/opcodes/presence.rs +++ /dev/null @@ -1,63 +0,0 @@ -use serde_repr::{Deserialize_repr, Serialize_repr}; -use serde::{Deserialize, Serialize}; -#[derive(Serialize_repr, Deserialize_repr, Debug)] -#[repr(u8)] -pub enum ActivityType { - Game = 0, - Streaming = 1, - Listening = 2, - Watching = 3, - Custom = 4, - Competing = 5 -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ActivityTimestamps { - start: u64, - end: u64, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ActivityEmoji { - name: String, - id: Option<String>, - animated: Option<bool> -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Activity { - name: String, - #[serde(rename = "type")] - t: ActivityType, - - url: Option<String>, - created_at: i64, - timestamp: Option<ActivityTimestamps>, - application_id: Option<String>, - details: Option<String>, - state: Option<String>, - emoji: Option<ActivityEmoji>, - // todo: implement more -} - -#[derive(Serialize, Deserialize, Debug)] -pub enum PresenceStatus { - #[serde(rename = "online")] - Online, - #[serde(rename = "dnd")] - Dnd, - #[serde(rename = "idle")] - Idle, - #[serde(rename = "invisible")] - Invisible, - #[serde(rename = "offline")] - Offline -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct PresenceUpdate { - since: u64, - activities: Vec<Activity>, - status: PresenceStatus, - afk: bool, -} diff --git a/gateway/src/payloads/opcodes/resume.rs b/gateway/src/payloads/opcodes/resume.rs deleted file mode 100644 index e1bba91..0000000 --- a/gateway/src/payloads/opcodes/resume.rs +++ /dev/null @@ -1,8 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct Resume { - pub token: String, - pub session_id: String, - pub seq: u64, -}
\ No newline at end of file diff --git a/gateway/src/shard/actions.rs b/gateway/src/shard/actions.rs deleted file mode 100644 index 39a7ca2..0000000 --- a/gateway/src/shard/actions.rs +++ /dev/null @@ -1,132 +0,0 @@ -use std::env; - -use futures::SinkExt; -use common::log::{debug, error, info}; -use serde::Serialize; -use serde_json::Value; -use std::fmt::Debug; - -use crate::{ - error::GatewayError, - payloads::{ - gateway::BaseMessage, - opcodes::{ - identify::{Identify, IdentifyProprerties}, - presence::PresenceUpdate, - resume::Resume, - OpCodes, - }, - }, -}; - -use super::Shard; - -/// Implement the available actions for nova in the gateway. -impl Shard { - /// sends a message through the websocket - pub async fn _send<T: Serialize + Debug>( - &mut self, - message: BaseMessage<T>, - ) -> Result<(), GatewayError> { - debug!("Senging message {:?}", message); - if let Some(connection) = &mut self.connection { - if let Err(e) = connection.conn.send(message).await { - error!("failed to send message {:?}", e); - Err(GatewayError::from(e)) - } else { - Ok(()) - } - } else { - Err(GatewayError::from("no open connection".to_string())) - } - } - - pub async fn _identify(&mut self) -> Result<(), GatewayError> { - if let Some(state) = self.state.clone() { - info!("Using session"); - self._send(BaseMessage { - t: None, - sequence: None, - op: OpCodes::Resume, - data: Resume { - token: self.config.token.clone(), - seq: state.sequence, - session_id: state.session_id.clone(), - }, - }) - .await - } else { - info!("Sending login"); - let mut shards: Option<[u64; 2]> = None; - if let Some(sharding) = self.config.shard.as_ref() { - shards = Some([sharding.current_shard, sharding.total_shards]); - } - self._send(BaseMessage { - t: None, - sequence: None, - op: OpCodes::Identify, - data: Identify { - token: self.config.token.clone(), - intents: self.config.intents, - properties: IdentifyProprerties { - os: env::consts::OS.to_string(), - browser: "Nova".to_string(), - device: "Nova".to_string(), - }, - shard: shards, - compress: Some(false), - large_threshold: Some(500), - presence: None, - }, - }) - .await - } - } - - pub async fn _disconnect(&mut self) {} - - /// Updates the presence of the current shard. - #[allow(dead_code)] - pub async fn presence_update(&mut self, update: PresenceUpdate) -> Result<(), GatewayError> { - self._send(BaseMessage { - t: None, - sequence: None, - op: OpCodes::PresenceUpdate, - data: update, - }) - .await - } - /// Updates the voice status of the current shard in a certain channel. - #[allow(dead_code)] - pub async fn voice_state_update(&mut self) -> Result<(), GatewayError> { - self._send(BaseMessage { - t: None, - sequence: None, - op: OpCodes::VoiceStateUpdate, - // todo: proper payload for this - data: Value::Null, - }) - .await - } - /// Ask discord for more informations about offline guild members. - #[allow(dead_code)] - pub async fn request_guild_members(&mut self) -> Result<(), GatewayError> { - self._send(BaseMessage { - t: None, - sequence: None, - op: OpCodes::RequestGuildMembers, - // todo: proper payload for this - data: Value::Null, - }) - .await - } - - pub async fn _send_heartbeat(&mut self) -> Result<(), GatewayError> { - self._send(BaseMessage { - t: None, - sequence: None, - op: OpCodes::Heartbeat, - data: self.state.as_ref().unwrap().sequence - }).await - } -} diff --git a/gateway/src/shard/connection.rs b/gateway/src/shard/connection.rs deleted file mode 100644 index 8f8ddc6..0000000 --- a/gateway/src/shard/connection.rs +++ /dev/null @@ -1,193 +0,0 @@ -use std::{cmp::min, convert::TryInto, time::Duration}; - -use crate::{connection::Connection, error::GatewayError, payloads::{ - dispatch::Dispatch, - gateway::{BaseMessage, Message}, - }, shard::state::SessionState}; - -use super::{state::ConnectionState, ConnectionWithState, Shard}; -use futures::StreamExt; -use common::log::{error, info}; -use tokio::{select, time::{Instant, interval_at, sleep}}; - -impl Shard { - pub async fn start(self: &mut Self) { - let mut reconnects = 1; - info!("Starting shard"); - - while reconnects < self.config.max_reconnects { - info!("Starting connection for shard"); - if let Err(e) = self._shard_task().await { - error!("Gateway status: {:?}", e); - } - // when the shard got disconnected, the shard task ends - reconnects += 1; - - // wait reconnects min(max(reconnects * reconnect_delay_growth_factor, reconnect_delay_minimum),reconnect_delay_maximum) - if reconnects < self.config.max_reconnects { - let time = min( - self.config.reconnect_delay_maximum, - self.config.reconnect_delay_minimum * (((reconnects - 1) as f32) * self.config.reconnect_delay_growth_factor) as usize, - ); - info!( - "The shard got disconnected, waiting for reconnect ({}ms)", - time - ); - sleep(Duration::from_millis(time.try_into().unwrap())).await; - } - } - info!( - "The shard got disconnected too many times and reached the maximum {}", - self.config.max_reconnects - ); - } - - async fn _shard_task(&mut self) -> Result<(), GatewayError> { - // create the new connection - let mut connection = Connection::new(); - connection.start().await.unwrap(); - self.connection = Some(ConnectionWithState { - conn: connection, - state: ConnectionState::new(), - }); - - loop { - if let Some(connection) = &mut self.connection { - if let Some(timer) = &mut connection.state.interval { - select!( - payload = connection.conn.next() => { - match payload { - Some(data) => match data { - Ok(message) => self._handle(&message).await, - Err(error) => { - return Err(GatewayError::from(format!("An error occured while being connected to Discord: {:?}", error).to_string())); - }, - }, - None => { - return Err(GatewayError::from("Connection terminated".to_string())); - }, - } - }, - _ = timer.tick() => match self._do_heartbeat().await { - Ok(_) => {}, - Err(error) => { - return Err(GatewayError::from(format!("An error occured while being connected to Discord: {:?}", error).to_string())); - }, - } - ) - } else { - select!( - payload = connection.conn.next() => { - match payload { - Some(data) => match data { - Ok(message) => self._handle(&message).await, - Err(error) => { - return Err(GatewayError::from(format!("An error occured while being connected to Discord: {:?}", error).to_string())); - }, - }, - None => { - return Err(GatewayError::from("Connection terminated".to_string())); - }, - } - } - ) - } - - } - } - } - - async fn _do_heartbeat(&mut self) -> Result<(), GatewayError> { - info!("heartbeat sent"); - if let Some(conn) = &mut self.connection { - if !conn.state.last_heartbeat_acknowledged { - error!("we missed a hertbeat"); - Err(GatewayError::from("a hertbeat was dropped, we need to restart the connection".to_string())) - } else { - conn.state.last_heartbeat_acknowledged = false; - conn.state.last_heartbeat_time = Instant::now(); - self._send_heartbeat().await - } - } else { - unreachable!() - } - } - - fn _util_set_seq(&mut self, seq: Option<u64>) { - if let Some(seq) = seq { - if let Some(state) = &mut self.state { - state.sequence = seq; - } - } - } - - async fn _handle(&mut self, message: &Message) { - match message { - Message::Dispatch(msg) => { - self._util_set_seq(msg.sequence); - self._dispatch(&msg).await; - } - // we need to reconnect to the gateway - Message::Reconnect(msg) => { - self._util_set_seq(msg.sequence); - info!("Gateway disconnect requested"); - self._disconnect().await; - } - Message::InvalidSession(msg) => { - self._util_set_seq(msg.sequence); - info!("invalid session"); - let data = msg.data; - if !data { - info!("Session removed"); - // reset the session data - self.state = None; - if let Err(e) = self._identify().await { - error!("Error while sending identify: {:?}", e); - } - } - } - Message::HeartbeatACK(msg) => { - info!("Heartbeat ack received"); - self._util_set_seq(msg.sequence); - if let Some(conn) = &mut self.connection { - conn.state.last_heartbeat_acknowledged = true; - let latency = Instant::now() - conn.state.last_heartbeat_time; - info!("Latency updated {}ms", latency.as_millis()); - } - } - Message::Hello(msg) => { - info!("Server hello received"); - self._util_set_seq(msg.sequence); - if let Some(conn) = &mut self.connection { - conn.state.interval = Some(interval_at( - Instant::now() + Duration::from_millis(msg.data.heartbeat_interval), - Duration::from_millis(msg.data.heartbeat_interval), - )); - } - - if let Err(e) = self._identify().await { - error!("error while sending identify: {:?}", e); - } - }, - } - } - - async fn _dispatch(&mut self, dispatch: &BaseMessage<Dispatch>) { - match &dispatch.data { - Dispatch::Ready(ready) => { - info!("Received gateway dispatch ready"); - info!( - "Logged in as {}", - ready.user.get("username").unwrap().to_string() - ); - self.state = Some(SessionState { - sequence: dispatch.sequence.unwrap(), - session_id: ready.session_id.clone(), - }); - } - Dispatch::Other(_data) => { - // todo: build dispatch & forward to nats - } - } - } -} diff --git a/gateway/src/shard/mod.rs b/gateway/src/shard/mod.rs deleted file mode 100644 index 55828d0..0000000 --- a/gateway/src/shard/mod.rs +++ /dev/null @@ -1,49 +0,0 @@ -use enumflags2::BitFlags; -use serde::{Deserialize, Serialize}; -use crate::{connection::Connection, payloads::opcodes::identify::Intents}; -use self::state::{ConnectionState, SessionState}; -mod actions; -mod connection; -mod state; - -#[derive(Debug, Deserialize, Serialize, Default, Clone)] -pub struct Sharding { - pub total_shards: u64, - pub current_shard: u64 -} - -#[derive(Debug, Deserialize, Serialize, Default, Clone)] -pub struct ShardConfig { - pub max_reconnects: usize, - pub reconnect_delay_growth_factor: f32, - pub reconnect_delay_minimum: usize, - pub reconnect_delay_maximum: usize, - pub token: String, - - pub large_threshold: Option<u64>, - pub shard: Option<Sharding>, - pub intents: BitFlags<Intents> -} - -struct ConnectionWithState { - conn: Connection, - state: ConnectionState, -} - -/// Represents a shard & all the reconnection logic related to it -pub struct Shard { - connection: Option<ConnectionWithState>, - state: Option<SessionState>, - config: ShardConfig -} - -impl Shard { - /// Creates a new shard instance - pub fn new(config: ShardConfig) -> Self { - Shard { - connection: None, - state: None, - config, - } - } -} diff --git a/gateway/src/shard/state.rs b/gateway/src/shard/state.rs deleted file mode 100644 index 34b7acc..0000000 --- a/gateway/src/shard/state.rs +++ /dev/null @@ -1,35 +0,0 @@ -use tokio::time::{Instant, Interval}; - -/// This struct represents the state of a session -#[derive(Clone, Debug)] -pub struct SessionState { - pub sequence: u64, - pub session_id: String, -} - -impl Default for SessionState { - fn default() -> Self { - Self { - sequence: Default::default(), - session_id: Default::default(), - } - } -} - -/// This struct represents the state of a connection -#[derive(Debug)] -pub struct ConnectionState { - pub last_heartbeat_acknowledged: bool, - pub last_heartbeat_time: Instant, - pub interval: Option<Interval>, - -} -impl ConnectionState { - pub fn new() -> Self { - Self { - last_heartbeat_acknowledged: true, - last_heartbeat_time: Instant::now(), - interval: None, - } - } -}
\ No newline at end of file diff --git a/gateway/src/utils.rs b/gateway/src/utils.rs deleted file mode 100644 index 48a9aed..0000000 --- a/gateway/src/utils.rs +++ /dev/null @@ -1,8 +0,0 @@ -/// Formats a url of connection to the gateway -pub fn get_gateway_url (compress: bool, encoding: &str, v: i16) -> String { - return format!( - "wss://gateway.discord.gg/?v={}&encoding={}&compress={}", - v, encoding, - if compress { "zlib-stream" } else { "" } - ); -}
\ No newline at end of file |
