--- /dev/null
+load("@rules_rust//rust:rust.bzl", "rust_binary", "rust_library")
+load("@crates//:defs.bzl", "crates_from")
+
+exports_files(["Cargo.toml"])
+
+rust_binary(
+ name = "gateway",
+ srcs = glob(["src/**/*.rs"]),
+ rustc_env = {},
+ deps = crates_from("Cargo.toml"),
+ visibility = ["//visibility:public"],
+)
+++ /dev/null
-load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
-
-go_library(
- name = "gateway_lib",
- srcs = ["main.go"],
- importpath = "github.com/discordnova/nova/gateway",
- visibility = ["//visibility:public"],
- deps = [
- "//common",
- "//gateway/lib/gateway",
- "//gateway/lib/gateway/compression",
- "//gateway/lib/gateway/transporters",
- "@com_github_rs_zerolog//log",
- ],
-)
-
-go_binary(
- name = "gateway",
- embed = [":gateway_lib"],
- visibility = ["//visibility:public"],
-)
\ No newline at end of file
--- /dev/null
+[package]
+name = "gateway"
+version = "0.1.0"
+edition = "2018"
+
+[dependencies]
+tokio = { version = "1", features = ["full"] }
+tokio-tungstenite = { version = "*", features = ["rustls-tls"] }
+url = "2.2.2"
+futures-util = "0.3.17"
+log = { version = "0.4", features = ["std"] }
+pretty_env_logger = "0.4"
+serde_json = { version = "1.0" }
+serde = { version = "1.0", features = ["derive"] }
+tokio-stream = "0.1.7"
+async-stream = "0.3.2"
+futures-core = "0.3.17"
+serde_repr = "0.1"
--- /dev/null
+"""
+@generated
+cargo-raze generated Bazel file.
+
+DO NOT EDIT! Replaced on runs of cargo-raze
+"""
+
+package(default_visibility = ["//visibility:public"])
+
+licenses([
+ "notice", # See individual crates for specific licenses
+])
+
+# Aliased targets
+alias(
+ name = "async_stream",
+ actual = "@raze__async_stream__0_3_2//:async_stream",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "futures_core",
+ actual = "@raze__futures_core__0_3_17//:futures_core",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "futures_util",
+ actual = "@raze__futures_util__0_3_17//:futures_util",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "log",
+ actual = "@raze__log__0_4_14//:log",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "pretty_env_logger",
+ actual = "@raze__pretty_env_logger__0_4_0//:pretty_env_logger",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "serde",
+ actual = "@raze__serde__1_0_130//:serde",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "serde_json",
+ actual = "@raze__serde_json__1_0_67//:serde_json",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "serde_repr",
+ actual = "@raze__serde_repr__0_1_7//:serde_repr",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "tokio",
+ actual = "@raze__tokio__1_11_0//:tokio",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "tokio_stream",
+ actual = "@raze__tokio_stream__0_1_7//:tokio_stream",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "tokio_tungstenite",
+ actual = "@raze__tokio_tungstenite__0_15_0//:tokio_tungstenite",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+alias(
+ name = "url",
+ actual = "@raze__url__2_2_2//:url",
+ tags = [
+ "cargo-raze",
+ "manual",
+ ],
+)
+
+# Export file for Stardoc support
+exports_files(
+ [
+ "crates.bzl",
+ ],
+ visibility = ["//visibility:public"],
+)
+++ /dev/null
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-
-go_library(
- name = "gateway",
- srcs = [
- "gateway.go",
- "options.go",
- ],
- importpath = "github.com/discordnova/nova/gateway/lib/gateway",
- visibility = ["//gateway:__subpackages__"],
- deps = [
- "//common/discord/types/payloads/gateway",
- "//common/discord/types/payloads/gateway/commands",
- "//common/discord/types/payloads/gateway/events",
- "//common/discord/types/structures",
- "//common/discord/types/types",
- "//common/gateway",
- "@com_github_boz_go_throttle//:go-throttle",
- "@com_github_gorilla_websocket//:websocket",
- "@com_github_prometheus_client_golang//prometheus",
- "@com_github_prometheus_client_golang//prometheus/promauto",
- "@com_github_rs_zerolog//log",
- ],
-)
+++ /dev/null
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-
-go_library(
- name = "compression",
- srcs = ["json-zlib.go"],
- importpath = "github.com/discordnova/nova/gateway/lib/gateway/compression",
- visibility = ["//gateway:__subpackages__"],
- deps = [
- "//common/discord/types/payloads/gateway",
- "//common/gateway",
- "@com_github_rs_zerolog//log",
- ],
-)
+++ /dev/null
-package compression
-
-import (
- "bytes"
- "compress/zlib"
- "encoding/json"
- "fmt"
- "io"
-
- "github.com/rs/zerolog/log"
- gatewayTypes "github.com/discordnova/nova/common/discord/types/payloads/gateway"
- "github.com/discordnova/nova/common/gateway"
-)
-
-// JsonZlibCompressor is the default compression interface.
-type JsonZlibCompressor struct {
- buffer *bytes.Buffer
- reader io.ReadCloser
-}
-
-// NewJsonZlibCompressor creates an instance of JsonZlibCompressor
-func NewJsonZlibCompressor() gateway.Compression {
- return &JsonZlibCompressor{
- buffer: bytes.NewBuffer([]byte{}),
- }
-}
-
-func (compressor *JsonZlibCompressor) Reset() error {
- compressor.buffer.Reset()
- if compressor.reader == nil {
- return nil
- }
- err := compressor.reader.Close()
- if err != nil {
- return err
- }
- compressor.reader = nil
- return nil
-}
-
-// GetConnectionOptions gets the required options for the gateway.
-func (compressor JsonZlibCompressor) GetConnectionOptions() gateway.GatewayConnectionOptions {
- // Gateway options for the discord gateway.
- return gateway.GatewayConnectionOptions{
- Encoding: "json",
- TransportCompression: "zlib-stream",
- }
-}
-
-// DecodeMessage decodes messages using the compressor.
-func (compressor *JsonZlibCompressor) DecodeMessage(data []byte) (*gatewayTypes.Payload, error) {
-
- // check if the message have the zlib suffix to avoid ruining our zlib context :'(
- if !bytes.HasSuffix(data, []byte{0x00, 0x00, 0xff, 0xff}) {
- return nil, fmt.Errorf("the gateway failed to verify the message validity due to invalid suffix")
- }
-
- // add the data to the buffer for the decompression.
- compressor.buffer.Write(data)
-
- // we can't create the reader without data, so we initialize on the first decompression.
- if compressor.reader == nil {
- reader, err := zlib.NewReader(compressor.buffer)
- if err != nil {
- log.Err(err).Msgf("Failed to initialize zlib reader")
- }
- compressor.reader = reader
- }
-
- // we unmarshal the reader as json
- inter := gatewayTypes.Payload{}
- decoder := json.NewDecoder(compressor.reader)
- err := decoder.Decode(&inter)
-
- if err != nil {
- // the unmarshalling failed
- return nil, err
- }
-
- return &inter, nil
-}
+++ /dev/null
-package gateway
-
-import (
- "encoding/json"
- "fmt"
- "os"
- "runtime"
- "time"
-
- "github.com/boz/go-throttle"
- gatewayTypes "github.com/discordnova/nova/common/discord/types/payloads/gateway"
- "github.com/discordnova/nova/common/discord/types/payloads/gateway/commands"
- "github.com/discordnova/nova/common/discord/types/payloads/gateway/events"
- "github.com/discordnova/nova/common/discord/types/structures"
- "github.com/discordnova/nova/common/discord/types/types"
- "github.com/discordnova/nova/common/gateway"
- "github.com/gorilla/websocket"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/prometheus/client_golang/prometheus/promauto"
- "github.com/rs/zerolog/log"
-)
-
-// connectionState is a struct representing a connection state
-type connectionState struct {
- HeartbeatInterval uint16
- Latency int64
-}
-
-var (
- messagesCounter = promauto.NewCounter(prometheus.CounterOpts{
- Name: "nova_gateway_messages_processed",
- Help: "The total number of processed messages",
- })
-
- heartbeatCounter = promauto.NewCounter(prometheus.CounterOpts{
- Name: "nova_gateway_heartbeat_sent",
- Help: "The total number of heartbeat sent",
- })
-
- latencyGauge = promauto.NewGauge(prometheus.GaugeOpts{
- Name: "nova_gateway_latency",
- Help: "The round trip latency of the gateway",
- })
-
- reconnectionsCounter = promauto.NewCounter(prometheus.CounterOpts{
- Name: "nova_gateway_reconnections",
- Help: "the number of reconnections of the gateway",
- })
-
- eventsCounter = promauto.NewCounter(prometheus.CounterOpts{
- Name: "nova_gateway_events",
- Help: "The various events received by Nova.",
- })
-)
-
-// GatewayConnector represents a connection to the discord gateway for a shard
-type GatewayConnector struct {
- // Public State
- SessionState GatewayConnectorOptionsResume // The state of the session
-
- // Private state
- connectionState connectionState // The internal state of the gateway connection.
- options GatewayConnectorOptions // The connection options.
- connection *websocket.Conn // The current websocket connection.
- heartbeat chan struct{} // Channel for reacting to heartbeat acks
- terminate chan string // Called when a gateway disconnect is requested
- updateThrottle throttle.ThrottleDriver
-}
-
-// NewGateway creates a connector instance based on the given options.
-func NewGateway(options GatewayConnectorOptions) *GatewayConnector {
- return &GatewayConnector{
- options: options,
- SessionState: options.ResumeSession,
- }
-}
-
-// Start is used to start or reset a connection to the gateway.
-func (discord *GatewayConnector) Start() {
- shouldStart := true
- for shouldStart {
- reconnectionsCounter.Inc()
- discord.connectionState = connectionState{}
- _ = discord.start()
- err := discord.options.Compressor.Reset()
- if err != nil {
- log.Fatal().Msgf("failed to reset the compressor")
- }
- shouldStart = *discord.options.Restart
- if shouldStart {
- log.Info().Msg("waiting 10s before gateway reconnection")
- time.Sleep(time.Second * 10)
- }
- }
-}
-
-// start is the internal routine for starting the gateway
-func (discord *GatewayConnector) start() error {
- // we throttle the update function to limit the amount of session state
- // presisted to the session persistence interface
- discord.updateThrottle = throttle.ThrottleFunc(time.Second*5, false, func() {
- if discord.options.OnSessionStateUpdated != nil {
- _ = discord.options.OnSessionStateUpdated(discord.SessionState)
- }
- })
-
- // initialize the message channels
- discord.heartbeat = make(chan struct{})
- discord.terminate = make(chan string)
-
- // since a Compressor is given to the gateway when created, we use the Connector to get
- // the compression and encoding options.
- comOptions := discord.options.Compressor.GetConnectionOptions()
- websocketURL := fmt.Sprintf("wss://gateway.discord.gg/?v=%d&encoding=%s&compress=%s", 6, comOptions.Encoding, comOptions.TransportCompression)
-
- log.Info().Msgf("connecting to the gateway at url %s", websocketURL)
- // we start the connection to discord.
- connection, _, err := websocket.DefaultDialer.Dial(websocketURL, nil)
- if err != nil {
- log.Err(err).Msgf("an error occurred while connecting to the gateway")
- return err
- }
- discord.connection = connection
- defer discord.connection.Close()
-
- // start listening to messages on the socket.
- go discord.listen()
-
- msg := <-discord.terminate
- log.Info().Msgf("terminating the gateway: %s", msg)
-
- return nil
-}
-
-// ticker starts the loop for the heartbeat checks
-func (discord *GatewayConnector) ticker(interval int) {
- // sends a message to heartbeat.C every time we need to send a heartbeat
- heartbeat := time.NewTicker(time.Duration(interval) * time.Millisecond)
- // stores if the last heartbeat succeeded
- doneLastAck := true
-
- // executes the given code when heartbeat.C is triggered
- for range heartbeat.C {
- // if the server did not send the last heartbeat
- if !doneLastAck {
- // we need to terminate the connection
- discord.terminate <- "server missed an ack and must be disconnected"
- return
- }
-
- log.Debug().Msg("Sending a heartbeat.")
-
- index, _ := json.Marshal(discord.SessionState.Index)
- err := discord.connection.WriteJSON(gatewayTypes.Payload{
- Op: types.GatewayOpCodeHeartbeat,
- D: index,
- })
-
- if err != nil {
- discord.terminate <- fmt.Sprintf("failed to send a heartbeat: %s", err.Error())
- return
- }
-
- heartbeatCounter.Inc()
- // wait for the ack asynchronously
- go func() {
- start := time.Now()
- doneLastAck = false
- <-discord.heartbeat
- doneLastAck = true
-
- discord.connectionState.Latency = time.Since(start).Milliseconds()
- latencyGauge.Set(float64(discord.connectionState.Latency))
-
- log.Info().Msgf("heartbeat completed, latency: %dms", discord.connectionState.Latency)
- }()
-
- }
-}
-
-// listen listens to the messages on the gateway
-func (discord *GatewayConnector) listen() {
- for {
- _, message, err := discord.connection.ReadMessage()
-
- if err != nil {
- discord.terminate <- fmt.Sprintf("the connection was closed by the gateway: %s", err.Error())
- return
- }
-
- messagesCounter.Inc()
- data, err := discord.options.Compressor.DecodeMessage(message)
-
- if err != nil || data == nil {
- log.Print(err.Error())
- continue
- }
-
- if data.S != 0 {
- discord.SessionState.Index = data.S
- discord.updateState(data.S, "")
- }
-
- discord.handleMessage(data)
- }
-}
-
-func (discord *GatewayConnector) updateState(newIndex int64, sessionId string) {
- discord.SessionState.Index = newIndex
- if sessionId != "" {
- discord.SessionState.Session = sessionId
- }
- discord.updateThrottle.Trigger()
-}
-
-func (discord *GatewayConnector) handleMessage(message *gatewayTypes.Payload) {
- switch message.Op {
- // call the startup function
- case types.GatewayOpCodeHello:
- discord.hello(message)
- // notify the heartbeat goroutine that a heartbeat ack was received
- case types.GatewayOpCodeHeartbeatACK:
- discord.heartbeat <- struct{}{}
- // handles a dispatch from the gateway
- case types.GatewayOpCodeDispatch:
- discord.dispatch(message)
- // when the session resume fails
- case types.GatewayOpCodeInvalidSession:
- log.Print("failed to resume the session, reconnecting")
- discord.updateState(0, "")
- discord.doLogin()
- // when the gateway requests a reconnect
- case types.GatewayOpCodeReconnect:
- log.Print("the gateway requested a reconnect")
- if string(message.D) != "true" {
- // we may delete the session state
- discord.SessionState.Index = 0
- discord.updateState(0, "")
- }
- discord.terminate <- "the gateway requested a reconnect"
- }
-}
-
-func (discord *GatewayConnector) doLogin() {
- var payload gatewayTypes.Payload
- // if we do not have to resume a session
- if discord.SessionState.Session == "" {
- log.Info().Msg("using identify for authentification")
- data, err := json.Marshal(commands.GatewayCommandIdentifyPayload{
- Token: *discord.options.Token,
- Properties: structures.IdentifyConnectionProperties{
- OS: runtime.GOOS,
- Device: "Nova Discord Client",
- Browser: "Nova Discord Client",
- },
- Compress: true,
- LargeThreshold: 1000,
- Shard: []int{
- *discord.options.SelfShard,
- *discord.options.TotalShard,
- },
- Presence: commands.GatewayCommandUpdateStatusPayload{},
- GuildSubscriptions: *discord.options.GuildSubs,
- Intents: discord.options.Intend,
- })
-
- if err != nil {
- return
- }
-
- payload = gatewayTypes.Payload{
- Op: types.GatewayOpCodeIdentify,
- D: data,
- }
- } else {
- log.Info().Msg("resuming session")
- data, err := json.Marshal(commands.GatewayCommandResumePayload{
- Token: *discord.options.Token,
- SessionID: discord.SessionState.Session,
- Seq: discord.SessionState.Index,
- })
-
- if err != nil {
- return
- }
- payload = gatewayTypes.Payload{
- Op: types.GatewayOpCodeResume,
- D: data,
- }
- }
-
- err := discord.connection.WriteJSON(payload)
- if err != nil {
- log.Err(err).Msgf("failed send the identify payload")
- }
-}
-
-func (discord *GatewayConnector) hello(message *gatewayTypes.Payload) {
-
- data := &events.GatewayEventHelloPayload{}
- err := json.Unmarshal(message.D, &data)
- if err != nil {
- discord.terminate <- fmt.Sprintf("invalid payload: %s", err.Error())
- }
-
- // start the heartbeat goroutine
- log.Debug().Msgf("hello recevied, heartbeating every %d ms", data.HeartbeatInterval)
- go discord.ticker(data.HeartbeatInterval)
-
- // login
- discord.doLogin()
-}
-
-type NovaMessage struct {
- Data json.RawMessage `json:"data"`
- Tracing struct {
- NodeName string `json:"node_name"`
- } `json:"tracing"`
-}
-
-func (discord *GatewayConnector) dispatch(message *gatewayTypes.Payload) {
- // since this is juste a event gateway, we do not care about the content of the events
- // except the ready, resumed, reconnect event we use to update the session_id, the other events are forwarded to the transporter
- if message.T == "READY" {
- event := events.GatewayEventReadyPayload{}
- err := json.Unmarshal(message.D, &event)
-
- log.Info().Msgf("logged in as %s", event.User.Username)
-
- if err != nil {
- discord.terminate <- "invalid ready event"
- return
- }
-
- discord.updateState(discord.SessionState.Index, event.SessionID)
- return
- }
-
- newName := gateway.EventNames[message.T]
-
- if newName == "" {
- log.Error().Msgf("unknown event name: %s", newName)
- return
- }
-
- name, err := os.Hostname()
-
- if err != nil {
- log.Err(err).Msgf("failed to get the hostname")
- return
- }
-
- data, err := json.Marshal(NovaMessage{
- Data: message.D,
- Tracing: struct {
- NodeName string `json:"node_name"`
- }{
- NodeName: name,
- },
- })
-
- if err != nil {
- log.Err(err).Msg("failed to serialize the outgoing nova message")
- }
-
- discord.options.Transporter.PushChannel() <- gateway.PushData{
- Data: data,
- Name: newName,
- }
-
- if err != nil {
- log.Err(err).Msg("failed to send the event to the nova event broker")
- }
-}
+++ /dev/null
-package gateway
-
-import (
- "github.com/discordnova/nova/common/discord/types/types"
- "github.com/discordnova/nova/common/gateway"
-)
-
-// GatewayConnectorOptionsResume represents the options for reconnecting the gateway.
-type GatewayConnectorOptionsResume struct {
- Session string `json:"session_id"` // The session id of the older session we want to resume.
- Index int64 `json:"index"` // The index of the last packet recevied by the older session.
-}
-
-// GatewayConnectorOptionsSharding represents the options for sharding the gateway.
-type GatewayConnectorOptionsSharding struct {
- TotalShards int `json:"total_shards"` // The total amount of shards
- CurrentShard int `json:"current_shard"` // The shard we want to connect to.
-}
-
-// GatewayConnectorOptions is the options given to the GatewayConnector when creating it.
-type GatewayConnectorOptions struct {
- Token *string // The token of the bot
- SelfShard *int // The shard of the current connector
- TotalShard *int // The total count of shards
- Intend types.GatewayIntents // The bitflags for the indents.
- GuildSubs *bool // Should the guild_subscriptions be enabled
- ResumeSession GatewayConnectorOptionsResume // Is specified, the gateway will try to resume a connection.
- Compressor gateway.Compression // The compressor given to the gateway that determine the connection method and compression used.
- Transporter gateway.Transporter // The interface where we send the data.
- Restart *bool // Should the gateway restart upon failure.
-
- OnSessionStateUpdated func(state GatewayConnectorOptionsResume) error // When the session state is called, we call this function
- SessionUpdateFrequency *int
-}
+++ /dev/null
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-
-go_library(
- name = "transporters",
- srcs = ["rabbitmq.go"],
- importpath = "github.com/discordnova/nova/gateway/lib/gateway/transporters",
- visibility = ["//gateway:__subpackages__"],
- deps = [
- "//common/gateway",
- "@com_github_rs_zerolog//log",
- "@com_github_streadway_amqp//:amqp",
- ],
-)
+++ /dev/null
-package transporters
-
-import (
- "time"
-
- "github.com/discordnova/nova/common/gateway"
- "github.com/rs/zerolog/log"
- "github.com/streadway/amqp"
-)
-
-type RabbitMqTransporter struct {
- pullChannel chan []byte
- pushChannel chan gateway.PushData
-}
-
-// NewRabbitMqTransporter creates a rabbitmq transporter using a given url
-func NewRabbitMqTransporter(url string) (gateway.Transporter, error) {
- log.Info().Msg("connecting to the transporter using rabbitmq...")
- conn, err := amqp.Dial(url)
-
- if err != nil {
- return nil, err
- }
-
- send, err := conn.Channel()
-
- if err != nil {
- return nil, err
- }
-
- err = send.ExchangeDeclare(
- "nova_gateway_dispatch",
- "topic",
- true,
- false,
- false,
- true,
- nil,
- )
-
- if err != nil {
- return nil, err
- }
-
- pullChannel, pushChannel := make(chan []byte), make(chan gateway.PushData)
-
- go func() {
- for {
- data := <-pushChannel
- send.Publish(
- "nova_gateway_dispatch",
- data.Name,
- false,
- false,
- amqp.Publishing{
- Priority: 1,
- Timestamp: time.Now(),
- Type: data.Name,
- Body: data.Data,
- },
- )
- }
- }()
-
- return &RabbitMqTransporter{
- pullChannel: pullChannel,
- pushChannel: pushChannel,
- }, nil
-}
-
-func (t RabbitMqTransporter) PushChannel() chan gateway.PushData {
- return t.pushChannel
-}
-func (t RabbitMqTransporter) PullChannel() chan []byte {
- return t.pullChannel
-}
+++ /dev/null
-package main
-
-import (
- "flag"
-
- "github.com/discordnova/nova/common"
- "github.com/discordnova/nova/gateway/lib/gateway"
- "github.com/discordnova/nova/gateway/lib/gateway/compression"
- "github.com/discordnova/nova/gateway/lib/gateway/transporters"
- "github.com/rs/zerolog/log"
-)
-
-var (
- settings = gateway.GatewayConnectorOptions{
- Token: flag.String("token", "", "the discord token for the websocket connection"),
- Restart: flag.Bool("restart", true, "should the gateway be restarted if an error occurs"),
- GuildSubs: flag.Bool("guild-subscriptions", true, "should the guild subscription gateway flag set to true"),
- SelfShard: flag.Int("shard", 0, "the shard id of this instance"),
- TotalShard: flag.Int("shard-count", 1, "the total amount of shard"),
- SessionUpdateFrequency: flag.Int("session-persist-frequence", 10, "the frequency of session persistence"),
- }
-
- compressor = flag.String("compressor", "json-zlib", "the compressor used to connect")
- transporter = flag.String("transporter", "rabbitmq", "the compressor used to connect")
- monitoring = flag.Int("prometheus-port", 9000, "is this flag is set, a prometheus metrics endpoint will be exposed")
- url = flag.String("transporter-url", "", "the url needed for rabbitmq to function")
-)
-
-func validate(settings *gateway.GatewayConnectorOptions) {
- if *settings.SelfShard > *settings.TotalShard {
- log.Fatal().Msg("invalid config: the shard id must be inferior than the total shard value")
- } else if *settings.SessionUpdateFrequency == 0 {
- log.Fatal().Msg("invalid config: the session update frequency muse be greater than 0")
- } else if *settings.Token == "" {
- log.Fatal().Msg("invalid config: invalid token provided")
- } else if *settings.TotalShard == 0 {
- log.Fatal().Msg("invalid config: the total number of shards muse be greater than 0")
- }
-}
-
-func main() {
- flag.Parse()
- common.SetupLogger()
-
- if monitoring != nil {
- go common.CreatePrometheus(*monitoring)
- log.Debug().Msg("prometheus server called")
- }
-
- if *compressor == "json-zlib" {
- settings.Compressor = compression.NewJsonZlibCompressor()
- } else {
- log.Fatal().Msgf("unknown compressor specified: %s", *compressor)
- }
-
- if *transporter == "rabbitmq" {
- trns, err := transporters.NewRabbitMqTransporter(*url)
- if err != nil {
- log.Fatal().Msgf("failed to initialize the transporter: %s", err.Error())
- }
- settings.Transporter = trns
- }
-
- validate(&settings)
-
- gateway := gateway.NewGateway(settings)
- gateway.Start()
-}
--- /dev/null
+use crate::client::payloads::{message::OpCodes, payloads::Hello};
+
+use super::{
+ payloads::message::MessageBase,
+ state::{Stage, State},
+ utils::get_gateway_url,
+};
+use futures_util::{
+ SinkExt, StreamExt,
+};
+use log::{error, info, warn};
+use std::{str::from_utf8, time::Duration};
+use tokio::{
+ net::TcpStream,
+ select,
+ time::{Instant},
+};
+use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::{self, Message}};
+
+#[derive(Debug)]
+pub enum CloseReason {
+ ConnexionAlredyOpen,
+ ConnexionEnded,
+ ErrorEncountered(&'static str),
+ ConnexionError(tungstenite::Error),
+}
+
+pub enum HandleResult {
+ Success,
+ Error(CloseReason),
+}
+
+/// This struct represents a single connexion to the gateway,
+/// it does not have any retry logic or reconnexion mechanism,
+/// everything is handled in the Shard struct.
+/// The purpose of this struct is to handle the encoding,
+/// compression and other gateway-transport related stuff.
+/// All the messages are send through another struct implementing
+/// the MessageHandler trait.
+pub struct Connexion {
+ state: State,
+ connexion: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
+}
+
+impl Connexion {
+ /// Creates a new instance of a discord websocket connexion using the options
+ /// this is used internally by the shard struct to initialize a single
+ /// websocket connexion. This instance is not initialized by default.
+ /// a websocket connexion like this can be re-used multiple times
+ /// to allow reconnexion mechanisms.
+ pub async fn new() -> Self {
+ Connexion {
+ state: State::default(),
+ connexion: None,
+ }
+ }
+
+ /// Terminate the connexion and the "start" method related to it.
+ async fn _terminate_websocket(&mut self, message: CloseReason) {
+ if let Some(connexion) = &mut self.connexion {
+ if let Err(err) = connexion.close(None).await {
+ error!("failed to close socket {}", err);
+ } else {
+ info!("closed the socket: {:?}", message)
+ }
+ } else {
+ warn!("a termination request was sent without a connexion openned")
+ }
+ }
+
+ /// Initialize a connexion to the gateway
+ /// returns if a connexion is already present
+ pub async fn start(mut self) -> CloseReason {
+ if let Some(_) = self.connexion {
+ CloseReason::ConnexionAlredyOpen
+ } else {
+ // we reset the state before starting the connection
+ self.state = State::default();
+
+ let connexion_result = connect_async(get_gateway_url(false, "json", 9)).await;
+ // we connect outselves to the websocket server
+ if let Err(err) = connexion_result {
+ return CloseReason::ConnexionError(err)
+ }
+ self.connexion = Some(connexion_result.unwrap().0);
+
+ // this is the loop that will maintain the whole connexion
+ loop {
+ if let Some(connexion) = &mut self.connexion {
+ // if we do not have a hello message received yet, then we do not use the heartbeat interval
+ // and we just wait for messages to arrive
+ if self.state.stage == Stage::Unknown {
+ let msg = connexion.next().await;
+ if let HandleResult::Error(reason) = self._handle_message(&msg).await {
+ return reason
+ }
+ } else {
+ let timer = self.state.interval.as_mut().unwrap().tick();
+ select! {
+ msg = connexion.next() => {
+ if let HandleResult::Error(reason) = self._handle_message(&msg).await {
+ return reason
+ }
+ },
+ _ = timer => self._do_heartbeat().await
+ }
+ }
+ } else {
+ return CloseReason::ConnexionEnded
+ }
+ }
+ }
+ }
+
+ async fn _handle_message(
+ &mut self,
+ data: &Option<Result<Message, tokio_tungstenite::tungstenite::Error>>,
+ ) -> HandleResult {
+ if let Some(message) = data {
+ match message {
+ Ok(message) => match message {
+ Message::Text(text) => {
+ self._handle_discord_message(&text).await;
+ HandleResult::Success
+ }
+ Message::Binary(message) => {
+ self._handle_discord_message(from_utf8(message).unwrap())
+ .await;
+ HandleResult::Success
+ }
+ Message::Close(_) => {
+ error!("discord connexion closed");
+ HandleResult::Error(CloseReason::ConnexionEnded)
+ }
+
+ _ => {
+ HandleResult::Error(CloseReason::ErrorEncountered("unsupported message type encountered"))
+ }
+ },
+ Err(_error) => {
+ HandleResult::Error(CloseReason::ErrorEncountered("error while reading a message"))
+ }
+ }
+ } else {
+ HandleResult::Error(CloseReason::ErrorEncountered("error while reading a message"))
+ }
+ }
+
+ async fn _handle_discord_message(&mut self, raw_message: &str) {
+ let a: Result<MessageBase, serde_json::Error> = serde_json::from_str(raw_message);
+ let message = a.unwrap();
+
+ // handles the state
+ if let Some(index) = message.s {
+ self.state.sequence = index;
+ }
+
+ match message.op {
+ OpCodes::Dispatch => todo!(),
+ OpCodes::Heartbeat => todo!(),
+ OpCodes::Identify => todo!(),
+ OpCodes::PresenceUpdate => todo!(),
+ OpCodes::VoiceStateUpdate => todo!(),
+ OpCodes::Resume => todo!(),
+ OpCodes::Reconnect => todo!(),
+ OpCodes::RequestGuildMembers => todo!(),
+ OpCodes::InvalidSession => todo!(),
+ OpCodes::Hello => {
+ if let Ok(hello) = serde_json::from_value::<Hello>(message.d) {
+ info!("server sent hello {:?}", hello);
+ info!("heartbeating every {}ms", hello.heartbeat_interval);
+ self.state.interval = Some(tokio::time::interval_at(
+ Instant::now() + Duration::from_millis(hello.heartbeat_interval),
+ Duration::from_millis(hello.heartbeat_interval),
+ ));
+ self.state.stage = Stage::Initialized;
+ }
+ }
+ OpCodes::HeartbeatACK => {
+ info!(
+ "heartbeat acknowledged after {}ms",
+ (std::time::Instant::now() - self.state.last_heartbeat_time).as_millis()
+ );
+ self.state.last_heartbeat_acknowledged = true;
+ }
+ }
+ }
+
+ async fn _do_heartbeat(&mut self) {
+ if !self.state.last_heartbeat_acknowledged {
+ self._terminate_websocket(CloseReason::ErrorEncountered("the server did not acknowledged the last heartbeat")).await;
+ return;
+ }
+ self.state.last_heartbeat_acknowledged = false;
+
+ info!("sending heartbeat");
+ self._send(
+ serde_json::to_vec(&MessageBase {
+ t: None,
+ d: serde_json::to_value(self.state.sequence).unwrap(),
+ s: None,
+ op: OpCodes::Heartbeat,
+ })
+ .unwrap(),
+ )
+ .await;
+ self.state.last_heartbeat_time = std::time::Instant::now();
+ }
+
+ async fn _send(&mut self, data: Vec<u8>) {
+ if let Some(connexion) = &mut self.connexion {
+ if let Err(error) = connexion.send(Message::Binary(data)).await {
+ error!("failed to write to socket: {}", error);
+ self._terminate_websocket(CloseReason::ErrorEncountered("failed to write to the socket")).await;
+ }
+ }
+ }
+}
--- /dev/null
+#[derive(Debug)]
+struct MyError(String);
+
+impl fmt::Display for MyError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "There is an error: {}", self.0)
+ }
+}
+
+impl Error for NovaError {}
+
+pub fn run() -> Result<(), Box<dyn Error>> {
+ let condition = true;
+
+ if condition {
+ return Err(Box::new(MyError("Oops".into())));
+ }
+
+ Ok(())
+}
\ No newline at end of file
--- /dev/null
+pub mod connexion;
+mod utils;
+mod state;
+mod shard;
+pub mod payloads;
+pub mod traits;
\ No newline at end of file
--- /dev/null
+use serde_json::Value;
+use serde_repr::{Serialize_repr, Deserialize_repr};
+use serde::{Deserialize, Serialize};
+
+
+#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug)]
+#[repr(u8)]
+pub enum OpCodes {
+ Dispatch = 0,
+ Heartbeat = 1,
+ Identify = 2,
+ PresenceUpdate = 3,
+ VoiceStateUpdate = 4,
+ Resume = 6,
+ Reconnect = 7,
+ RequestGuildMembers = 8,
+ InvalidSession = 9,
+ Hello = 10,
+ HeartbeatACK = 11,
+}
+
+#[derive(Serialize, Deserialize, PartialEq, Debug)]
+pub struct MessageBase {
+ pub t: Option<String>,
+ pub s: Option<i64>,
+ pub op: OpCodes,
+ pub d: Value
+}
--- /dev/null
+pub mod payloads;
+pub mod message;
\ No newline at end of file
--- /dev/null
+use serde::{Serialize, Deserialize};
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Hello {
+ #[serde(rename = "heartbeat_interval")]
+ pub heartbeat_interval: u64
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct HeartbeatACK {}
\ No newline at end of file
--- /dev/null
+
+
+
+struct Shard {
+
+}
\ No newline at end of file
--- /dev/null
+use std::time::Instant;
+use tokio::time::Interval;
+
+#[derive(PartialEq)]
+pub enum Stage {
+ Unknown,
+ Initialized,
+ LoggedIn,
+}
+
+pub struct State {
+ pub stage: Stage,
+ pub sequence: i64,
+ pub last_heartbeat_acknowledged: bool,
+ pub last_heartbeat_time: Instant,
+ pub interval: Option<Interval>,
+}
+
+impl State {
+ pub fn default() -> Self {
+ State {
+ sequence: 0,
+ interval: None,
+ stage: Stage::Unknown,
+ last_heartbeat_acknowledged: true,
+ last_heartbeat_time: std::time::Instant::now(),
+ }
+ }
+}
--- /dev/null
+/// This trait is used by the Connexion<H> struct
+/// It implements a basic interface for handling events.
+pub trait MessageHandler {}
\ No newline at end of file
--- /dev/null
+pub mod message_handler;
\ No newline at end of file
--- /dev/null
+pub fn get_gateway_url (compress: bool, encoding: &str, v: i16) -> String {
+ return format!(
+ "wss://gateway.discord.gg/?v={}&encoding={}&compress={}",
+ v, encoding,
+ if compress { "zlib-stream" } else { "" }
+ );
+}
\ No newline at end of file
--- /dev/null
+use client::traits::message_handler::MessageHandler;
+extern crate serde_json;
+
+mod client;
+
+struct Handler {}
+impl MessageHandler for Handler {}
+
+#[tokio::main]
+async fn main() {
+ pretty_env_logger::init();
+ for _ in 0..1 {
+ tokio::spawn(async move {
+ let con = client::connexion::Connexion::new().await;
+ con.start().await;
+ }).await.unwrap();
+ }
+}
\ No newline at end of file