From: Matthieu Date: Tue, 7 Sep 2021 14:30:24 +0000 (+0400) Subject: base for the new gateway implementation X-Git-Tag: v0.1~64^2^2~37 X-Git-Url: https://git.puffer.fish/?a=commitdiff_plain;h=5846af709a8214b2cfacb00bbeb631d394131072;p=matthieu%2Fnova.git base for the new gateway implementation --- diff --git a/gateway/BUILD b/gateway/BUILD new file mode 100644 index 0000000..4474c47 --- /dev/null +++ b/gateway/BUILD @@ -0,0 +1,12 @@ +load("@rules_rust//rust:rust.bzl", "rust_binary", "rust_library") +load("@crates//:defs.bzl", "crates_from") + +exports_files(["Cargo.toml"]) + +rust_binary( + name = "gateway", + srcs = glob(["src/**/*.rs"]), + rustc_env = {}, + deps = crates_from("Cargo.toml"), + visibility = ["//visibility:public"], +) diff --git a/gateway/BUILD.bazel b/gateway/BUILD.bazel deleted file mode 100644 index 97af3b0..0000000 --- a/gateway/BUILD.bazel +++ /dev/null @@ -1,21 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") - -go_library( - name = "gateway_lib", - srcs = ["main.go"], - importpath = "github.com/discordnova/nova/gateway", - visibility = ["//visibility:public"], - deps = [ - "//common", - "//gateway/lib/gateway", - "//gateway/lib/gateway/compression", - "//gateway/lib/gateway/transporters", - "@com_github_rs_zerolog//log", - ], -) - -go_binary( - name = "gateway", - embed = [":gateway_lib"], - visibility = ["//visibility:public"], -) \ No newline at end of file diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml new file mode 100644 index 0000000..0182309 --- /dev/null +++ b/gateway/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "gateway" +version = "0.1.0" +edition = "2018" + +[dependencies] +tokio = { version = "1", features = ["full"] } +tokio-tungstenite = { version = "*", features = ["rustls-tls"] } +url = "2.2.2" +futures-util = "0.3.17" +log = { version = "0.4", features = ["std"] } +pretty_env_logger = "0.4" +serde_json = { version = "1.0" } +serde = { version = "1.0", features = ["derive"] } +tokio-stream = "0.1.7" +async-stream = "0.3.2" +futures-core = "0.3.17" +serde_repr = "0.1" diff --git a/gateway/cargo/BUILD.bazel b/gateway/cargo/BUILD.bazel new file mode 100644 index 0000000..bf7bb77 --- /dev/null +++ b/gateway/cargo/BUILD.bazel @@ -0,0 +1,129 @@ +""" +@generated +cargo-raze generated Bazel file. + +DO NOT EDIT! Replaced on runs of cargo-raze +""" + +package(default_visibility = ["//visibility:public"]) + +licenses([ + "notice", # See individual crates for specific licenses +]) + +# Aliased targets +alias( + name = "async_stream", + actual = "@raze__async_stream__0_3_2//:async_stream", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "futures_core", + actual = "@raze__futures_core__0_3_17//:futures_core", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "futures_util", + actual = "@raze__futures_util__0_3_17//:futures_util", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "log", + actual = "@raze__log__0_4_14//:log", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "pretty_env_logger", + actual = "@raze__pretty_env_logger__0_4_0//:pretty_env_logger", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "serde", + actual = "@raze__serde__1_0_130//:serde", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "serde_json", + actual = "@raze__serde_json__1_0_67//:serde_json", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "serde_repr", + actual = "@raze__serde_repr__0_1_7//:serde_repr", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "tokio", + actual = "@raze__tokio__1_11_0//:tokio", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "tokio_stream", + actual = "@raze__tokio_stream__0_1_7//:tokio_stream", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "tokio_tungstenite", + actual = "@raze__tokio_tungstenite__0_15_0//:tokio_tungstenite", + tags = [ + "cargo-raze", + "manual", + ], +) + +alias( + name = "url", + actual = "@raze__url__2_2_2//:url", + tags = [ + "cargo-raze", + "manual", + ], +) + +# Export file for Stardoc support +exports_files( + [ + "crates.bzl", + ], + visibility = ["//visibility:public"], +) diff --git a/gateway/lib/BUILD.bazel b/gateway/lib/BUILD.bazel deleted file mode 100644 index e69de29..0000000 diff --git a/gateway/lib/gateway/BUILD.bazel b/gateway/lib/gateway/BUILD.bazel deleted file mode 100644 index 14d2bac..0000000 --- a/gateway/lib/gateway/BUILD.bazel +++ /dev/null @@ -1,24 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") - -go_library( - name = "gateway", - srcs = [ - "gateway.go", - "options.go", - ], - importpath = "github.com/discordnova/nova/gateway/lib/gateway", - visibility = ["//gateway:__subpackages__"], - deps = [ - "//common/discord/types/payloads/gateway", - "//common/discord/types/payloads/gateway/commands", - "//common/discord/types/payloads/gateway/events", - "//common/discord/types/structures", - "//common/discord/types/types", - "//common/gateway", - "@com_github_boz_go_throttle//:go-throttle", - "@com_github_gorilla_websocket//:websocket", - "@com_github_prometheus_client_golang//prometheus", - "@com_github_prometheus_client_golang//prometheus/promauto", - "@com_github_rs_zerolog//log", - ], -) diff --git a/gateway/lib/gateway/compression/BUILD.bazel b/gateway/lib/gateway/compression/BUILD.bazel deleted file mode 100644 index 1977322..0000000 --- a/gateway/lib/gateway/compression/BUILD.bazel +++ /dev/null @@ -1,13 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") - -go_library( - name = "compression", - srcs = ["json-zlib.go"], - importpath = "github.com/discordnova/nova/gateway/lib/gateway/compression", - visibility = ["//gateway:__subpackages__"], - deps = [ - "//common/discord/types/payloads/gateway", - "//common/gateway", - "@com_github_rs_zerolog//log", - ], -) diff --git a/gateway/lib/gateway/compression/json-zlib.go b/gateway/lib/gateway/compression/json-zlib.go deleted file mode 100644 index 7f64bbf..0000000 --- a/gateway/lib/gateway/compression/json-zlib.go +++ /dev/null @@ -1,81 +0,0 @@ -package compression - -import ( - "bytes" - "compress/zlib" - "encoding/json" - "fmt" - "io" - - "github.com/rs/zerolog/log" - gatewayTypes "github.com/discordnova/nova/common/discord/types/payloads/gateway" - "github.com/discordnova/nova/common/gateway" -) - -// JsonZlibCompressor is the default compression interface. -type JsonZlibCompressor struct { - buffer *bytes.Buffer - reader io.ReadCloser -} - -// NewJsonZlibCompressor creates an instance of JsonZlibCompressor -func NewJsonZlibCompressor() gateway.Compression { - return &JsonZlibCompressor{ - buffer: bytes.NewBuffer([]byte{}), - } -} - -func (compressor *JsonZlibCompressor) Reset() error { - compressor.buffer.Reset() - if compressor.reader == nil { - return nil - } - err := compressor.reader.Close() - if err != nil { - return err - } - compressor.reader = nil - return nil -} - -// GetConnectionOptions gets the required options for the gateway. -func (compressor JsonZlibCompressor) GetConnectionOptions() gateway.GatewayConnectionOptions { - // Gateway options for the discord gateway. - return gateway.GatewayConnectionOptions{ - Encoding: "json", - TransportCompression: "zlib-stream", - } -} - -// DecodeMessage decodes messages using the compressor. -func (compressor *JsonZlibCompressor) DecodeMessage(data []byte) (*gatewayTypes.Payload, error) { - - // check if the message have the zlib suffix to avoid ruining our zlib context :'( - if !bytes.HasSuffix(data, []byte{0x00, 0x00, 0xff, 0xff}) { - return nil, fmt.Errorf("the gateway failed to verify the message validity due to invalid suffix") - } - - // add the data to the buffer for the decompression. - compressor.buffer.Write(data) - - // we can't create the reader without data, so we initialize on the first decompression. - if compressor.reader == nil { - reader, err := zlib.NewReader(compressor.buffer) - if err != nil { - log.Err(err).Msgf("Failed to initialize zlib reader") - } - compressor.reader = reader - } - - // we unmarshal the reader as json - inter := gatewayTypes.Payload{} - decoder := json.NewDecoder(compressor.reader) - err := decoder.Decode(&inter) - - if err != nil { - // the unmarshalling failed - return nil, err - } - - return &inter, nil -} diff --git a/gateway/lib/gateway/gateway.go b/gateway/lib/gateway/gateway.go deleted file mode 100644 index ee9d13e..0000000 --- a/gateway/lib/gateway/gateway.go +++ /dev/null @@ -1,374 +0,0 @@ -package gateway - -import ( - "encoding/json" - "fmt" - "os" - "runtime" - "time" - - "github.com/boz/go-throttle" - gatewayTypes "github.com/discordnova/nova/common/discord/types/payloads/gateway" - "github.com/discordnova/nova/common/discord/types/payloads/gateway/commands" - "github.com/discordnova/nova/common/discord/types/payloads/gateway/events" - "github.com/discordnova/nova/common/discord/types/structures" - "github.com/discordnova/nova/common/discord/types/types" - "github.com/discordnova/nova/common/gateway" - "github.com/gorilla/websocket" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/rs/zerolog/log" -) - -// connectionState is a struct representing a connection state -type connectionState struct { - HeartbeatInterval uint16 - Latency int64 -} - -var ( - messagesCounter = promauto.NewCounter(prometheus.CounterOpts{ - Name: "nova_gateway_messages_processed", - Help: "The total number of processed messages", - }) - - heartbeatCounter = promauto.NewCounter(prometheus.CounterOpts{ - Name: "nova_gateway_heartbeat_sent", - Help: "The total number of heartbeat sent", - }) - - latencyGauge = promauto.NewGauge(prometheus.GaugeOpts{ - Name: "nova_gateway_latency", - Help: "The round trip latency of the gateway", - }) - - reconnectionsCounter = promauto.NewCounter(prometheus.CounterOpts{ - Name: "nova_gateway_reconnections", - Help: "the number of reconnections of the gateway", - }) - - eventsCounter = promauto.NewCounter(prometheus.CounterOpts{ - Name: "nova_gateway_events", - Help: "The various events received by Nova.", - }) -) - -// GatewayConnector represents a connection to the discord gateway for a shard -type GatewayConnector struct { - // Public State - SessionState GatewayConnectorOptionsResume // The state of the session - - // Private state - connectionState connectionState // The internal state of the gateway connection. - options GatewayConnectorOptions // The connection options. - connection *websocket.Conn // The current websocket connection. - heartbeat chan struct{} // Channel for reacting to heartbeat acks - terminate chan string // Called when a gateway disconnect is requested - updateThrottle throttle.ThrottleDriver -} - -// NewGateway creates a connector instance based on the given options. -func NewGateway(options GatewayConnectorOptions) *GatewayConnector { - return &GatewayConnector{ - options: options, - SessionState: options.ResumeSession, - } -} - -// Start is used to start or reset a connection to the gateway. -func (discord *GatewayConnector) Start() { - shouldStart := true - for shouldStart { - reconnectionsCounter.Inc() - discord.connectionState = connectionState{} - _ = discord.start() - err := discord.options.Compressor.Reset() - if err != nil { - log.Fatal().Msgf("failed to reset the compressor") - } - shouldStart = *discord.options.Restart - if shouldStart { - log.Info().Msg("waiting 10s before gateway reconnection") - time.Sleep(time.Second * 10) - } - } -} - -// start is the internal routine for starting the gateway -func (discord *GatewayConnector) start() error { - // we throttle the update function to limit the amount of session state - // presisted to the session persistence interface - discord.updateThrottle = throttle.ThrottleFunc(time.Second*5, false, func() { - if discord.options.OnSessionStateUpdated != nil { - _ = discord.options.OnSessionStateUpdated(discord.SessionState) - } - }) - - // initialize the message channels - discord.heartbeat = make(chan struct{}) - discord.terminate = make(chan string) - - // since a Compressor is given to the gateway when created, we use the Connector to get - // the compression and encoding options. - comOptions := discord.options.Compressor.GetConnectionOptions() - websocketURL := fmt.Sprintf("wss://gateway.discord.gg/?v=%d&encoding=%s&compress=%s", 6, comOptions.Encoding, comOptions.TransportCompression) - - log.Info().Msgf("connecting to the gateway at url %s", websocketURL) - // we start the connection to discord. - connection, _, err := websocket.DefaultDialer.Dial(websocketURL, nil) - if err != nil { - log.Err(err).Msgf("an error occurred while connecting to the gateway") - return err - } - discord.connection = connection - defer discord.connection.Close() - - // start listening to messages on the socket. - go discord.listen() - - msg := <-discord.terminate - log.Info().Msgf("terminating the gateway: %s", msg) - - return nil -} - -// ticker starts the loop for the heartbeat checks -func (discord *GatewayConnector) ticker(interval int) { - // sends a message to heartbeat.C every time we need to send a heartbeat - heartbeat := time.NewTicker(time.Duration(interval) * time.Millisecond) - // stores if the last heartbeat succeeded - doneLastAck := true - - // executes the given code when heartbeat.C is triggered - for range heartbeat.C { - // if the server did not send the last heartbeat - if !doneLastAck { - // we need to terminate the connection - discord.terminate <- "server missed an ack and must be disconnected" - return - } - - log.Debug().Msg("Sending a heartbeat.") - - index, _ := json.Marshal(discord.SessionState.Index) - err := discord.connection.WriteJSON(gatewayTypes.Payload{ - Op: types.GatewayOpCodeHeartbeat, - D: index, - }) - - if err != nil { - discord.terminate <- fmt.Sprintf("failed to send a heartbeat: %s", err.Error()) - return - } - - heartbeatCounter.Inc() - // wait for the ack asynchronously - go func() { - start := time.Now() - doneLastAck = false - <-discord.heartbeat - doneLastAck = true - - discord.connectionState.Latency = time.Since(start).Milliseconds() - latencyGauge.Set(float64(discord.connectionState.Latency)) - - log.Info().Msgf("heartbeat completed, latency: %dms", discord.connectionState.Latency) - }() - - } -} - -// listen listens to the messages on the gateway -func (discord *GatewayConnector) listen() { - for { - _, message, err := discord.connection.ReadMessage() - - if err != nil { - discord.terminate <- fmt.Sprintf("the connection was closed by the gateway: %s", err.Error()) - return - } - - messagesCounter.Inc() - data, err := discord.options.Compressor.DecodeMessage(message) - - if err != nil || data == nil { - log.Print(err.Error()) - continue - } - - if data.S != 0 { - discord.SessionState.Index = data.S - discord.updateState(data.S, "") - } - - discord.handleMessage(data) - } -} - -func (discord *GatewayConnector) updateState(newIndex int64, sessionId string) { - discord.SessionState.Index = newIndex - if sessionId != "" { - discord.SessionState.Session = sessionId - } - discord.updateThrottle.Trigger() -} - -func (discord *GatewayConnector) handleMessage(message *gatewayTypes.Payload) { - switch message.Op { - // call the startup function - case types.GatewayOpCodeHello: - discord.hello(message) - // notify the heartbeat goroutine that a heartbeat ack was received - case types.GatewayOpCodeHeartbeatACK: - discord.heartbeat <- struct{}{} - // handles a dispatch from the gateway - case types.GatewayOpCodeDispatch: - discord.dispatch(message) - // when the session resume fails - case types.GatewayOpCodeInvalidSession: - log.Print("failed to resume the session, reconnecting") - discord.updateState(0, "") - discord.doLogin() - // when the gateway requests a reconnect - case types.GatewayOpCodeReconnect: - log.Print("the gateway requested a reconnect") - if string(message.D) != "true" { - // we may delete the session state - discord.SessionState.Index = 0 - discord.updateState(0, "") - } - discord.terminate <- "the gateway requested a reconnect" - } -} - -func (discord *GatewayConnector) doLogin() { - var payload gatewayTypes.Payload - // if we do not have to resume a session - if discord.SessionState.Session == "" { - log.Info().Msg("using identify for authentification") - data, err := json.Marshal(commands.GatewayCommandIdentifyPayload{ - Token: *discord.options.Token, - Properties: structures.IdentifyConnectionProperties{ - OS: runtime.GOOS, - Device: "Nova Discord Client", - Browser: "Nova Discord Client", - }, - Compress: true, - LargeThreshold: 1000, - Shard: []int{ - *discord.options.SelfShard, - *discord.options.TotalShard, - }, - Presence: commands.GatewayCommandUpdateStatusPayload{}, - GuildSubscriptions: *discord.options.GuildSubs, - Intents: discord.options.Intend, - }) - - if err != nil { - return - } - - payload = gatewayTypes.Payload{ - Op: types.GatewayOpCodeIdentify, - D: data, - } - } else { - log.Info().Msg("resuming session") - data, err := json.Marshal(commands.GatewayCommandResumePayload{ - Token: *discord.options.Token, - SessionID: discord.SessionState.Session, - Seq: discord.SessionState.Index, - }) - - if err != nil { - return - } - payload = gatewayTypes.Payload{ - Op: types.GatewayOpCodeResume, - D: data, - } - } - - err := discord.connection.WriteJSON(payload) - if err != nil { - log.Err(err).Msgf("failed send the identify payload") - } -} - -func (discord *GatewayConnector) hello(message *gatewayTypes.Payload) { - - data := &events.GatewayEventHelloPayload{} - err := json.Unmarshal(message.D, &data) - if err != nil { - discord.terminate <- fmt.Sprintf("invalid payload: %s", err.Error()) - } - - // start the heartbeat goroutine - log.Debug().Msgf("hello recevied, heartbeating every %d ms", data.HeartbeatInterval) - go discord.ticker(data.HeartbeatInterval) - - // login - discord.doLogin() -} - -type NovaMessage struct { - Data json.RawMessage `json:"data"` - Tracing struct { - NodeName string `json:"node_name"` - } `json:"tracing"` -} - -func (discord *GatewayConnector) dispatch(message *gatewayTypes.Payload) { - // since this is juste a event gateway, we do not care about the content of the events - // except the ready, resumed, reconnect event we use to update the session_id, the other events are forwarded to the transporter - if message.T == "READY" { - event := events.GatewayEventReadyPayload{} - err := json.Unmarshal(message.D, &event) - - log.Info().Msgf("logged in as %s", event.User.Username) - - if err != nil { - discord.terminate <- "invalid ready event" - return - } - - discord.updateState(discord.SessionState.Index, event.SessionID) - return - } - - newName := gateway.EventNames[message.T] - - if newName == "" { - log.Error().Msgf("unknown event name: %s", newName) - return - } - - name, err := os.Hostname() - - if err != nil { - log.Err(err).Msgf("failed to get the hostname") - return - } - - data, err := json.Marshal(NovaMessage{ - Data: message.D, - Tracing: struct { - NodeName string `json:"node_name"` - }{ - NodeName: name, - }, - }) - - if err != nil { - log.Err(err).Msg("failed to serialize the outgoing nova message") - } - - discord.options.Transporter.PushChannel() <- gateway.PushData{ - Data: data, - Name: newName, - } - - if err != nil { - log.Err(err).Msg("failed to send the event to the nova event broker") - } -} diff --git a/gateway/lib/gateway/options.go b/gateway/lib/gateway/options.go deleted file mode 100644 index 3ca90a6..0000000 --- a/gateway/lib/gateway/options.go +++ /dev/null @@ -1,34 +0,0 @@ -package gateway - -import ( - "github.com/discordnova/nova/common/discord/types/types" - "github.com/discordnova/nova/common/gateway" -) - -// GatewayConnectorOptionsResume represents the options for reconnecting the gateway. -type GatewayConnectorOptionsResume struct { - Session string `json:"session_id"` // The session id of the older session we want to resume. - Index int64 `json:"index"` // The index of the last packet recevied by the older session. -} - -// GatewayConnectorOptionsSharding represents the options for sharding the gateway. -type GatewayConnectorOptionsSharding struct { - TotalShards int `json:"total_shards"` // The total amount of shards - CurrentShard int `json:"current_shard"` // The shard we want to connect to. -} - -// GatewayConnectorOptions is the options given to the GatewayConnector when creating it. -type GatewayConnectorOptions struct { - Token *string // The token of the bot - SelfShard *int // The shard of the current connector - TotalShard *int // The total count of shards - Intend types.GatewayIntents // The bitflags for the indents. - GuildSubs *bool // Should the guild_subscriptions be enabled - ResumeSession GatewayConnectorOptionsResume // Is specified, the gateway will try to resume a connection. - Compressor gateway.Compression // The compressor given to the gateway that determine the connection method and compression used. - Transporter gateway.Transporter // The interface where we send the data. - Restart *bool // Should the gateway restart upon failure. - - OnSessionStateUpdated func(state GatewayConnectorOptionsResume) error // When the session state is called, we call this function - SessionUpdateFrequency *int -} diff --git a/gateway/lib/gateway/transporters/BUILD.bazel b/gateway/lib/gateway/transporters/BUILD.bazel deleted file mode 100644 index c5fcb2f..0000000 --- a/gateway/lib/gateway/transporters/BUILD.bazel +++ /dev/null @@ -1,13 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") - -go_library( - name = "transporters", - srcs = ["rabbitmq.go"], - importpath = "github.com/discordnova/nova/gateway/lib/gateway/transporters", - visibility = ["//gateway:__subpackages__"], - deps = [ - "//common/gateway", - "@com_github_rs_zerolog//log", - "@com_github_streadway_amqp//:amqp", - ], -) diff --git a/gateway/lib/gateway/transporters/rabbitmq.go b/gateway/lib/gateway/transporters/rabbitmq.go deleted file mode 100644 index 1a163ad..0000000 --- a/gateway/lib/gateway/transporters/rabbitmq.go +++ /dev/null @@ -1,76 +0,0 @@ -package transporters - -import ( - "time" - - "github.com/discordnova/nova/common/gateway" - "github.com/rs/zerolog/log" - "github.com/streadway/amqp" -) - -type RabbitMqTransporter struct { - pullChannel chan []byte - pushChannel chan gateway.PushData -} - -// NewRabbitMqTransporter creates a rabbitmq transporter using a given url -func NewRabbitMqTransporter(url string) (gateway.Transporter, error) { - log.Info().Msg("connecting to the transporter using rabbitmq...") - conn, err := amqp.Dial(url) - - if err != nil { - return nil, err - } - - send, err := conn.Channel() - - if err != nil { - return nil, err - } - - err = send.ExchangeDeclare( - "nova_gateway_dispatch", - "topic", - true, - false, - false, - true, - nil, - ) - - if err != nil { - return nil, err - } - - pullChannel, pushChannel := make(chan []byte), make(chan gateway.PushData) - - go func() { - for { - data := <-pushChannel - send.Publish( - "nova_gateway_dispatch", - data.Name, - false, - false, - amqp.Publishing{ - Priority: 1, - Timestamp: time.Now(), - Type: data.Name, - Body: data.Data, - }, - ) - } - }() - - return &RabbitMqTransporter{ - pullChannel: pullChannel, - pushChannel: pushChannel, - }, nil -} - -func (t RabbitMqTransporter) PushChannel() chan gateway.PushData { - return t.pushChannel -} -func (t RabbitMqTransporter) PullChannel() chan []byte { - return t.pullChannel -} diff --git a/gateway/main.go b/gateway/main.go deleted file mode 100644 index d0857fc..0000000 --- a/gateway/main.go +++ /dev/null @@ -1,68 +0,0 @@ -package main - -import ( - "flag" - - "github.com/discordnova/nova/common" - "github.com/discordnova/nova/gateway/lib/gateway" - "github.com/discordnova/nova/gateway/lib/gateway/compression" - "github.com/discordnova/nova/gateway/lib/gateway/transporters" - "github.com/rs/zerolog/log" -) - -var ( - settings = gateway.GatewayConnectorOptions{ - Token: flag.String("token", "", "the discord token for the websocket connection"), - Restart: flag.Bool("restart", true, "should the gateway be restarted if an error occurs"), - GuildSubs: flag.Bool("guild-subscriptions", true, "should the guild subscription gateway flag set to true"), - SelfShard: flag.Int("shard", 0, "the shard id of this instance"), - TotalShard: flag.Int("shard-count", 1, "the total amount of shard"), - SessionUpdateFrequency: flag.Int("session-persist-frequence", 10, "the frequency of session persistence"), - } - - compressor = flag.String("compressor", "json-zlib", "the compressor used to connect") - transporter = flag.String("transporter", "rabbitmq", "the compressor used to connect") - monitoring = flag.Int("prometheus-port", 9000, "is this flag is set, a prometheus metrics endpoint will be exposed") - url = flag.String("transporter-url", "", "the url needed for rabbitmq to function") -) - -func validate(settings *gateway.GatewayConnectorOptions) { - if *settings.SelfShard > *settings.TotalShard { - log.Fatal().Msg("invalid config: the shard id must be inferior than the total shard value") - } else if *settings.SessionUpdateFrequency == 0 { - log.Fatal().Msg("invalid config: the session update frequency muse be greater than 0") - } else if *settings.Token == "" { - log.Fatal().Msg("invalid config: invalid token provided") - } else if *settings.TotalShard == 0 { - log.Fatal().Msg("invalid config: the total number of shards muse be greater than 0") - } -} - -func main() { - flag.Parse() - common.SetupLogger() - - if monitoring != nil { - go common.CreatePrometheus(*monitoring) - log.Debug().Msg("prometheus server called") - } - - if *compressor == "json-zlib" { - settings.Compressor = compression.NewJsonZlibCompressor() - } else { - log.Fatal().Msgf("unknown compressor specified: %s", *compressor) - } - - if *transporter == "rabbitmq" { - trns, err := transporters.NewRabbitMqTransporter(*url) - if err != nil { - log.Fatal().Msgf("failed to initialize the transporter: %s", err.Error()) - } - settings.Transporter = trns - } - - validate(&settings) - - gateway := gateway.NewGateway(settings) - gateway.Start() -} diff --git a/gateway/src/client/connexion.rs b/gateway/src/client/connexion.rs new file mode 100644 index 0000000..069bbf9 --- /dev/null +++ b/gateway/src/client/connexion.rs @@ -0,0 +1,218 @@ +use crate::client::payloads::{message::OpCodes, payloads::Hello}; + +use super::{ + payloads::message::MessageBase, + state::{Stage, State}, + utils::get_gateway_url, +}; +use futures_util::{ + SinkExt, StreamExt, +}; +use log::{error, info, warn}; +use std::{str::from_utf8, time::Duration}; +use tokio::{ + net::TcpStream, + select, + time::{Instant}, +}; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::{self, Message}}; + +#[derive(Debug)] +pub enum CloseReason { + ConnexionAlredyOpen, + ConnexionEnded, + ErrorEncountered(&'static str), + ConnexionError(tungstenite::Error), +} + +pub enum HandleResult { + Success, + Error(CloseReason), +} + +/// This struct represents a single connexion to the gateway, +/// it does not have any retry logic or reconnexion mechanism, +/// everything is handled in the Shard struct. +/// The purpose of this struct is to handle the encoding, +/// compression and other gateway-transport related stuff. +/// All the messages are send through another struct implementing +/// the MessageHandler trait. +pub struct Connexion { + state: State, + connexion: Option>>, +} + +impl Connexion { + /// Creates a new instance of a discord websocket connexion using the options + /// this is used internally by the shard struct to initialize a single + /// websocket connexion. This instance is not initialized by default. + /// a websocket connexion like this can be re-used multiple times + /// to allow reconnexion mechanisms. + pub async fn new() -> Self { + Connexion { + state: State::default(), + connexion: None, + } + } + + /// Terminate the connexion and the "start" method related to it. + async fn _terminate_websocket(&mut self, message: CloseReason) { + if let Some(connexion) = &mut self.connexion { + if let Err(err) = connexion.close(None).await { + error!("failed to close socket {}", err); + } else { + info!("closed the socket: {:?}", message) + } + } else { + warn!("a termination request was sent without a connexion openned") + } + } + + /// Initialize a connexion to the gateway + /// returns if a connexion is already present + pub async fn start(mut self) -> CloseReason { + if let Some(_) = self.connexion { + CloseReason::ConnexionAlredyOpen + } else { + // we reset the state before starting the connection + self.state = State::default(); + + let connexion_result = connect_async(get_gateway_url(false, "json", 9)).await; + // we connect outselves to the websocket server + if let Err(err) = connexion_result { + return CloseReason::ConnexionError(err) + } + self.connexion = Some(connexion_result.unwrap().0); + + // this is the loop that will maintain the whole connexion + loop { + if let Some(connexion) = &mut self.connexion { + // if we do not have a hello message received yet, then we do not use the heartbeat interval + // and we just wait for messages to arrive + if self.state.stage == Stage::Unknown { + let msg = connexion.next().await; + if let HandleResult::Error(reason) = self._handle_message(&msg).await { + return reason + } + } else { + let timer = self.state.interval.as_mut().unwrap().tick(); + select! { + msg = connexion.next() => { + if let HandleResult::Error(reason) = self._handle_message(&msg).await { + return reason + } + }, + _ = timer => self._do_heartbeat().await + } + } + } else { + return CloseReason::ConnexionEnded + } + } + } + } + + async fn _handle_message( + &mut self, + data: &Option>, + ) -> HandleResult { + if let Some(message) = data { + match message { + Ok(message) => match message { + Message::Text(text) => { + self._handle_discord_message(&text).await; + HandleResult::Success + } + Message::Binary(message) => { + self._handle_discord_message(from_utf8(message).unwrap()) + .await; + HandleResult::Success + } + Message::Close(_) => { + error!("discord connexion closed"); + HandleResult::Error(CloseReason::ConnexionEnded) + } + + _ => { + HandleResult::Error(CloseReason::ErrorEncountered("unsupported message type encountered")) + } + }, + Err(_error) => { + HandleResult::Error(CloseReason::ErrorEncountered("error while reading a message")) + } + } + } else { + HandleResult::Error(CloseReason::ErrorEncountered("error while reading a message")) + } + } + + async fn _handle_discord_message(&mut self, raw_message: &str) { + let a: Result = serde_json::from_str(raw_message); + let message = a.unwrap(); + + // handles the state + if let Some(index) = message.s { + self.state.sequence = index; + } + + match message.op { + OpCodes::Dispatch => todo!(), + OpCodes::Heartbeat => todo!(), + OpCodes::Identify => todo!(), + OpCodes::PresenceUpdate => todo!(), + OpCodes::VoiceStateUpdate => todo!(), + OpCodes::Resume => todo!(), + OpCodes::Reconnect => todo!(), + OpCodes::RequestGuildMembers => todo!(), + OpCodes::InvalidSession => todo!(), + OpCodes::Hello => { + if let Ok(hello) = serde_json::from_value::(message.d) { + info!("server sent hello {:?}", hello); + info!("heartbeating every {}ms", hello.heartbeat_interval); + self.state.interval = Some(tokio::time::interval_at( + Instant::now() + Duration::from_millis(hello.heartbeat_interval), + Duration::from_millis(hello.heartbeat_interval), + )); + self.state.stage = Stage::Initialized; + } + } + OpCodes::HeartbeatACK => { + info!( + "heartbeat acknowledged after {}ms", + (std::time::Instant::now() - self.state.last_heartbeat_time).as_millis() + ); + self.state.last_heartbeat_acknowledged = true; + } + } + } + + async fn _do_heartbeat(&mut self) { + if !self.state.last_heartbeat_acknowledged { + self._terminate_websocket(CloseReason::ErrorEncountered("the server did not acknowledged the last heartbeat")).await; + return; + } + self.state.last_heartbeat_acknowledged = false; + + info!("sending heartbeat"); + self._send( + serde_json::to_vec(&MessageBase { + t: None, + d: serde_json::to_value(self.state.sequence).unwrap(), + s: None, + op: OpCodes::Heartbeat, + }) + .unwrap(), + ) + .await; + self.state.last_heartbeat_time = std::time::Instant::now(); + } + + async fn _send(&mut self, data: Vec) { + if let Some(connexion) = &mut self.connexion { + if let Err(error) = connexion.send(Message::Binary(data)).await { + error!("failed to write to socket: {}", error); + self._terminate_websocket(CloseReason::ErrorEncountered("failed to write to the socket")).await; + } + } + } +} diff --git a/gateway/src/client/error.rs b/gateway/src/client/error.rs new file mode 100644 index 0000000..bac6894 --- /dev/null +++ b/gateway/src/client/error.rs @@ -0,0 +1,20 @@ +#[derive(Debug)] +struct MyError(String); + +impl fmt::Display for MyError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "There is an error: {}", self.0) + } +} + +impl Error for NovaError {} + +pub fn run() -> Result<(), Box> { + let condition = true; + + if condition { + return Err(Box::new(MyError("Oops".into()))); + } + + Ok(()) +} \ No newline at end of file diff --git a/gateway/src/client/mod.rs b/gateway/src/client/mod.rs new file mode 100644 index 0000000..179c40d --- /dev/null +++ b/gateway/src/client/mod.rs @@ -0,0 +1,6 @@ +pub mod connexion; +mod utils; +mod state; +mod shard; +pub mod payloads; +pub mod traits; \ No newline at end of file diff --git a/gateway/src/client/payloads/message.rs b/gateway/src/client/payloads/message.rs new file mode 100644 index 0000000..4b2a657 --- /dev/null +++ b/gateway/src/client/payloads/message.rs @@ -0,0 +1,28 @@ +use serde_json::Value; +use serde_repr::{Serialize_repr, Deserialize_repr}; +use serde::{Deserialize, Serialize}; + + +#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug)] +#[repr(u8)] +pub enum OpCodes { + Dispatch = 0, + Heartbeat = 1, + Identify = 2, + PresenceUpdate = 3, + VoiceStateUpdate = 4, + Resume = 6, + Reconnect = 7, + RequestGuildMembers = 8, + InvalidSession = 9, + Hello = 10, + HeartbeatACK = 11, +} + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +pub struct MessageBase { + pub t: Option, + pub s: Option, + pub op: OpCodes, + pub d: Value +} diff --git a/gateway/src/client/payloads/mod.rs b/gateway/src/client/payloads/mod.rs new file mode 100644 index 0000000..d0a5e38 --- /dev/null +++ b/gateway/src/client/payloads/mod.rs @@ -0,0 +1,2 @@ +pub mod payloads; +pub mod message; \ No newline at end of file diff --git a/gateway/src/client/payloads/payloads.rs b/gateway/src/client/payloads/payloads.rs new file mode 100644 index 0000000..bcbdeb0 --- /dev/null +++ b/gateway/src/client/payloads/payloads.rs @@ -0,0 +1,10 @@ +use serde::{Serialize, Deserialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Hello { + #[serde(rename = "heartbeat_interval")] + pub heartbeat_interval: u64 +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct HeartbeatACK {} \ No newline at end of file diff --git a/gateway/src/client/shard.rs b/gateway/src/client/shard.rs new file mode 100644 index 0000000..fb1ceda --- /dev/null +++ b/gateway/src/client/shard.rs @@ -0,0 +1,6 @@ + + + +struct Shard { + +} \ No newline at end of file diff --git a/gateway/src/client/state.rs b/gateway/src/client/state.rs new file mode 100644 index 0000000..553fea7 --- /dev/null +++ b/gateway/src/client/state.rs @@ -0,0 +1,29 @@ +use std::time::Instant; +use tokio::time::Interval; + +#[derive(PartialEq)] +pub enum Stage { + Unknown, + Initialized, + LoggedIn, +} + +pub struct State { + pub stage: Stage, + pub sequence: i64, + pub last_heartbeat_acknowledged: bool, + pub last_heartbeat_time: Instant, + pub interval: Option, +} + +impl State { + pub fn default() -> Self { + State { + sequence: 0, + interval: None, + stage: Stage::Unknown, + last_heartbeat_acknowledged: true, + last_heartbeat_time: std::time::Instant::now(), + } + } +} diff --git a/gateway/src/client/traits/message_handler.rs b/gateway/src/client/traits/message_handler.rs new file mode 100644 index 0000000..a5bfd20 --- /dev/null +++ b/gateway/src/client/traits/message_handler.rs @@ -0,0 +1,3 @@ +/// This trait is used by the Connexion struct +/// It implements a basic interface for handling events. +pub trait MessageHandler {} \ No newline at end of file diff --git a/gateway/src/client/traits/mod.rs b/gateway/src/client/traits/mod.rs new file mode 100644 index 0000000..98d0c32 --- /dev/null +++ b/gateway/src/client/traits/mod.rs @@ -0,0 +1 @@ +pub mod message_handler; \ No newline at end of file diff --git a/gateway/src/client/utils.rs b/gateway/src/client/utils.rs new file mode 100644 index 0000000..023b6b9 --- /dev/null +++ b/gateway/src/client/utils.rs @@ -0,0 +1,7 @@ +pub fn get_gateway_url (compress: bool, encoding: &str, v: i16) -> String { + return format!( + "wss://gateway.discord.gg/?v={}&encoding={}&compress={}", + v, encoding, + if compress { "zlib-stream" } else { "" } + ); +} \ No newline at end of file diff --git a/gateway/src/main.rs b/gateway/src/main.rs new file mode 100644 index 0000000..8af7505 --- /dev/null +++ b/gateway/src/main.rs @@ -0,0 +1,18 @@ +use client::traits::message_handler::MessageHandler; +extern crate serde_json; + +mod client; + +struct Handler {} +impl MessageHandler for Handler {} + +#[tokio::main] +async fn main() { + pretty_env_logger::init(); + for _ in 0..1 { + tokio::spawn(async move { + let con = client::connexion::Connexion::new().await; + con.start().await; + }).await.unwrap(); + } +} \ No newline at end of file