diff --git a/_examples/command/README.md b/_examples/command/README.md index 2cc05ea..e767615 100644 --- a/_examples/command/README.md +++ b/_examples/command/README.md @@ -78,8 +78,15 @@ Program executed successfully. **I am receiving `Invalid interaction application command` from Discord when I send an interaction.** -Discord propagates registered Global Application Commands. As a result, it can take time to add or update a Global Application Command. In addition, the user's client must reload the commands that are available to them, so that the user's client selects the new propagated Global Application Command _(with a new ID and Token)_. In contrast, Guild Application Commands _(registered via `CreateGuildApplicationCommand`)_ are updated instantly. Due to this behavior, use Guild Application Commands to test your application without waiting for propagation. For more information, read the [Discord API Documentation](https://discord.com/developers/docs/interactions/application-commands#registering-a-command). +Discord propagates registered Global Application Commands. So, it can take time to add or update a Global Application Command. In addition, the user's client must reload the commands that are available to them, so that the user's client selects the new propagated Global Application Command _(with a new ID and Token)_. In contrast, Guild Application Commands _(registered via `CreateGuildApplicationCommand`)_ are updated instantly. Due to this behavior, use Guild Application Commands to test your application without waiting for propagation. For more information, read the [Discord API Documentation](https://discord.com/developers/docs/interactions/application-commands#registering-a-command). **I am receiving a nil pointer dereference when the Bot's Application Command is used in a DM or Guild.** -When an Application Command is used in a direct message, the `Interaction.User` field is provided, while the `Interaction.Member` is **NOT**. When an Application Command is used in a guild, the `Interaction.Member` field is provided, while the `Interaction.User` is **NOT**. For the sake of simplicity, these examples assume that you will use your command in a Direct Message Channel. To protect against this behavior in production-level code, ensure that the `Interaction.User` or `Interaction.Member` is `!= nil` before referencing their fields. \ No newline at end of file +These examples create commands for usage in a Direct Message Channel. Confirm the `Interaction.User` or `Interaction.Member` is `!= nil` before referencing their fields to protect against this behavior in production-level code. + +_Here is more information about this behavior._ + +When an Application Command is used in a direct message, the `Interaction.User` field is provided, while the `Interaction.Member` is **NOT**. + +When an Application Command is used in a guild, the `Interaction.Member` field is provided, while the `Interaction.User` is **NOT**. + diff --git a/_examples/command/autocomplete/main.go b/_examples/command/autocomplete/main.go index 10745a6..3923298 100644 --- a/_examples/command/autocomplete/main.go +++ b/_examples/command/autocomplete/main.go @@ -90,7 +90,7 @@ func main() { // Add an event handler to the bot. // - // ensure that the event handler is added to the bot. + // confirm the event handler is added to the bot. if err := bot.Handle(disgo.FlagGatewayEventNameInteractionCreate, func(i *disgo.InteractionCreate) { log.Println("Received interaction.") diff --git a/_examples/command/followup/main.go b/_examples/command/followup/main.go index d353fe4..bfc5149 100644 --- a/_examples/command/followup/main.go +++ b/_examples/command/followup/main.go @@ -59,7 +59,7 @@ func main() { // Add an event handler to the bot. // - // ensure that the event handler is added to the bot. + // confirm the event handler is added to the bot. if err := bot.Handle(disgo.FlagGatewayEventNameInteractionCreate, func(i *disgo.InteractionCreate) { log.Printf("followup called by %s.", i.Interaction.User.Username) diff --git a/_examples/command/localization/main.go b/_examples/command/localization/main.go index 30e726e..7290802 100644 --- a/_examples/command/localization/main.go +++ b/_examples/command/localization/main.go @@ -103,7 +103,7 @@ func main() { // Add an event handler to the bot. // - // ensure that the event handler is added to the bot. + // confirm the event handler is added to the bot. if err := bot.Handle(disgo.FlagGatewayEventNameInteractionCreate, func(i *disgo.InteractionCreate) { log.Printf("hello called by %s.", i.Interaction.User.Username) diff --git a/_examples/command/main.go b/_examples/command/main.go index 7d13419..1d6c2fe 100644 --- a/_examples/command/main.go +++ b/_examples/command/main.go @@ -62,7 +62,7 @@ func main() { // Add an event handler to the bot. // - // ensure that the event handler is added to the bot. + // confirm the event handler is added to the bot. if err := bot.Handle(disgo.FlagGatewayEventNameInteractionCreate, func(i *disgo.InteractionCreate) { log.Printf("main called by %s.", i.Interaction.User.Username) diff --git a/_examples/command/subcommand/main.go b/_examples/command/subcommand/main.go index 4c4d27c..974c005 100644 --- a/_examples/command/subcommand/main.go +++ b/_examples/command/subcommand/main.go @@ -174,7 +174,7 @@ func main() { // Add an event handler to the bot. // - // ensure that the event handler is added to the bot. + // confirm the event handler is added to the bot. if err := bot.Handle(disgo.FlagGatewayEventNameInteractionCreate, func(i *disgo.InteractionCreate) { log.Printf("calculate called by %s.", i.Interaction.User.Username) diff --git a/_examples/image/avatar/main.go b/_examples/image/avatar/main.go index 2a1ba11..e535f0e 100644 --- a/_examples/image/avatar/main.go +++ b/_examples/image/avatar/main.go @@ -28,7 +28,7 @@ func main() { // parse the command line flags. flag.Parse() - // ensure that the program has the necessary data to succeed. + // confirm the program has the necessary data to succeed. if token == "" { fmt.Println("The bot's token must be set, but is currently empty.") diff --git a/_examples/message/components/main.go b/_examples/message/components/main.go index 5e9cd55..342e11f 100644 --- a/_examples/message/components/main.go +++ b/_examples/message/components/main.go @@ -26,7 +26,7 @@ func main() { // parse the command line flags. flag.Parse() - // ensure that the program has the necessary data to succeed. + // confirm the program has the necessary data to succeed. if token == "" { log.Println("The bot's token must be set, but is currently empty.") @@ -46,7 +46,7 @@ func main() { Config: disgo.DefaultConfig(), } - // ensure that the bot has access to the channel. + // confirm the bot has access to the channel. // // This is useful for the validation of this program, but unnecessary. getChannelRequest := disgo.GetChannel{ChannelID: *channelID} diff --git a/_examples/message/send/main.go b/_examples/message/send/main.go index cf8ee4e..567934b 100644 --- a/_examples/message/send/main.go +++ b/_examples/message/send/main.go @@ -28,7 +28,7 @@ func main() { // parse the command line flags. flag.Parse() - // ensure that the program has the necessary data to succeed. + // confirm the program has the necessary data to succeed. if token == "" { log.Println("The bot's token must be set, but is currently empty.") @@ -69,7 +69,7 @@ func main() { Config: disgo.DefaultConfig(), } - // ensure that the bot has access to the channel. + // confirm the bot has access to the channel. // // This is useful for the validation of this program, but unnecessary. getChannelRequest := disgo.GetChannel{ChannelID: *channelID} diff --git a/disgo.go b/disgo.go index 7af7b29..39fb0cc 100644 --- a/disgo.go +++ b/disgo.go @@ -8193,19 +8193,29 @@ func (e ErrorRequest) Error() string { e.ClientID, e.CorrelationID, e.RouteID, e.ResourceID, e.Endpoint, e.Err).Error() } +// ErrorStatusCode represents an HTTP Request error that occurs when an unexpected response is returned. +type ErrorStatusCode struct { + // StatusCode represents the HTTP Status Code received from a response. + StatusCode int +} + // Status Code Error Messages. const ( errStatusCodeKnown = "status code %d: %v" errStatusCodeUnknown = "status code %d: unknown status code error from Discord" ) -// StatusCodeError handles a Discord API HTTP Status Code and returns the relevant error message. -func StatusCodeError(status int) error { +func (e ErrorStatusCode) Error() string { + return fmt.Sprintf("STATUS CODE ERROR: status code: %q: msg: %v", e.StatusCode, StatusCodeError(e.StatusCode)) +} + +// StatusCodeError returns the relevant message for a Discord API HTTP Status Code. +func StatusCodeError(status int) string { if msg, ok := HTTPResponseCodes[status]; ok { - return fmt.Errorf(errStatusCodeKnown, status, msg) + return fmt.Sprintf(errStatusCodeKnown, status, msg) } - return fmt.Errorf(errStatusCodeUnknown, status) + return fmt.Sprintf(errStatusCodeUnknown, status) } // JSON Error Code Messages. @@ -8280,26 +8290,6 @@ func (e ErrorEvent) Error() string { e.ClientID, e.Event, e.Action, e.Err).Error() } -// Discord Gateway Error Messages -const ( - errNoSessionManager = `The client must contain a non-nil SessionManager struct to connect to the Discord Gateway. - -Set the *Client.SessionManager using one of the following methods. - ---- 1 - -bot := &disgo.Client{ -... -Sessions: disgo.NewSessionManager(), -} - ---- 2 - -bot.Sessions = disgo.NewSessionManager() - -` -) - // ErrorSession represents a WebSocket Session error that occurs during an active session. type ErrorSession struct { // Err represents the error that occurred. @@ -8307,35 +8297,38 @@ type ErrorSession struct { // SessionID represents the ID of the Session. SessionID string -} -func (e ErrorSession) Error() string { - return fmt.Errorf("SESSION ERROR: session %q: error: %w", e.SessionID, e.Err).Error() + // State represents the state of the session. + State string + + // Type represents the type of connection (e.g., Discord Gateway, Discord Voice). + Type string } const ( - ErrConnectionSession = "Discord Gateway" - ErrConnectionSessionVoice = "Discord Voice" + ErrorSessionTypeGateway = "Discord Gateway" + ErrorSessionTypeVoice = "Discord Voice" ) -// ErrorDisconnect represents a disconnection error that occurs when -// an attempt to gracefully disconnect from a connection fails. -type ErrorDisconnect struct { +func (e ErrorSession) Error() string { + return fmt.Errorf("SESSION ERROR: %q session %q: state: %q error: %w", e.Type, e.SessionID, e.State, e.Err).Error() +} + +// ErrorSessionDisconnect represents a disconnection error that occurs when +// an attempt to gracefully disconnect from a session fails. +type ErrorSessionDisconnect struct { // Action represents the error that prompted the disconnection (if applicable). Action error // Err represents the error that occurred while disconnecting. Err error - - // Connection represents the name of the connection. - Connection string } -func (e ErrorDisconnect) Error() string { - return fmt.Errorf("error disconnecting from %q\n"+ +func (e ErrorSessionDisconnect) Error() string { + return fmt.Errorf( "\tDisconnect(): %v\n"+ - "\treason: %w\n", - e.Connection, e.Err, e.Action, + "\treason: %w\n", + e.Err, e.Action, ).Error() //lint:ignore ST1005 readability } @@ -10693,7 +10686,9 @@ SEND: goto RATELIMIT } - return StatusCodeError(response.StatusCode()) + return ErrorStatusCode{ + StatusCode: response.StatusCode(), + } } // parse the rate limit response data for `retry_after`. @@ -10759,7 +10754,9 @@ SEND: goto RATELIMIT } - return StatusCodeError(fasthttp.StatusTooManyRequests) + return ErrorStatusCode{ + StatusCode: fasthttp.StatusTooManyRequests, + } // retry the request on a bad gateway server error. case fasthttp.StatusBadGateway: @@ -10769,10 +10766,14 @@ SEND: goto RATELIMIT } - return StatusCodeError(fasthttp.StatusBadGateway) + return ErrorStatusCode{ + StatusCode: fasthttp.StatusBadGateway, + } default: - return StatusCodeError(response.StatusCode()) + return ErrorStatusCode{ + StatusCode: response.StatusCode(), + } } } @@ -17177,12 +17178,6 @@ func (r *GetCurrentAuthorizationInformation) Send(bot *Client) (*CurrentAuthoriz return result, nil } -const ( - gatewayEndpointParams = "?v=" + VersionDiscordAPI + "&encoding=json" - invalidSessionWaitTime = 1 * time.Second - maxIdentifyLargeThreshold = 250 -) - // Session represents a Discord Gateway WebSocket Session. type Session struct { Context context.Context @@ -17198,432 +17193,142 @@ type Session struct { sync.RWMutex } -// isConnected returns whether the session is connected. -func (s *Session) isConnected() bool { - if s.Context == nil { - return false - } - - select { - case <-s.Context.Done(): - return false - default: - return true - } -} - -// canReconnect determines whether the session is in a valid state to reconnect. -func (s *Session) canReconnect() bool { - return s.ID != "" && s.Endpoint != "" && atomic.LoadInt64(&s.Seq) != 0 -} - // Connect connects a session to the Discord Gateway (WebSocket Connection). func (s *Session) Connect(bot *Client) error { - s.Lock() - defer s.Unlock() - - LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connecting session") - - return s.connect(bot) -} + if bot == nil { + return errors.New("cannot connect session using a nil Client") + } -// connect connects a session to a WebSocket Connection. -func (s *Session) connect(bot *Client) error { if bot.Sessions == nil { - return fmt.Errorf("%q", errNoSessionManager) + bot.Sessions = NewSessionManager() } - s.client_manager = bot.Sessions - if bot.Handlers == nil { bot.Handlers = new(Handlers) } - if s.isConnected() { + s.Lock() + s.client_manager = bot.Sessions + + if s.manager != nil && s.State() == SessionStateConnected { + s.Unlock() + return fmt.Errorf("session %q is already connected", s.ID) } - var err error - - // request a valid Gateway URL endpoint and response from the Discord API. - gatewayEndpoint := s.Endpoint - var response *GetGatewayBotResponse + s.spawnManager(bot) - if bot.Config.Gateway.ShardManager != nil { - if response, err = bot.Config.Gateway.ShardManager.SetLimit(bot); err != nil { - return fmt.Errorf("shardmanager: %w", err) - } - } else { - if gatewayEndpoint == "" || !s.canReconnect() { - gateway := GetGatewayBot{} - response, err = gateway.Send(bot) - if err != nil { - return fmt.Errorf("error getting the Gateway API Endpoint: %w", err) - } + s.manager.signals <- sessionSignalConnect + s.Unlock() - gatewayEndpoint = response.URL - } + if err := <-s.manager.actionError; err != nil { + return err } - // set the maximum allowed (Identify) concurrency rate limit. - // - // https://discord.com/developers/docs/topics/gateway#rate-limiting - if response != nil { - bot.Config.Gateway.RateLimiter.StartTx() + return nil +} - identifyBucket := bot.Config.Gateway.RateLimiter.GetBucketFromID(FlagGatewaySendEventNameIdentify) - if identifyBucket == nil { - identifyBucket = getBucket() - bot.Config.Gateway.RateLimiter.SetBucketFromID(FlagGatewaySendEventNameIdentify, identifyBucket) - } +// Disconnect disconnects a session from the Discord Gateway. +func (s *Session) Disconnect() error { + s.Lock() + if s.manager == nil || s.State() != SessionStateConnected { + s.Unlock() - identifyBucket.Limit = int16(response.SessionStartLimit.MaxConcurrency) //nolint:gosec // disable G115 + return errors.New("cannot disconnect session that isn't connected") + } - if identifyBucket.Expiry.IsZero() { - identifyBucket.Remaining = identifyBucket.Limit - identifyBucket.Expiry = time.Now().Add(FlagGlobalRateLimitIdentifyInterval) - } + s.manager.signals <- sessionSignalDisconnect + s.Unlock() - bot.Config.Gateway.RateLimiter.EndTx() + if err := <-s.manager.actionError; err != nil { + return err } - // connect to the Discord Gateway Websocket. - s.manager = new(manager) - s.Context, s.manager.cancel = context.WithCancel(context.Background()) - if s.Conn, _, err = websocket.Dial(s.Context, gatewayEndpoint+gatewayEndpointParams, nil); err != nil { - return fmt.Errorf("error connecting to the Discord Gateway: %w", err) - } + // Reset the session. + putSession(s) - // set up the Session's Rate Limiter (applied per WebSocket Connection). - // https://discord.com/developers/docs/topics/gateway#rate-limiting - s.RateLimiter = &RateLimit{ //nolint:exhaustruct - ids: make(map[string]string, totalGatewayBucketsPerConnection), - buckets: make(map[string]*Bucket, totalGatewayBucketsPerConnection), - } + return nil +} - s.RateLimiter.SetBucket( - GlobalRateLimitRouteID, &Bucket{ //nolint:exhaustruct - Limit: FlagGlobalRateLimitGateway, - Remaining: FlagGlobalRateLimitGateway, - Expiry: time.Now().Add(FlagGlobalRateLimitGatewayInterval), - }, - ) +// Reconnect reconnects an already connected session to the Discord Gateway +// by disconnecting the session, then connecting again. +func (s *Session) Reconnect(bot *Client) error { + s.Lock() + if s.manager == nil || s.State() != SessionStateConnected { + s.Unlock() - // handle the incoming Hello event upon connecting to the Gateway. - hello := new(Hello) - if err := readEvent(s, hello); err != nil { - err = fmt.Errorf("error reading initial Hello event: %w", err) - sessionErr := ErrorSession{SessionID: s.ID, Err: err} - if disconnectErr := s.disconnect(FlagClientCloseEventCodeNormal); disconnectErr != nil { - sessionErr.Err = ErrorDisconnect{ - Action: err, - Err: disconnectErr, - Connection: ErrConnectionSession, - } - } + return errors.New("cannot reconnect session that isn't connected") + } - return sessionErr + s.manager.signals <- sessionSignalReconnect + s.Unlock() + + if err := <-s.manager.actionError; err != nil { + return fmt.Errorf("reconnect: %w", err) } - for _, handler := range bot.Handlers.Hello { - go handler(hello) + return nil +} + +// SendEvent sends an Opcode 1 Heartbeat event to the Discord Gateway. +func (c *Heartbeat) SendEvent(bot *Client, session *Session) error { + if err := writeEvent(bot, session, FlagGatewayOpcodeHeartbeat, FlagGatewaySendEventNameHeartbeat, c); err != nil { + return err } - // begin sending heartbeat payloads every heartbeat_interval ms. - ms := time.Millisecond * time.Duration(hello.HeartbeatInterval) - s.heartbeat = &heartbeat{ - interval: ms, - ticker: time.NewTicker(ms), - send: make(chan Heartbeat), + return nil +} - // add a HeartbeatACK to the HeartbeatACK channel to prevent - // the length of the HeartbeatACK channel from being 0 immediately, - // which results in an attempt to reconnect. - acks: 1, +// SendEvent sends an Opcode 2 Identify event to the Discord Gateway. +func (c *Identify) SendEvent(bot *Client, session *Session) error { + if err := writeEvent(bot, session, FlagGatewayOpcodeIdentify, FlagGatewaySendEventNameIdentify, c); err != nil { + return err } - // create a goroutine group for the Session. - s.manager.Group, s.manager.signal = errgroup.WithContext(s.Context) - s.manager.err = make(chan error, 1) + return nil +} - // spawn the heartbeat pulse goroutine. - s.manager.routines.Add(1) - atomic.AddInt32(&s.manager.pulses, 1) - s.manager.Go(func() error { - s.pulse() - return nil - }) +// SendEvent sends an Opcode 3 UpdatePresence event to the Discord Gateway. +func (c *GatewayPresenceUpdate) SendEvent(bot *Client, session *Session) error { + if err := writeEvent(bot, session, FlagGatewayOpcodePresenceUpdate, FlagGatewaySendEventNameUpdatePresence, c); err != nil { + return err + } - // spawn the heartbeat beat goroutine. - s.manager.routines.Add(1) - s.manager.Go(func() error { - if err := s.beat(bot); err != nil { - return ErrorSession{ - SessionID: s.ID, - Err: fmt.Errorf("heartbeat: %w", err), - } - } + return nil +} - return nil - }) +// SendEvent sends an Opcode 4 UpdateVoiceState event to the Discord Gateway. +func (c *GatewayVoiceStateUpdate) SendEvent(bot *Client, session *Session) error { + if err := writeEvent(bot, session, FlagGatewayOpcodeVoiceStateUpdate, FlagGatewaySendEventNameUpdateVoiceState, c); err != nil { + return err + } - // send the initial Identify or Resumed packet. - if err := s.initial(bot, 0); err != nil { - sessionErr := ErrorSession{SessionID: s.ID, Err: err} - if disconnectErr := s.disconnect(FlagClientCloseEventCodeNormal); disconnectErr != nil { - sessionErr.Err = ErrorDisconnect{ - Action: err, - Err: disconnectErr, - Connection: ErrConnectionSession, - } - } + return nil +} - return sessionErr +// SendEvent sends an Opcode 6 Resume event to the Discord Gateway. +func (c *Resume) SendEvent(bot *Client, session *Session) error { + if err := writeEvent(bot, session, FlagGatewayOpcodeResume, FlagGatewaySendEventNameResume, c); err != nil { + return err } - // spawn the event listener listen goroutine. - s.manager.routines.Add(1) - s.manager.Go(func() error { - if err := s.listen(bot); err != nil { - return ErrorSession{ - SessionID: s.ID, - Err: fmt.Errorf("listen: %w", err), - } - } + return nil +} - return nil - }) +// SendEvent sends an Opcode 8 RequestGuildMembers event to the Discord Gateway. +func (c *RequestGuildMembers) SendEvent(bot *Client, session *Session) error { + if err := writeEvent(bot, session, FlagGatewayOpcodeRequestGuildMembers, FlagGatewaySendEventNameRequestGuildMembers, c); err != nil { + return err + } - // spawn the manager goroutine. - s.manager.routines.Add(1) - go s.manage(bot) + return nil +} - // ensure that the Session's goroutines are spawned. - s.manager.routines.Wait() - - return nil -} - -// initial sends the initial Identify or Resume packet required to connect to the Gateway, -// then handles the incoming Ready or Resumed packet that indicates a successful connection. -func (s *Session) initial(bot *Client, attempt int) error { - if !s.canReconnect() { - // send an Opcode 2 Identify to the Discord Gateway. - identify := Identify{ - Token: bot.Authentication.Token, - Properties: IdentifyConnectionProperties{ - OS: runtime.GOOS, - Browser: module, - Device: module, - }, - Compress: Pointer(true), - LargeThreshold: Pointer(maxIdentifyLargeThreshold), - Shard: s.Shard, - Presence: bot.Config.Gateway.GatewayPresenceUpdate, - Intents: bot.Config.Gateway.Intents, - } - - if err := identify.SendEvent(bot, s); err != nil { - return err - } - } else { - // send an Opcode 6 Resume to the Discord Gateway to reconnect the session. - resume := Resume{ - Token: bot.Authentication.Token, - SessionID: s.ID, - Seq: atomic.LoadInt64(&s.Seq), - } - - if err := resume.SendEvent(bot, s); err != nil { - return err - } - } - - // handle the incoming Ready, Resumed or Replayed event (or Opcode 9 Invalid Session). - payload := getPayload() - defer putPayload(payload) - if err := socket.Read(s.Context, s.Conn, payload); err != nil { - return fmt.Errorf("error reading initial payload: %w", err) - } - - LogPayload(LogSession(Logger.Info(), s.ID), payload.Op, payload.Data).Msg("received initial payload") - - switch payload.Op { - case FlagGatewayOpcodeDispatch: - switch { - // When a connection is successful, the Discord Gateway will respond with a Ready event. - case *payload.EventName == FlagGatewayEventNameReady: - ready := new(Ready) - if err := json.Unmarshal(payload.Data, ready); err != nil { - return fmt.Errorf("error reading ready event: %w", err) - } - - LogSession(Logger.Info(), ready.SessionID).Msg("received Ready event") - - // Configure the session. - s.ID = ready.SessionID - atomic.StoreInt64(&s.Seq, 0) - s.Endpoint = ready.ResumeGatewayURL - - // Store the session in the session manager. - s.client_manager.Gateway.Store(s.ID, s) - - if bot.Config.Gateway.ShardManager != nil { - bot.Config.Gateway.ShardManager.Ready(bot, s, ready) - } - - for _, handler := range bot.Handlers.Ready { - go handler(ready) - } - - // When a reconnection is successful, the Discord Gateway will respond - // by replaying all missed events in order, finalized by a Resumed event. - case *payload.EventName == FlagGatewayEventNameResumed: - LogSession(Logger.Info(), s.ID).Msg("received Resumed event") - - // Store the session in the session manager. - s.client_manager.Gateway.Store(s.ID, s) - - for _, handler := range bot.Handlers.Resumed { - go handler(&Resumed{}) - } - - // When a reconnection is successful, the Discord Gateway will respond - // by replaying all missed events in order, finalized by a Resumed event. - default: - // handle the initial payload(s) until a Resumed event is encountered. - go bot.handle(*payload.EventName, payload.Data) - - for { - replayed := new(GatewayPayload) - if err := socket.Read(s.Context, s.Conn, replayed); err != nil { - return fmt.Errorf("error replaying events: %w", err) - } - - if replayed.Op == FlagGatewayOpcodeDispatch && *replayed.EventName == FlagGatewayEventNameResumed { - LogSession(Logger.Info(), s.ID).Msg("received Resumed event") - - // Store the session in the session manager. - s.client_manager.Gateway.Store(s.ID, s) - - for _, handler := range bot.Handlers.Resumed { - go handler(&Resumed{}) - } - - return nil - } - - go bot.handle(*payload.EventName, payload.Data) - } - } - - // When the maximum concurrency limit has been reached while connecting, or when - // the session does NOT reconnect in time, the Discord Gateway send an Opcode 9 Invalid Session. - case FlagGatewayOpcodeInvalidSession: - // Remove the session from the session manager. - s.client_manager.RemoveGatewaySession(s.ID) - - if attempt < 1 { - // wait for Discord to close the session, then complete a fresh connect. - <-time.NewTimer(invalidSessionWaitTime).C - - s.ID = "" - atomic.StoreInt64(&s.Seq, 0) - if err := s.initial(bot, attempt+1); err != nil { - return err - } - - return nil - } - - return fmt.Errorf("session %q couldn't connect to the Discord Gateway or has invalidated an active session", s.ID) - default: - return fmt.Errorf("session %q received payload %d during connection which is unexpected", s.ID, payload.Op) - } - - return nil -} - -// Disconnect disconnects a session from the Discord Gateway. -func (s *Session) Disconnect() error { - s.Lock() - - if !s.isConnected() { - s.Unlock() - - return fmt.Errorf("session %q is already disconnected", s.ID) - } - - id := s.ID - LogSession(Logger.Info(), id).Msgf("disconnecting session with code %d", FlagClientCloseEventCodeNormal) - - s.manager.signal = context.WithValue(s.manager.signal, keySignal, signalDisconnect) - - if err := s.disconnect(FlagClientCloseEventCodeNormal); err != nil { - s.Unlock() - - return ErrorDisconnect{ - Connection: ErrConnectionSession, - Action: nil, - Err: err, - } - } - - s.Unlock() - - if err := <-s.manager.err; err != nil { - return err - } - - putSession(s) - - LogSession(Logger.Info(), id).Msgf("disconnected session with code %d", FlagClientCloseEventCodeNormal) - - return nil -} - -// disconnect disconnects a session from a WebSocket Connection using the given status code. -func (s *Session) disconnect(code int) error { - // cancel the context to kill the goroutines of the Session. - defer s.manager.cancel() - - // Remove the session from the session manager. - s.client_manager.RemoveGatewaySession(s.ID) - - if err := s.Conn.Close(websocket.StatusCode(code), ""); err != nil { - return fmt.Errorf("%w", err) - } - - return nil -} - -// Reconnect reconnects an already connected session to the Discord Gateway -// by disconnecting the session, then connecting again. -func (s *Session) Reconnect(bot *Client) error { - s.reconnect(bot, "reconnecting") - - if err := <-s.manager.err; err != nil { - return err - } - - // connect to the Discord Gateway again. - if err := s.Connect(bot); err != nil { - return fmt.Errorf("error reconnecting session %q: %w", s.ID, err) - } - - return nil -} - -// readEvent is a helper function for reading events from the WebSocket Session. -func readEvent(s *Session, dst any) error { - payload := new(GatewayPayload) - if err := socket.Read(s.Context, s.Conn, payload); err != nil { - return fmt.Errorf("readEvent: %w", err) - } - - if err := json.Unmarshal(payload.Data, dst); err != nil { - return fmt.Errorf("readEvent: %w", err) - } +// SendEvent sends an Opcode 31 RequestSoundboardSounds event to the Discord Gateway. +func (c *RequestSoundboardSounds) SendEvent(bot *Client, session *Session) error { + if err := writeEvent(bot, session, FlagGatewayOpcodeRequestSoundboardSounds, FlagGatewaySendEventNameRequestSoundboardSounds, c); err != nil { + return err + } return nil } @@ -17754,69 +17459,6 @@ SEND: return nil } -// SendEvent sends an Opcode 1 Heartbeat event to the Discord Gateway. -func (c *Heartbeat) SendEvent(bot *Client, session *Session) error { - if err := writeEvent(bot, session, FlagGatewayOpcodeHeartbeat, FlagGatewaySendEventNameHeartbeat, c); err != nil { - return err - } - - return nil -} - -// SendEvent sends an Opcode 2 Identify event to the Discord Gateway. -func (c *Identify) SendEvent(bot *Client, session *Session) error { - if err := writeEvent(bot, session, FlagGatewayOpcodeIdentify, FlagGatewaySendEventNameIdentify, c); err != nil { - return err - } - - return nil -} - -// SendEvent sends an Opcode 3 UpdatePresence event to the Discord Gateway. -func (c *GatewayPresenceUpdate) SendEvent(bot *Client, session *Session) error { - if err := writeEvent(bot, session, FlagGatewayOpcodePresenceUpdate, FlagGatewaySendEventNameUpdatePresence, c); err != nil { - return err - } - - return nil -} - -// SendEvent sends an Opcode 4 UpdateVoiceState event to the Discord Gateway. -func (c *GatewayVoiceStateUpdate) SendEvent(bot *Client, session *Session) error { - if err := writeEvent(bot, session, FlagGatewayOpcodeVoiceStateUpdate, FlagGatewaySendEventNameUpdateVoiceState, c); err != nil { - return err - } - - return nil -} - -// SendEvent sends an Opcode 6 Resume event to the Discord Gateway. -func (c *Resume) SendEvent(bot *Client, session *Session) error { - if err := writeEvent(bot, session, FlagGatewayOpcodeResume, FlagGatewaySendEventNameResume, c); err != nil { - return err - } - - return nil -} - -// SendEvent sends an Opcode 8 RequestGuildMembers event to the Discord Gateway. -func (c *RequestGuildMembers) SendEvent(bot *Client, session *Session) error { - if err := writeEvent(bot, session, FlagGatewayOpcodeRequestGuildMembers, FlagGatewaySendEventNameRequestGuildMembers, c); err != nil { - return err - } - - return nil -} - -// SendEvent sends an Opcode 31 RequestSoundboardSounds event to the Discord Gateway. -func (c *RequestSoundboardSounds) SendEvent(bot *Client, session *Session) error { - if err := writeEvent(bot, session, FlagGatewayOpcodeRequestSoundboardSounds, FlagGatewaySendEventNameRequestSoundboardSounds, c); err != nil { - return err - } - - return nil -} - // Handlers represents a bot's event handlers. type Handlers struct { Hello []func(*Hello) @@ -20887,6 +20529,42 @@ func (bot *Client) handle(eventname string, data json.RawMessage) { } } +// coroner investigates when a Session's goroutines are shutdown. +func (s *Session) coroner() { + // wait until all the manager goroutines is closed. + err := s.manager.coroner.Wait() + + s.Lock() + + // report the disconnection error + s.manager.actionError <- err + close(s.manager.actionError) + + // remove the session from the client. + s.client_manager.RemoveGatewaySession(s.ID) + + s.logClose("coroner") + s.Unlock() +} + +// Wait blocks until the calling Session is inactive (due to a final disconnect), +// then returns the Session's state and the disconnection error (if it exists). +// +// If Wait() is called on a Session that isn't connected, it will return immediately +// with code SessionStateNew. +// +// A disconnected session is reset and placed into a memory pool, +// so do NOT modify a Session after it disconnects. +func (s *Session) Wait() (string, error) { + if s.State() == SessionStateNew { + return SessionStateNew, nil + } + + err := s.manager.coroner.Wait() + + return s.State(), err +} + // heartbeat represents the heartbeat mechanism for a Session. type heartbeat struct { ticker *time.Ticker @@ -20908,7 +20586,7 @@ func (s *Session) Monitor() uint32 { func (s *Session) beat(bot *Client) error { s.manager.routines.Done() - // ensure that all pulse routines are closed prior to closing. + // confirm all pulse routines are closed prior to closing. defer func() { for { select { @@ -20934,7 +20612,7 @@ func (s *Session) beat(bot *Client) error { if atomic.LoadUint32(&s.heartbeat.acks) == 0 { s.Unlock() - s.reconnect(bot, "attempting to reconnect session due to no HeartbeatACK") + s.reconnect("attempting to reconnect session due to no HeartbeatACK") return nil } @@ -21025,7 +20703,7 @@ func (s *Session) respond(data json.RawMessage) error { s.Lock() - // ensure that the heartbeat routine has not been closed. + // confirm the heartbeat routine has not been closed. if atomic.LoadInt32(&s.manager.pulses) <= 1 { s.Unlock() @@ -21050,6 +20728,14 @@ func (s *Session) respond(data json.RawMessage) error { return nil } +// decrementPulses safely decrements the pulses counter. +func (s *Session) decrementPulses() { + s.Lock() + defer s.Unlock() + + atomic.AddInt32(&s.manager.pulses, -1) +} + // listen listens to the connection for payloads from the Discord Gateway. func (s *Session) listen(bot *Client) error { s.manager.routines.Done() @@ -21070,16 +20756,20 @@ func (s *Session) listen(bot *Client) error { } s.Lock() - defer s.Unlock() defer s.logClose("listen") + defer s.Unlock() - select { - case <-s.Context.Done(): - return nil + if s.Context != nil { + select { + case <-s.Context.Done(): + return nil - default: - return err + default: + return err + } } + + return nil } // onPayload handles an Discord Gateway Payload. @@ -21115,7 +20805,7 @@ func (s *Session) onPayload(bot *Client, payload GatewayPayload) error { // occurs when the Discord Gateway is shutting down the connection, while signalling the client to reconnect. case FlagGatewayOpcodeReconnect: - s.reconnect(bot, "reconnecting session due to Opcode 7 Reconnect") + s.reconnect("reconnecting session due to Opcode 7 Reconnect") return nil @@ -21138,40 +20828,280 @@ func (s *Session) onPayload(bot *Client, payload GatewayPayload) error { return nil } -// signal represents a manager Context Signal. -type signal string +// manager represents a manager of a Session's goroutines. +type manager struct { + coroner errgroup.Group + signals chan uint8 + cancel context.CancelFunc + actionError chan error + *errgroup.Group + state string + routines sync.WaitGroup + stateMutex sync.RWMutex + pulses int32 +} -// manager Context Signals. +// spawnManager spawns a tracked manager. +func (s *Session) spawnManager(bot *Client) { + s.manager = new(manager) + + s.Context, s.manager.cancel = context.WithCancel(context.Background()) + s.manager.Group, s.Context = errgroup.WithContext(s.Context) + s.manager.signals = make(chan uint8) + s.manager.actionError = make(chan error, 1) + + // spawn the manager goroutine. + s.manager.coroner.Go(func() error { + if err := s.manage(bot); err != nil { + return fmt.Errorf("manager: %w", err) + } + + return nil + }) +} + +// Session States represent the state of the Session's connection to Discord. const ( - // keySignal represents the Context key for a manager's signals. - keySignal = signal("signal") + SessionStateNew = "" - // keyReason represents the Context key for a manager's reason for disconnection. - keyReason = signal("reason") + SessionStateConnecting = "connecting (before websocket connection)" + SessionStateConnectingWebsocket = "connecting (with websocket connection)" + SessionStateConnected = "connected" - // signalDisconnect indicates that a disconnection was called purposefully. - signalDisconnect = 1 + SessionStateDisconnecting = "disconnecting (purposefully)" + SessionStateDisconnectingError = "disconnecting (due to an error)" + SessionStateDisconnectingReconnect = "disconnecting (while reconnecting)" - // signalReconnect signals the manager to reconnect upon a successful disconnection. - signalReconnect = 2 + SessionStateDisconnectedFinal = "disconnected (after connection)" + SessionStateDisconnectedError = "disconnected (due to an error)" + SessionStateDisconnectedReconnect = "disconnected (while reconnecting)" + + SessionStateReconnecting = "reconnecting" ) -// manager represents a manager of a Session's goroutines. -type manager struct { - signal context.Context - cancel context.CancelFunc - err chan error - *errgroup.Group - routines sync.WaitGroup - pulses int32 +// State returns the state of the Session's connection to Discord. +func (s *Session) State() string { + s.manager.stateMutex.RLock() + defer s.manager.stateMutex.RUnlock() + + return s.manager.state } -// decrementPulses safely decrements the pulses counter of a Session manager. -func (s *Session) decrementPulses() { - s.Lock() - defer s.Unlock() +// setState sets the state of a Session. +func (s *Session) setState(state string) { + s.manager.stateMutex.Lock() + s.manager.state = state + s.manager.stateMutex.Unlock() +} - atomic.AddInt32(&s.manager.pulses, -1) +// canReconnect returns whether the Session's fields are in a valid state to reconnect. +func (s *Session) canReconnect() bool { + return s.ID != "" && s.Endpoint != "" && atomic.LoadInt64(&s.Seq) != 0 +} + +// Session Signals represent manager signals to perform actions to the Session. +const ( + sessionSignalConnect = 1 + sessionSignalDisconnect = 2 + sessionSignalReconnect = 3 +) + +// manage manages a Session's goroutines. +func (s *Session) manage(bot *Client) error { + // spawn the coroner once the manager routine is alive. + go s.coroner() + + defer func() { + if s.State() != SessionStateDisconnectedReconnect { + s.Unlock() + } + + // wait until the previous connection's manager goroutines are closed. + _ = s.manager.Wait() + + s.logClose("manager") + }() + + var managedErr error + + for { + select { + case <-s.Context.Done(): + if s.State() == SessionStateDisconnectedReconnect { + break + } + + // wait until the previous connection's manager goroutines are closed. + err := s.manager.Wait() + if err != nil { + closeErr := new(websocket.CloseError) + + if errors.As(err, closeErr) { + if vErr := s.validateGatewayCloseError(closeErr); vErr == nil { + // reconnect from a state where + s.setState(SessionStateDisconnectedReconnect) + + // manager routines must be reset + s.Context, s.manager.cancel = context.WithCancel(context.Background()) //nolint:fatcontext + s.manager.Group, s.Context = errgroup.WithContext(s.Context) + + go func() { + // send a connection signal. + s.manager.signals <- sessionSignalConnect + + // read the s.manager.actionError send from a successful connection. + e := <-s.manager.actionError + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("captured result from close event reconnect: %q", e) + }() + + s.Lock() + + break // to reconnect from the connect case logic. + } // vErr == nil + } // errors.As + + // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 + if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { + err = nil + } + } // err != nil + + s.Lock() + + return err + + case signal := <-s.manager.signals: + switch signal { + case sessionSignalConnect: + if s.State() != SessionStateDisconnectedReconnect { + s.Lock() + } else { + s.Unlock() + + // wait until the previous connection's manager goroutines are closed. + _ = s.manager.Wait() + + s.Lock() + } + + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connecting session") + + if err := s.connect(bot); err != nil { + managedErr = ErrorSession{SessionID: s.ID, State: s.State(), Type: ErrorSessionTypeGateway, Err: err} + + switch s.State() { + case SessionStateConnectingWebsocket: + go func() { + // send a disconnection signal. + s.manager.signals <- sessionSignalDisconnect + }() + + // case SessionStateNew, SessionStateConnecting... + default: + return managedErr + } + + break // to handle the error in the disconnect case logic. + } + + s.setState(SessionStateConnected) + s.manager.actionError <- nil + s.Unlock() + + case sessionSignalDisconnect: + if managedErr == nil && s.State() != SessionStateReconnecting { + s.Lock() + } + + // update the session's state and client close event code. + code := FlagClientCloseEventCodeNormal + + switch { + case managedErr != nil: + s.setState(SessionStateDisconnectingError) + case s.State() == SessionStateReconnecting: + s.setState(SessionStateDisconnectingReconnect) + code = FlagClientCloseEventCodeReconnect + default: + s.setState(SessionStateDisconnecting) + } + + LogSession(Logger.Info(), s.ID).Msgf("%q session with code %d", s.State(), code) + + // disconnect the session. + if err := s.disconnect(code); err != nil { + managedErr = ErrorSession{ + SessionID: s.ID, + State: s.State(), + Type: ErrorSessionTypeGateway, + Err: ErrorSessionDisconnect{ + Action: managedErr, + Err: err, + }, + } + + // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 + if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { + managedErr = nil + } + + if s.State() != SessionStateDisconnectingReconnect { + return managedErr + } + + // validate error when reconnecting + closeErr := new(websocket.CloseError) + if errors.As(managedErr, closeErr) { + if managedErr = s.validateGatewayCloseError(closeErr); managedErr != nil { + return managedErr + } + } + } // disconnect + + // update the session's state. + switch { + case s.State() == SessionStateDisconnectingError: + s.setState(SessionStateDisconnectedError) + + case s.State() == SessionStateDisconnectingReconnect: + s.setState(SessionStateDisconnectedReconnect) + + default: + s.setState(SessionStateDisconnectedFinal) + } + + LogSession(Logger.Info(), s.ID).Msgf("%q session with code %d", s.State(), code) + + if s.State() == SessionStateDisconnectedReconnect { + // allow Discord to close the session. + <-time.After(time.Second) + + go func() { + // send a connection signal. + s.manager.signals <- sessionSignalConnect + }() + + break + } + + // Destroy the manager when the bot isn't reconnecting. + if managedErr != nil { + return managedErr + } + + return nil + + case sessionSignalReconnect: + s.Lock() + s.setState(SessionStateReconnecting) + + go func() { + // send a disconnection signal. + s.manager.signals <- sessionSignalDisconnect + }() + } + } // select + } // for } // logClose safely logs the close of a Session's goroutine. @@ -21179,225 +21109,346 @@ func (s *Session) logClose(routine string) { LogSession(Logger.Info(), s.ID).Msgf("closed %s routine", routine) } -// reconnect spawns a goroutine for reconnection which prompts the manager -// to reconnect upon a disconnection. -func (s *Session) reconnect(bot *Client, reason string) { +// validateGatewayCloseError validates a WebSocket CloseError +// and returns whether to reconnect (when error == nil). +func (s *Session) validateGatewayCloseError(closeErr *websocket.CloseError) error { + code, ok := GatewayCloseEventCodes[int(closeErr.Code)] + + switch ok { + // Gateway Close Event Code is known. + case true: + LogSession(Logger.Info(), s.ID). + Msgf("received Gateway Close Event Code %d %s: %s", + code.Code, code.Description, code.Explanation, + ) + + if code.Reconnect { + return nil + } + + return closeErr + + // Gateway Close Event Code is unknown. + default: + LogSession(Logger.Info(), s.ID). + Msgf("received unknown Gateway Close Event Code %d with reason %q", + closeErr.Code, closeErr.Reason, + ) + + return closeErr + } +} + +const ( + gatewayEndpointParams = "?v=" + VersionDiscordAPI + "&encoding=json" + invalidSessionWaitTime = 1 * time.Second + maxIdentifyLargeThreshold = 250 +) + +// connect connects a session to a WebSocket Connection. +func (s *Session) connect(bot *Client) error { + var err error + + // request a valid Gateway URL endpoint and response from the Discord API. + gatewayEndpoint := s.Endpoint + var response *GetGatewayBotResponse + + if bot.Config.Gateway.ShardManager != nil { + if response, err = bot.Config.Gateway.ShardManager.SetLimit(bot); err != nil { + return fmt.Errorf("shardmanager: %w", err) + } + } else { + if gatewayEndpoint == "" || !s.canReconnect() { + gateway := GetGatewayBot{} + response, err = gateway.Send(bot) + if err != nil { + return fmt.Errorf("error getting the Gateway API Endpoint: %w", err) + } + + gatewayEndpoint = response.URL + } + } + + // set the maximum allowed (Identify) concurrency rate limit for the bot. + // + // https://discord.com/developers/docs/topics/gateway#rate-limiting + if response != nil { + bot.Config.Gateway.RateLimiter.StartTx() + + identifyBucket := bot.Config.Gateway.RateLimiter.GetBucketFromID(FlagGatewaySendEventNameIdentify) + if identifyBucket == nil { + identifyBucket = getBucket() + bot.Config.Gateway.RateLimiter.SetBucketFromID(FlagGatewaySendEventNameIdentify, identifyBucket) + } + + identifyBucket.Limit = int16(response.SessionStartLimit.MaxConcurrency) //nolint:gosec // disable G115 + + if identifyBucket.Expiry.IsZero() { + identifyBucket.Remaining = identifyBucket.Limit + identifyBucket.Expiry = time.Now().Add(FlagGlobalRateLimitIdentifyInterval) + } + + bot.Config.Gateway.RateLimiter.EndTx() + } + + // set up the Session's Rate Limiter (applied per WebSocket Connection). + // https://discord.com/developers/docs/topics/gateway#rate-limiting + s.RateLimiter = &RateLimit{ //nolint:exhaustruct + ids: make(map[string]string, totalGatewayBucketsPerConnection), + buckets: make(map[string]*Bucket, totalGatewayBucketsPerConnection), + } + + s.RateLimiter.SetBucket( + GlobalRateLimitRouteID, &Bucket{ //nolint:exhaustruct + Limit: FlagGlobalRateLimitGateway, + Remaining: FlagGlobalRateLimitGateway, + Expiry: time.Now().Add(FlagGlobalRateLimitGatewayInterval), + }, + ) + + // connect to the Discord Gateway Websocket. + s.Context, s.manager.cancel = context.WithCancel(context.Background()) + if s.Conn, _, err = websocket.Dial(s.Context, gatewayEndpoint+gatewayEndpointParams, nil); err != nil { + return fmt.Errorf("error connecting to the Discord Gateway: %w", err) + } + + s.setState(SessionStateConnectingWebsocket) + + // handle the incoming Hello event upon connecting to the Gateway. + hello := new(Hello) + if err := readEvent(s, hello); err != nil { + return fmt.Errorf("error reading initial Hello event: %w", err) + } + + for _, handler := range bot.Handlers.Hello { + go handler(hello) + } + + // begin sending heartbeat payloads every heartbeat_interval ms. + ms := time.Millisecond * time.Duration(hello.HeartbeatInterval) + s.heartbeat = &heartbeat{ + interval: ms, + ticker: time.NewTicker(ms), + send: make(chan Heartbeat), + + // add a HeartbeatACK to the HeartbeatACK channel to prevent + // the length of the HeartbeatACK channel from being 0 immediately, + // which results in an attempt to reconnect. + acks: 1, + } + + // spawn the heartbeat pulse goroutine. + s.manager.routines.Add(1) s.manager.Go(func() error { - s.Lock() - defer s.logClose("reconnect") - defer s.Unlock() + atomic.AddInt32(&s.manager.pulses, 1) + s.pulse() - LogSession(Logger.Info(), s.ID).Msg(reason) + return nil + }) - s.manager.signal = context.WithValue(s.manager.signal, keySignal, signalReconnect) - if err := s.disconnect(FlagClientCloseEventCodeReconnect); err != nil { - return fmt.Errorf("reconnect: %w", err) + // spawn the heartbeat beat goroutine. + s.manager.routines.Add(1) + s.manager.Go(func() error { + if err := s.beat(bot); err != nil { + return fmt.Errorf("heartbeat: %w", err) } - // connect to the Discord Gateway again. - s.Context = nil - if err := s.connect(bot); err != nil { - return fmt.Errorf("reconnect: %w", err) + return nil + }) + + // send the initial Identify or Resumed packet. + if err := s.initial(bot, 0); err != nil { + return fmt.Errorf("initial: %w", err) + } + + // spawn the event listener listen goroutine. + s.manager.routines.Add(1) + s.manager.Go(func() error { + if err := s.listen(bot); err != nil { + return fmt.Errorf("listen: %w", err) } return nil }) + + // confirm the Session's goroutines are spawned. + s.manager.routines.Wait() + + return nil } -// manage manages a Session's goroutines. -func (s *Session) manage(bot *Client) { - s.manager.routines.Done() - defer func() { - s.Lock() - s.logClose("manager") - s.Unlock() - }() +// initial sends the initial Identify or Resume packet required to connect to the Gateway, +// then handles the incoming Ready or Resumed packet that indicates a successful connection. +func (s *Session) initial(bot *Client, attempt int) error { + if !s.canReconnect() { + // send an Opcode 2 Identify to the Discord Gateway. + identify := Identify{ + Token: bot.Authentication.Token, + Properties: IdentifyConnectionProperties{ + OS: runtime.GOOS, + Browser: module, + Device: module, + }, + Compress: Pointer(true), + LargeThreshold: Pointer(maxIdentifyLargeThreshold), + Shard: s.Shard, + Presence: bot.Config.Gateway.GatewayPresenceUpdate, + Intents: bot.Config.Gateway.Intents, + } - // wait until all of a Session's goroutines are closed. - err := s.manager.Wait() - s.Lock() - defer s.Unlock() + if err := identify.SendEvent(bot, s); err != nil { + return err + } + } else { + // send an Opcode 6 Resume to the Discord Gateway to reconnect the session. + resume := Resume{ + Token: bot.Authentication.Token, + SessionID: s.ID, + Seq: atomic.LoadInt64(&s.Seq), + } - // log the reason for disconnection (if applicable). - if reason := s.manager.signal.Value(keyReason); reason != nil { - LogSession(Logger.Info(), s.ID).Msgf("%v", reason) + if err := resume.SendEvent(bot, s); err != nil { + return err + } + } + + // handle the incoming Ready, Resumed or Replayed event (or Opcode 9 Invalid Session). + payload := getPayload() + defer putPayload(payload) + if err := socket.Read(s.Context, s.Conn, payload); err != nil { + return fmt.Errorf("error reading initial payload: %w", err) } - // when a signal is provided, it indicates that the disconnection was purposeful. - signal := s.manager.signal.Value(keySignal) - switch signal { - case signalDisconnect: - LogSession(Logger.Info(), s.ID).Msg("successfully disconnected") - - s.manager.err <- nil - - return + LogPayload(LogSession(Logger.Info(), s.ID), payload.Op, payload.Data).Msg("received initial payload") - case signalReconnect: - LogSession(Logger.Info(), s.ID).Msg("successfully disconnected (while reconnecting)") + switch payload.Op { + case FlagGatewayOpcodeDispatch: + switch { + // When a connection is successful, the Discord Gateway will respond with a Ready event. + case *payload.EventName == FlagGatewayEventNameReady: + ready := new(Ready) + if err := json.Unmarshal(payload.Data, ready); err != nil { + return fmt.Errorf("error reading ready event: %w", err) + } - // allow Discord to close the session. - <-time.After(time.Second) + LogSession(Logger.Info(), ready.SessionID).Msg("received Ready event") - s.manager.err <- nil + // Configure the session. + s.ID = ready.SessionID + atomic.StoreInt64(&s.Seq, 0) + s.Endpoint = ready.ResumeGatewayURL - return - } + // Store the session in the session manager. + s.client_manager.Gateway.Store(s.ID, s) - // when an error caused goroutines to close, manage the state of disconnection. - if err != nil { - disconnectErr := new(ErrorDisconnect) - closeErr := new(websocket.CloseError) - switch { - // when an error occurs from a purposeful disconnection. - case errors.As(err, disconnectErr): - s.manager.err <- err + if bot.Config.Gateway.ShardManager != nil { + bot.Config.Gateway.ShardManager.Ready(bot, s, ready) + } - // when an error occurs from a WebSocket Close Error. - case errors.As(err, closeErr): - if bot == nil { - s.manager.err <- fmt.Errorf("gateway websocket close error, but unable to reconnect: %w", err) + for _, handler := range bot.Handlers.Ready { + go handler(ready) } - s.manager.err <- s.handleGatewayCloseError(bot, closeErr) + // When a reconnection is successful, the Discord Gateway will respond + // by replaying all missed events in order, finalized by a Resumed event. + case *payload.EventName == FlagGatewayEventNameResumed: + LogSession(Logger.Info(), s.ID).Msg("received Resumed event") - default: - if cErr := s.Conn.Close(websocket.StatusCode(FlagClientCloseEventCodeAway), ""); cErr != nil { - s.manager.err <- ErrorDisconnect{ - Action: err, - Err: cErr, - Connection: ErrConnectionSession, - } + // Store the session in the session manager. + s.client_manager.Gateway.Store(s.ID, s) - return + for _, handler := range bot.Handlers.Resumed { + go handler(&Resumed{}) } - s.manager.err <- err - } + // When a reconnection is successful, the Discord Gateway will respond + // by replaying all missed events in order, finalized by a Resumed event. + default: + // handle the initial payload(s) until a Resumed event is encountered. + go bot.handle(*payload.EventName, payload.Data) - return - } + for { + replayed := new(GatewayPayload) + if err := socket.Read(s.Context, s.Conn, replayed); err != nil { + return fmt.Errorf("error replaying events: %w", err) + } - s.manager.err <- nil -} + if replayed.Op == FlagGatewayOpcodeDispatch && *replayed.EventName == FlagGatewayEventNameResumed { + LogSession(Logger.Info(), s.ID).Msg("received Resumed event") -// handleGatewayCloseError handles a WebSocket CloseError. -func (s *Session) handleGatewayCloseError(bot *Client, closeErr *websocket.CloseError) error { - code, ok := GatewayCloseEventCodes[int(closeErr.Code)] - switch ok { - // Gateway Close Event Code is known. - case true: - LogSession(Logger.Info(), s.ID). - Msgf("received Gateway Close Event Code %d %s: %s", - code.Code, code.Description, code.Explanation, - ) + // Store the session in the session manager. + s.client_manager.Gateway.Store(s.ID, s) - if code.Reconnect { - s.reconnect(bot, fmt.Sprintf("reconnecting due to Gateway Close Event Code %d", code.Code)) + for _, handler := range bot.Handlers.Resumed { + go handler(&Resumed{}) + } - return nil + return nil + } + + go bot.handle(*payload.EventName, payload.Data) + } } - return closeErr + // When the maximum concurrency limit has been reached while connecting, or when + // the session does NOT reconnect in time, the Discord Gateway send an Opcode 9 Invalid Session. + case FlagGatewayOpcodeInvalidSession: + // Remove the session from the session manager. + s.client_manager.RemoveGatewaySession(s.ID) - // Gateway Close Event Code is unknown. - default: + if attempt < 1 { + // wait for Discord to close the session, then complete a fresh connect. + <-time.NewTimer(invalidSessionWaitTime).C + + s.ID = "" + atomic.StoreInt64(&s.Seq, 0) + if err := s.initial(bot, attempt+1); err != nil { + return err + } - // when another goroutine calls disconnect(), - // s.Conn.Close is called before s.cancel which will result in - // a CloseError with the close code that Disgo uses to reconnect. - if closeErr.Code == websocket.StatusCode(FlagClientCloseEventCodeReconnect) { return nil } - LogSession(Logger.Info(), s.ID). - Msgf("received unknown Gateway Close Event Code %d with reason %q", - closeErr.Code, closeErr.Reason, - ) - - return closeErr + return fmt.Errorf("session %q couldn't connect to the Discord Gateway or has invalidated an active session", s.ID) + default: + return fmt.Errorf("session %q received unexpected payload %d during connection", s.ID, payload.Op) } -} - -const ( - // SignalNone indicates that Wait() was called on an already disconnected session. - SignalNone = 0 - - // SignalDisconnect indicates that a disconnection was called purposefully. - SignalDisconnect = signalDisconnect - - // SignalReconnect indicates that a disconnection was called purposefully in order to reconnect. - SignalReconnect = signalReconnect - - // SignalError indicates that a disconnection occurred as an error. - SignalError = 3 - // SignalDisconnectError indicates that a disconnection was called purposefully (for any reason), - // but the Session experienced an error while disconnecting. - SignalDisconnectError = 4 + return nil +} - // SignalUndefined indicates that a disconnection occurred in an undefined manner. - // - // This signal should NEVER be returned: If it is, report it. - SignalUndefined = 5 -) +// disconnect disconnects a session from a WebSocket Connection using the given status code. +func (s *Session) disconnect(code int) error { + // cancel the context to kill the goroutines of the Session. + defer s.manager.cancel() -// Wait blocks until the calling Session has disconnected, then returns the reason -// (disgo.SignalReason) for disconnecting and the disconnection error (if it exists). -// -// If Wait() is called on a Session that isn't connected, it will return immediately -// with code SignalNone. -// -// It's NOT recommended to modify a Session after it has disconnected, -// since it will be cleared and placed into a memory pool shortly after. -func (s *Session) Wait() (int, error) { - if !s.isConnected() { - return SignalNone, nil + if err := s.Conn.Close(websocket.StatusCode(code), ""); err != nil { + return fmt.Errorf("%w", err) } - // NOTE: Wait() is equivalent to the s.manage() s.manager.Wait() handling logic, - // but without the management of the disconnection state, - // and without the usage of a channel that tells another goroutine to unblock. - // - // wait until all of a Session's goroutines are closed. - err := s.manager.Wait() - s.Lock() - defer s.Unlock() - - // when a signal is provided, it indicates that the disconnection was purposeful. - signal := s.manager.signal.Value(keySignal) - switch signal { - case signalDisconnect: - return SignalDisconnect, nil - - case signalReconnect: - return SignalReconnect, nil - } + return nil +} - // when an error caused goroutines to close. - if err != nil { - disconnectErr := new(ErrorDisconnect) - closeErr := new(websocket.CloseError) - switch { - // when an error occurs from a purposeful disconnection. - case errors.As(err, disconnectErr): - if signal != nil { - if signalValue, ok := signal.(int); ok { - return signalValue, err //nolint:wrapcheck - } - } +// reconnect spawns a goroutine for reconnection which prompts the manager +// to reconnect upon a disconnection. +func (s *Session) reconnect(reason string) { + LogSession(Logger.Info(), s.ID).Msg(reason) - return SignalDisconnectError, err //nolint:wrapcheck + s.manager.signals <- sessionSignalReconnect +} - // when an error occurs from a WebSocket Close Error. - case errors.As(err, closeErr): - return SignalError, s.handleGatewayCloseError(nil, closeErr) - } +// readEvent is a helper function for reading events from the WebSocket Session. +func readEvent(s *Session, dst any) error { + payload := new(GatewayPayload) + if err := socket.Read(s.Context, s.Conn, payload); err != nil { + return fmt.Errorf("readEvent: %w", err) + } - return SignalError, err //nolint:wrapcheck + if err := json.Unmarshal(payload.Data, dst); err != nil { + return fmt.Errorf("readEvent: %w", err) } - return SignalUndefined, nil + return nil } // ShardManager represents an interface for Shard Management. @@ -21618,7 +21669,7 @@ func (vc *VoiceChannelConnection) Connect(bot *Client) error { return errors.New("ConnectVoice: Voice ChannelID must be non-nil and non-empty to connect to voice channel") } - if vc.GatewaySession == nil || !vc.GatewaySession.isConnected() { + if vc.GatewaySession == nil || vc.GatewaySession.State() != SessionStateConnected { return errors.New("ConnectVoice: Session must be connected to the Discord Gateway to connect to voice channel") } @@ -21679,7 +21730,7 @@ VOICESERVERUPDATE: case <-vc.GatewaySession.Context.Done(): vc.VoiceSession.RUnlock() - return <-vc.GatewaySession.manager.err + return <-vc.GatewaySession.manager.actionError default: vc.VoiceSession.RUnlock() //lint:ignore SA4011 break into for loop. @@ -21834,12 +21885,11 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { hello := new(VoiceHello) if err := readEventVoice(s, hello); err != nil { err = fmt.Errorf("error reading initial VoiceHello event: %w", err) - sessionErr := ErrorSession{SessionID: s.ID, Err: err} + sessionErr := ErrorSession{SessionID: s.ID, Err: err} //nolint:exhaustruct // voice needs refactor if disconnectErr := s.disconnect(FlagClientCloseEventCodeNormal); disconnectErr != nil { - sessionErr.Err = ErrorDisconnect{ - Action: err, - Err: disconnectErr, - Connection: ErrConnectionSessionVoice, + sessionErr.Err = ErrorSessionDisconnect{ + Action: err, + Err: disconnectErr, } } @@ -21879,7 +21929,7 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { s.manager.routines.Add(1) s.manager.Go(func() error { if err := s.beat(); err != nil { - return ErrorSession{ + return ErrorSession{ //nolint:exhaustruct // voice needs refactor SessionID: s.ID, Err: fmt.Errorf("heartbeat: %w", err), } @@ -21890,12 +21940,11 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { // send the initial Identify or Resumed packet. if err := s.initial(bot, vc); err != nil { - sessionErr := ErrorSession{SessionID: s.ID, Err: err} + sessionErr := ErrorSession{SessionID: s.ID, Err: err} //nolint:exhaustruct // voice needs refactor if disconnectErr := s.disconnect(FlagClientCloseEventCodeNormal); disconnectErr != nil { - sessionErr.Err = ErrorDisconnect{ - Action: err, - Err: disconnectErr, - Connection: ErrConnectionSessionVoice, + sessionErr.Err = ErrorSessionDisconnect{ + Action: err, + Err: disconnectErr, } } @@ -21906,7 +21955,7 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { s.manager.routines.Add(1) s.manager.Go(func() error { if err := s.listen(bot); err != nil { - return ErrorSession{ + return ErrorSession{ //nolint:exhaustruct // voice needs refactor SessionID: s.ID, Err: fmt.Errorf("listen: %w", err), } @@ -21919,7 +21968,7 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { s.manager.routines.Add(1) go s.manage() - // ensure that the Session's goroutines are spawned. + // confirm the Session's goroutines are spawned. s.manager.routines.Wait() return nil @@ -22007,8 +22056,6 @@ func (s *VoiceSession) disconnect(code int) error { id := s.ID LogSession(Logger.Info(), id).Msgf("disconnecting voice session with code %d", FlagClientCloseEventCodeNormal) - s.manager.signal = context.WithValue(s.manager.signal, keySignal, signalDisconnect) - // cancel the context to kill the goroutines of the Voice Session. defer s.manager.cancel() @@ -22037,29 +22084,6 @@ func readEventVoice(s *VoiceSession, dst any) error { return nil } -// writeEventVoice is a helper function for writing voice events to the WebSocket Session. -func writeEventVoice(s *VoiceSession, op int, name string, dst any) error { - LogCommandVoice(log.Trace(), op, name).Msg("sending voice server command") - - // write the event to the WebSocket Connection. - event, err := json.Marshal(dst) - if err != nil { - return fmt.Errorf("writeEvent: %w", err) - } - - if err = socket.Write(s.Context, s.Conn, websocket.MessageText, - VoicePayload{ - Op: op, - Data: event, - }); err != nil { - return fmt.Errorf("writeEvent: %w", err) - } - - LogCommandVoice(log.Trace(), op, name).Msg("sending voice server command") - - return nil -} - // SendEvent sends an Opcode 0 Identify event to the Discord Voice Server. func (c *VoiceIdentify) SendEvent(session *VoiceSession) error { if err := writeEventVoice(session, FlagVoiceOpcodeIdentify, FlagVoiceSendEventNameIdentify, c); err != nil { @@ -22105,6 +22129,29 @@ func (c *VoiceResume) SendEvent(session *VoiceSession) error { return nil } +// writeEventVoice is a helper function for writing voice events to the WebSocket Session. +func writeEventVoice(s *VoiceSession, op int, name string, dst any) error { + LogCommandVoice(log.Trace(), op, name).Msg("sending voice server command") + + // write the event to the WebSocket Connection. + event, err := json.Marshal(dst) + if err != nil { + return fmt.Errorf("writeEvent: %w", err) + } + + if err = socket.Write(s.Context, s.Conn, websocket.MessageText, + VoicePayload{ + Op: op, + Data: event, + }); err != nil { + return fmt.Errorf("writeEvent: %w", err) + } + + LogCommandVoice(log.Trace(), op, name).Msg("sending voice server command") + + return nil +} + // VoiceHandlers represents a voice channel connection's event handlers. type VoiceHandlers struct { VoiceReady []func(*VoiceReady) @@ -22372,7 +22419,7 @@ func (s *VoiceSession) Monitor() uint32 { func (s *VoiceSession) beat() error { s.manager.routines.Done() - // ensure that all pulse routines are closed prior to closing. + // confirm all pulse routines are closed prior to closing. defer func() { for { if s.heartbeat == nil { @@ -22560,7 +22607,6 @@ func (s *VoiceSession) reconnect(reason string) { LogSession(Logger.Info(), s.ID).Msg(reason) - s.manager.signal = context.WithValue(s.manager.signal, keySignal, signalReconnect) if err := s.disconnect(FlagClientCloseEventCodeReconnect); err != nil { return fmt.Errorf("reconnect: %w", err) } @@ -22583,35 +22629,9 @@ func (s *VoiceSession) manage() { s.Lock() defer s.Unlock() - // log the reason for disconnection (if applicable). - if reason := s.manager.signal.Value(keyReason); reason != nil { - LogSession(Logger.Info(), s.ID).Msgf("%v", reason) - } - - // when a signal is provided, it indicates that the disconnection was purposeful. - signal := s.manager.signal.Value(keySignal) - switch signal { - case signalDisconnect: - LogSession(Logger.Info(), s.ID).Msg("successfully disconnected") - - s.manager.err <- nil - - return - - case signalReconnect: - LogSession(Logger.Info(), s.ID).Msg("successfully disconnected (while reconnecting)") - - // allow Discord to close the session. - <-time.After(time.Second) - - s.manager.err <- nil - - return - } - // when an error caused goroutines to close, manage the state of disconnection. if err != nil { - disconnectErr := new(ErrorDisconnect) + disconnectErr := new(ErrorSessionDisconnect) closeErr := new(websocket.CloseError) switch { // when an error occurs from a purposeful disconnection. @@ -22624,10 +22644,9 @@ func (s *VoiceSession) manage() { default: if cErr := s.Conn.Close(websocket.StatusCode(FlagClientCloseEventCodeAway), ""); cErr != nil { - s.manager.err <- ErrorDisconnect{ - Action: err, - Err: cErr, - Connection: ErrConnectionSessionVoice, + s.manager.err <- ErrorSessionDisconnect{ + Action: err, + Err: cErr, } return @@ -22673,61 +22692,3 @@ func (s *VoiceSession) handleGatewayCloseError(closeErr *websocket.CloseError) e return closeErr } } - -// Wait blocks until the calling Voice Session has disconnected, then returns the reason -// (disgo.SignalReason) for disconnecting and the disconnection error (if it exists). -// -// If Wait() is called on a Voice Session that isn't connected, it will return immediately -// with code SignalNone. -// -// It's NOT recommended to modify a Voice Session after it has disconnected, -// since it will be cleared and placed into a memory pool shortly after. -func (s *VoiceSession) Wait() (int, error) { - if !s.isConnected() { - return SignalNone, nil - } - - // NOTE: Wait() is equivalent to the s.manage() s.manager.Wait() handling logic, - // but without the management of the disconnection state, - // and without the usage of a channel that tells another goroutine to unblock. - // - // wait until all of a Session's goroutines are closed. - err := s.manager.Wait() - s.Lock() - defer s.Unlock() - - // when a signal is provided, it indicates that the disconnection was purposeful. - signal := s.manager.signal.Value(keySignal) - switch signal { - case signalDisconnect: - return SignalDisconnect, nil - - case signalReconnect: - return SignalReconnect, nil - } - - // when an error caused goroutines to close. - if err != nil { - disconnectErr := new(ErrorDisconnect) - closeErr := new(websocket.CloseError) - switch { - // when an error occurs from a purposeful disconnection. - case errors.As(err, disconnectErr): - if signal != nil { - if signalValue, ok := signal.(int); ok { - return signalValue, err //nolint:wrapcheck - } - } - - return SignalDisconnectError, err //nolint:wrapcheck - - // when an error occurs from a WebSocket Close Error. - case errors.As(err, closeErr): - return SignalError, s.handleGatewayCloseError(closeErr) - } - - return SignalError, err //nolint:wrapcheck - } - - return SignalUndefined, nil -} diff --git a/wrapper/errors.go b/wrapper/errors.go index 9291007..78beb89 100644 --- a/wrapper/errors.go +++ b/wrapper/errors.go @@ -37,19 +37,29 @@ func (e ErrorRequest) Error() string { e.ClientID, e.CorrelationID, e.RouteID, e.ResourceID, e.Endpoint, e.Err).Error() } +// ErrorStatusCode represents an HTTP Request error that occurs when an unexpected response is returned. +type ErrorStatusCode struct { + // StatusCode represents the HTTP Status Code received from a response. + StatusCode int +} + // Status Code Error Messages. const ( errStatusCodeKnown = "status code %d: %v" errStatusCodeUnknown = "status code %d: unknown status code error from Discord" ) -// StatusCodeError handles a Discord API HTTP Status Code and returns the relevant error message. -func StatusCodeError(status int) error { +func (e ErrorStatusCode) Error() string { + return fmt.Sprintf("STATUS CODE ERROR: status code: %q: msg: %v", e.StatusCode, StatusCodeError(e.StatusCode)) +} + +// StatusCodeError returns the relevant message for a Discord API HTTP Status Code. +func StatusCodeError(status int) string { if msg, ok := HTTPResponseCodes[status]; ok { - return fmt.Errorf(errStatusCodeKnown, status, msg) + return fmt.Sprintf(errStatusCodeKnown, status, msg) } - return fmt.Errorf(errStatusCodeUnknown, status) + return fmt.Sprintf(errStatusCodeUnknown, status) } // JSON Error Code Messages. @@ -124,26 +134,6 @@ func (e ErrorEvent) Error() string { e.ClientID, e.Event, e.Action, e.Err).Error() } -// Discord Gateway Error Messages -const ( - errNoSessionManager = `The client must contain a non-nil SessionManager struct to connect to the Discord Gateway. - - Set the *Client.SessionManager using one of the following methods. - - --- 1 - - bot := &disgo.Client{ - ... - Sessions: disgo.NewSessionManager(), - } - - --- 2 - - bot.Sessions = disgo.NewSessionManager() - - ` -) - // ErrorSession represents a WebSocket Session error that occurs during an active session. type ErrorSession struct { // Err represents the error that occurred. @@ -151,34 +141,37 @@ type ErrorSession struct { // SessionID represents the ID of the Session. SessionID string -} -func (e ErrorSession) Error() string { - return fmt.Errorf("SESSION ERROR: session %q: error: %w", e.SessionID, e.Err).Error() + // State represents the state of the session. + State string + + // Type represents the type of connection (e.g., Discord Gateway, Discord Voice). + Type string } const ( - ErrConnectionSession = "Discord Gateway" - ErrConnectionSessionVoice = "Discord Voice" + ErrorSessionTypeGateway = "Discord Gateway" + ErrorSessionTypeVoice = "Discord Voice" ) -// ErrorDisconnect represents a disconnection error that occurs when -// an attempt to gracefully disconnect from a connection fails. -type ErrorDisconnect struct { +func (e ErrorSession) Error() string { + return fmt.Errorf("SESSION ERROR: %q session %q: state: %q error: %w", e.Type, e.SessionID, e.State, e.Err).Error() +} + +// ErrorSessionDisconnect represents a disconnection error that occurs when +// an attempt to gracefully disconnect from a session fails. +type ErrorSessionDisconnect struct { // Action represents the error that prompted the disconnection (if applicable). Action error // Err represents the error that occurred while disconnecting. Err error - - // Connection represents the name of the connection. - Connection string } -func (e ErrorDisconnect) Error() string { - return fmt.Errorf("error disconnecting from %q\n"+ +func (e ErrorSessionDisconnect) Error() string { + return fmt.Errorf( "\tDisconnect(): %v\n"+ - "\treason: %w\n", - e.Connection, e.Err, e.Action, + "\treason: %w\n", + e.Err, e.Action, ).Error() //lint:ignore ST1005 readability } diff --git a/wrapper/request.go b/wrapper/request.go index a53f7e6..82c7a4c 100644 --- a/wrapper/request.go +++ b/wrapper/request.go @@ -311,7 +311,9 @@ SEND: goto RATELIMIT } - return StatusCodeError(response.StatusCode()) + return ErrorStatusCode{ + StatusCode: response.StatusCode(), + } } // parse the rate limit response data for `retry_after`. @@ -377,7 +379,9 @@ SEND: goto RATELIMIT } - return StatusCodeError(fasthttp.StatusTooManyRequests) + return ErrorStatusCode{ + StatusCode: fasthttp.StatusTooManyRequests, + } // retry the request on a bad gateway server error. case fasthttp.StatusBadGateway: @@ -387,10 +391,14 @@ SEND: goto RATELIMIT } - return StatusCodeError(fasthttp.StatusBadGateway) + return ErrorStatusCode{ + StatusCode: fasthttp.StatusBadGateway, + } default: - return StatusCodeError(response.StatusCode()) + return ErrorStatusCode{ + StatusCode: response.StatusCode(), + } } } diff --git a/wrapper/session.go b/wrapper/session.go index a51c312..45b0617 100644 --- a/wrapper/session.go +++ b/wrapper/session.go @@ -2,22 +2,11 @@ package wrapper import ( "context" + "errors" "fmt" - "runtime" "sync" - "sync/atomic" - "time" - json "github.com/goccy/go-json" - "github.com/switchupcb/disgo/wrapper/socket" "github.com/switchupcb/websocket" - "golang.org/x/sync/errgroup" -) - -const ( - gatewayEndpointParams = "?v=" + VersionDiscordAPI + "&encoding=json" - invalidSessionWaitTime = 1 * time.Second - maxIdentifyLargeThreshold = 250 ) // Session represents a Discord Gateway WebSocket Session. @@ -63,347 +52,36 @@ type Session struct { sync.RWMutex } -// isConnected returns whether the session is connected. -func (s *Session) isConnected() bool { - if s.Context == nil { - return false - } - - select { - case <-s.Context.Done(): - return false - default: - return true - } -} - -// canReconnect determines whether the session is in a valid state to reconnect. -func (s *Session) canReconnect() bool { - return s.ID != "" && s.Endpoint != "" && atomic.LoadInt64(&s.Seq) != 0 -} - // Connect connects a session to the Discord Gateway (WebSocket Connection). func (s *Session) Connect(bot *Client) error { - s.Lock() - defer s.Unlock() - - LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connecting session") - - return s.connect(bot) -} + if bot == nil { + return errors.New("cannot connect session using a nil Client") + } -// connect connects a session to a WebSocket Connection. -func (s *Session) connect(bot *Client) error { if bot.Sessions == nil { - return fmt.Errorf("%q", errNoSessionManager) + bot.Sessions = NewSessionManager() } - s.client_manager = bot.Sessions - if bot.Handlers == nil { bot.Handlers = new(Handlers) } - if s.isConnected() { - return fmt.Errorf("session %q is already connected", s.ID) - } - - var err error - - // request a valid Gateway URL endpoint and response from the Discord API. - gatewayEndpoint := s.Endpoint - var response *GetGatewayBotResponse - - if bot.Config.Gateway.ShardManager != nil { - if response, err = bot.Config.Gateway.ShardManager.SetLimit(bot); err != nil { - return fmt.Errorf("shardmanager: %w", err) - } - } else { - if gatewayEndpoint == "" || !s.canReconnect() { - gateway := GetGatewayBot{} - response, err = gateway.Send(bot) - if err != nil { - return fmt.Errorf("error getting the Gateway API Endpoint: %w", err) - } - - gatewayEndpoint = response.URL - } - } - - // set the maximum allowed (Identify) concurrency rate limit. - // - // https://discord.com/developers/docs/topics/gateway#rate-limiting - if response != nil { - bot.Config.Gateway.RateLimiter.StartTx() - - identifyBucket := bot.Config.Gateway.RateLimiter.GetBucketFromID(FlagGatewaySendEventNameIdentify) - if identifyBucket == nil { - identifyBucket = getBucket() - bot.Config.Gateway.RateLimiter.SetBucketFromID(FlagGatewaySendEventNameIdentify, identifyBucket) - } - - identifyBucket.Limit = int16(response.SessionStartLimit.MaxConcurrency) //nolint:gosec // disable G115 - - if identifyBucket.Expiry.IsZero() { - identifyBucket.Remaining = identifyBucket.Limit - identifyBucket.Expiry = time.Now().Add(FlagGlobalRateLimitIdentifyInterval) - } - - bot.Config.Gateway.RateLimiter.EndTx() - } - - // connect to the Discord Gateway Websocket. - s.manager = new(manager) - s.Context, s.manager.cancel = context.WithCancel(context.Background()) - if s.Conn, _, err = websocket.Dial(s.Context, gatewayEndpoint+gatewayEndpointParams, nil); err != nil { - return fmt.Errorf("error connecting to the Discord Gateway: %w", err) - } - - // set up the Session's Rate Limiter (applied per WebSocket Connection). - // https://discord.com/developers/docs/topics/gateway#rate-limiting - s.RateLimiter = &RateLimit{ //nolint:exhaustruct - ids: make(map[string]string, totalGatewayBucketsPerConnection), - buckets: make(map[string]*Bucket, totalGatewayBucketsPerConnection), - } - - s.RateLimiter.SetBucket( - GlobalRateLimitRouteID, &Bucket{ //nolint:exhaustruct - Limit: FlagGlobalRateLimitGateway, - Remaining: FlagGlobalRateLimitGateway, - Expiry: time.Now().Add(FlagGlobalRateLimitGatewayInterval), - }, - ) - - // handle the incoming Hello event upon connecting to the Gateway. - hello := new(Hello) - if err := readEvent(s, hello); err != nil { - err = fmt.Errorf("error reading initial Hello event: %w", err) - sessionErr := ErrorSession{SessionID: s.ID, Err: err} - if disconnectErr := s.disconnect(FlagClientCloseEventCodeNormal); disconnectErr != nil { - sessionErr.Err = ErrorDisconnect{ - Action: err, - Err: disconnectErr, - Connection: ErrConnectionSession, - } - } - - return sessionErr - } - - for _, handler := range bot.Handlers.Hello { - go handler(hello) - } + s.Lock() + s.client_manager = bot.Sessions - // begin sending heartbeat payloads every heartbeat_interval ms. - ms := time.Millisecond * time.Duration(hello.HeartbeatInterval) - s.heartbeat = &heartbeat{ - interval: ms, - ticker: time.NewTicker(ms), - send: make(chan Heartbeat), - - // add a HeartbeatACK to the HeartbeatACK channel to prevent - // the length of the HeartbeatACK channel from being 0 immediately, - // which results in an attempt to reconnect. - acks: 1, - } + if s.manager != nil && s.State() == SessionStateConnected { + s.Unlock() - // create a goroutine group for the Session. - s.manager.Group, s.manager.signal = errgroup.WithContext(s.Context) - s.manager.err = make(chan error, 1) - - // spawn the heartbeat pulse goroutine. - s.manager.routines.Add(1) - atomic.AddInt32(&s.manager.pulses, 1) - s.manager.Go(func() error { - s.pulse() - return nil - }) - - // spawn the heartbeat beat goroutine. - s.manager.routines.Add(1) - s.manager.Go(func() error { - if err := s.beat(bot); err != nil { - return ErrorSession{ - SessionID: s.ID, - Err: fmt.Errorf("heartbeat: %w", err), - } - } - - return nil - }) - - // send the initial Identify or Resumed packet. - if err := s.initial(bot, 0); err != nil { - sessionErr := ErrorSession{SessionID: s.ID, Err: err} - if disconnectErr := s.disconnect(FlagClientCloseEventCodeNormal); disconnectErr != nil { - sessionErr.Err = ErrorDisconnect{ - Action: err, - Err: disconnectErr, - Connection: ErrConnectionSession, - } - } - - return sessionErr + return fmt.Errorf("session %q is already connected", s.ID) } - // spawn the event listener listen goroutine. - s.manager.routines.Add(1) - s.manager.Go(func() error { - if err := s.listen(bot); err != nil { - return ErrorSession{ - SessionID: s.ID, - Err: fmt.Errorf("listen: %w", err), - } - } - - return nil - }) - - // spawn the manager goroutine. - s.manager.routines.Add(1) - go s.manage(bot) - - // ensure that the Session's goroutines are spawned. - s.manager.routines.Wait() + s.spawnManager(bot) - return nil -} - -// initial sends the initial Identify or Resume packet required to connect to the Gateway, -// then handles the incoming Ready or Resumed packet that indicates a successful connection. -func (s *Session) initial(bot *Client, attempt int) error { - if !s.canReconnect() { - // send an Opcode 2 Identify to the Discord Gateway. - identify := Identify{ - Token: bot.Authentication.Token, - Properties: IdentifyConnectionProperties{ - OS: runtime.GOOS, - Browser: module, - Device: module, - }, - Compress: Pointer(true), - LargeThreshold: Pointer(maxIdentifyLargeThreshold), - Shard: s.Shard, - Presence: bot.Config.Gateway.GatewayPresenceUpdate, - Intents: bot.Config.Gateway.Intents, - } - - if err := identify.SendEvent(bot, s); err != nil { - return err - } - } else { - // send an Opcode 6 Resume to the Discord Gateway to reconnect the session. - resume := Resume{ - Token: bot.Authentication.Token, - SessionID: s.ID, - Seq: atomic.LoadInt64(&s.Seq), - } - - if err := resume.SendEvent(bot, s); err != nil { - return err - } - } - - // handle the incoming Ready, Resumed or Replayed event (or Opcode 9 Invalid Session). - payload := getPayload() - defer putPayload(payload) - if err := socket.Read(s.Context, s.Conn, payload); err != nil { - return fmt.Errorf("error reading initial payload: %w", err) - } + s.manager.signals <- sessionSignalConnect + s.Unlock() - LogPayload(LogSession(Logger.Info(), s.ID), payload.Op, payload.Data).Msg("received initial payload") - - switch payload.Op { - case FlagGatewayOpcodeDispatch: - switch { - // When a connection is successful, the Discord Gateway will respond with a Ready event. - case *payload.EventName == FlagGatewayEventNameReady: - ready := new(Ready) - if err := json.Unmarshal(payload.Data, ready); err != nil { - return fmt.Errorf("error reading ready event: %w", err) - } - - LogSession(Logger.Info(), ready.SessionID).Msg("received Ready event") - - // Configure the session. - s.ID = ready.SessionID - atomic.StoreInt64(&s.Seq, 0) - s.Endpoint = ready.ResumeGatewayURL - - // Store the session in the session manager. - s.client_manager.Gateway.Store(s.ID, s) - - if bot.Config.Gateway.ShardManager != nil { - bot.Config.Gateway.ShardManager.Ready(bot, s, ready) - } - - for _, handler := range bot.Handlers.Ready { - go handler(ready) - } - - // When a reconnection is successful, the Discord Gateway will respond - // by replaying all missed events in order, finalized by a Resumed event. - case *payload.EventName == FlagGatewayEventNameResumed: - LogSession(Logger.Info(), s.ID).Msg("received Resumed event") - - // Store the session in the session manager. - s.client_manager.Gateway.Store(s.ID, s) - - for _, handler := range bot.Handlers.Resumed { - go handler(&Resumed{}) - } - - // When a reconnection is successful, the Discord Gateway will respond - // by replaying all missed events in order, finalized by a Resumed event. - default: - // handle the initial payload(s) until a Resumed event is encountered. - go bot.handle(*payload.EventName, payload.Data) - - for { - replayed := new(GatewayPayload) - if err := socket.Read(s.Context, s.Conn, replayed); err != nil { - return fmt.Errorf("error replaying events: %w", err) - } - - if replayed.Op == FlagGatewayOpcodeDispatch && *replayed.EventName == FlagGatewayEventNameResumed { - LogSession(Logger.Info(), s.ID).Msg("received Resumed event") - - // Store the session in the session manager. - s.client_manager.Gateway.Store(s.ID, s) - - for _, handler := range bot.Handlers.Resumed { - go handler(&Resumed{}) - } - - return nil - } - - go bot.handle(*payload.EventName, payload.Data) - } - } - - // When the maximum concurrency limit has been reached while connecting, or when - // the session does NOT reconnect in time, the Discord Gateway send an Opcode 9 Invalid Session. - case FlagGatewayOpcodeInvalidSession: - // Remove the session from the session manager. - s.client_manager.RemoveGatewaySession(s.ID) - - if attempt < 1 { - // wait for Discord to close the session, then complete a fresh connect. - <-time.NewTimer(invalidSessionWaitTime).C - - s.ID = "" - atomic.StoreInt64(&s.Seq, 0) - if err := s.initial(bot, attempt+1); err != nil { - return err - } - - return nil - } - - return fmt.Errorf("session %q couldn't connect to the Discord Gateway or has invalidated an active session", s.ID) - default: - return fmt.Errorf("session %q received payload %d during connection which is unexpected", s.ID, payload.Op) + if err := <-s.manager.actionError; err != nil { + return err } return nil @@ -412,209 +90,41 @@ func (s *Session) initial(bot *Client, attempt int) error { // Disconnect disconnects a session from the Discord Gateway. func (s *Session) Disconnect() error { s.Lock() - - if !s.isConnected() { + if s.manager == nil || s.State() != SessionStateConnected { s.Unlock() - return fmt.Errorf("session %q is already disconnected", s.ID) - } - - id := s.ID - LogSession(Logger.Info(), id).Msgf("disconnecting session with code %d", FlagClientCloseEventCodeNormal) - - s.manager.signal = context.WithValue(s.manager.signal, keySignal, signalDisconnect) - - if err := s.disconnect(FlagClientCloseEventCodeNormal); err != nil { - s.Unlock() - - return ErrorDisconnect{ - Connection: ErrConnectionSession, - Action: nil, - Err: err, - } + return errors.New("cannot disconnect session that isn't connected") } + s.manager.signals <- sessionSignalDisconnect s.Unlock() - if err := <-s.manager.err; err != nil { + if err := <-s.manager.actionError; err != nil { return err } + // Reset the session. putSession(s) - LogSession(Logger.Info(), id).Msgf("disconnected session with code %d", FlagClientCloseEventCodeNormal) - - return nil -} - -// disconnect disconnects a session from a WebSocket Connection using the given status code. -func (s *Session) disconnect(code int) error { - // cancel the context to kill the goroutines of the Session. - defer s.manager.cancel() - - // Remove the session from the session manager. - s.client_manager.RemoveGatewaySession(s.ID) - - if err := s.Conn.Close(websocket.StatusCode(code), ""); err != nil { - return fmt.Errorf("%w", err) - } - return nil } // Reconnect reconnects an already connected session to the Discord Gateway // by disconnecting the session, then connecting again. func (s *Session) Reconnect(bot *Client) error { - s.reconnect(bot, "reconnecting") - - if err := <-s.manager.err; err != nil { - return err - } - - // connect to the Discord Gateway again. - if err := s.Connect(bot); err != nil { - return fmt.Errorf("error reconnecting session %q: %w", s.ID, err) - } - - return nil -} - -// readEvent is a helper function for reading events from the WebSocket Session. -func readEvent(s *Session, dst any) error { - payload := new(GatewayPayload) - if err := socket.Read(s.Context, s.Conn, payload); err != nil { - return fmt.Errorf("readEvent: %w", err) - } - - if err := json.Unmarshal(payload.Data, dst); err != nil { - return fmt.Errorf("readEvent: %w", err) - } - - return nil -} - -// writeEvent is a helper function for writing events to the WebSocket Session. -func writeEvent(bot *Client, s *Session, op int, name string, dst any) error { -RATELIMIT: - // a single send event is PROCESSED at any point in time. - s.RateLimiter.Lock() - - LogCommand(LogSession(Logger.Trace(), s.ID), bot.ApplicationID, op, name).Msg("processing gateway command") - - for { - s.RateLimiter.StartTx() - - globalBucket := s.RateLimiter.GetBucket(GlobalRateLimitRouteID, "") - - // reset the Global Rate Limit Bucket when the current Bucket has passed its expiry. - if isExpired(globalBucket) { - globalBucket.Reset(time.Now().Add(time.Minute)) - } - - // stop waiting when the Global Rate Limit Bucket is NOT empty. - if isNotEmpty(globalBucket) { - switch op { - // Identify is also bound by the max_concurrency rate limit. - case FlagGatewayOpcodeIdentify: - bot.Config.Gateway.RateLimiter.StartTx() - - identifyBucket := bot.Config.Gateway.RateLimiter.GetBucketFromID(FlagGatewaySendEventNameIdentify) - - if isNotEmpty(identifyBucket) { - if globalBucket != nil { - if globalBucket.Remaining == FlagGlobalRateLimitGateway { - globalBucket.Reset(time.Now().Add(time.Minute)) - } - - globalBucket.Remaining-- - } - - if identifyBucket != nil { - identifyBucket.Remaining-- - } - - bot.Config.Gateway.RateLimiter.EndTx() - s.RateLimiter.EndTx() - - goto SEND - } - - if isExpired(identifyBucket) { - if globalBucket != nil { - if globalBucket.Remaining == FlagGlobalRateLimitGateway { - globalBucket.Reset(time.Now().Add(time.Minute)) - } - - globalBucket.Remaining-- - } - - if identifyBucket != nil { - identifyBucket.Reset(time.Now().Add(FlagGlobalRateLimitIdentifyInterval)) - identifyBucket.Remaining-- - } - - bot.Config.Gateway.RateLimiter.EndTx() - s.RateLimiter.EndTx() - - goto SEND - } - - var wait time.Time - if identifyBucket != nil { - wait = identifyBucket.Expiry - } - - // do NOT block other send events due to a Send Event Rate Limit. - bot.Config.Gateway.RateLimiter.EndTx() - s.RateLimiter.EndTx() - s.RateLimiter.Unlock() - - // reduce CPU usage by blocking the current goroutine - // until it's eligible for action. - if identifyBucket != nil { - <-time.After(time.Until(wait)) - } - - goto RATELIMIT - - default: - if globalBucket != nil { - if globalBucket.Remaining == FlagGlobalRateLimitGateway { - globalBucket.Reset(time.Now().Add(time.Minute)) - } - - globalBucket.Remaining-- - } - - s.RateLimiter.EndTx() - - goto SEND - } - } + s.Lock() + if s.manager == nil || s.State() != SessionStateConnected { + s.Unlock() - s.RateLimiter.EndTx() + return errors.New("cannot reconnect session that isn't connected") } -SEND: - s.RateLimiter.Unlock() - - LogCommand(LogSession(Logger.Trace(), s.ID), bot.ApplicationID, op, name).Msg("sending gateway command") - - // write the event to the WebSocket Connection. - event, err := json.Marshal(dst) - if err != nil { - return fmt.Errorf("writeEvent: %w", err) - } + s.manager.signals <- sessionSignalReconnect + s.Unlock() - if err = socket.Write(s.Context, s.Conn, websocket.MessageBinary, - GatewayPayload{ //nolint:exhaustruct - Op: op, - Data: event, - }); err != nil { - return fmt.Errorf("writeEvent: %w", err) + if err := <-s.manager.actionError; err != nil { + return fmt.Errorf("reconnect: %w", err) } - LogCommand(LogSession(Logger.Trace(), s.ID), bot.ApplicationID, op, name).Msg("sent gateway command") - return nil } diff --git a/wrapper/session_command_write_event.go b/wrapper/session_command_write_event.go new file mode 100644 index 0000000..7817a62 --- /dev/null +++ b/wrapper/session_command_write_event.go @@ -0,0 +1,136 @@ +package wrapper + +import ( + "fmt" + "time" + + json "github.com/goccy/go-json" + "github.com/switchupcb/disgo/wrapper/socket" + "github.com/switchupcb/websocket" +) + +// writeEvent is a helper function for writing events to the WebSocket Session. +func writeEvent(bot *Client, s *Session, op int, name string, dst any) error { +RATELIMIT: + // a single send event is PROCESSED at any point in time. + s.RateLimiter.Lock() + + LogCommand(LogSession(Logger.Trace(), s.ID), bot.ApplicationID, op, name).Msg("processing gateway command") + + for { + s.RateLimiter.StartTx() + + globalBucket := s.RateLimiter.GetBucket(GlobalRateLimitRouteID, "") + + // reset the Global Rate Limit Bucket when the current Bucket has passed its expiry. + if isExpired(globalBucket) { + globalBucket.Reset(time.Now().Add(time.Minute)) + } + + // stop waiting when the Global Rate Limit Bucket is NOT empty. + if isNotEmpty(globalBucket) { + switch op { + // Identify is also bound by the max_concurrency rate limit. + case FlagGatewayOpcodeIdentify: + bot.Config.Gateway.RateLimiter.StartTx() + + identifyBucket := bot.Config.Gateway.RateLimiter.GetBucketFromID(FlagGatewaySendEventNameIdentify) + + if isNotEmpty(identifyBucket) { + if globalBucket != nil { + if globalBucket.Remaining == FlagGlobalRateLimitGateway { + globalBucket.Reset(time.Now().Add(time.Minute)) + } + + globalBucket.Remaining-- + } + + if identifyBucket != nil { + identifyBucket.Remaining-- + } + + bot.Config.Gateway.RateLimiter.EndTx() + s.RateLimiter.EndTx() + + goto SEND + } + + if isExpired(identifyBucket) { + if globalBucket != nil { + if globalBucket.Remaining == FlagGlobalRateLimitGateway { + globalBucket.Reset(time.Now().Add(time.Minute)) + } + + globalBucket.Remaining-- + } + + if identifyBucket != nil { + identifyBucket.Reset(time.Now().Add(FlagGlobalRateLimitIdentifyInterval)) + identifyBucket.Remaining-- + } + + bot.Config.Gateway.RateLimiter.EndTx() + s.RateLimiter.EndTx() + + goto SEND + } + + var wait time.Time + if identifyBucket != nil { + wait = identifyBucket.Expiry + } + + // do NOT block other send events due to a Send Event Rate Limit. + bot.Config.Gateway.RateLimiter.EndTx() + s.RateLimiter.EndTx() + s.RateLimiter.Unlock() + + // reduce CPU usage by blocking the current goroutine + // until it's eligible for action. + if identifyBucket != nil { + <-time.After(time.Until(wait)) + } + + goto RATELIMIT + + default: + if globalBucket != nil { + if globalBucket.Remaining == FlagGlobalRateLimitGateway { + globalBucket.Reset(time.Now().Add(time.Minute)) + } + + globalBucket.Remaining-- + } + + s.RateLimiter.EndTx() + + goto SEND + } + } + + s.RateLimiter.EndTx() + } + +SEND: + s.RateLimiter.Unlock() + + LogCommand(LogSession(Logger.Trace(), s.ID), bot.ApplicationID, op, name).Msg("sending gateway command") + + // write the event to the WebSocket Connection. + event, err := json.Marshal(dst) + if err != nil { + return fmt.Errorf("writeEvent: %w", err) + } + + if err = socket.Write(s.Context, s.Conn, websocket.MessageBinary, + GatewayPayload{ //nolint:exhaustruct + Op: op, + Data: event, + }); err != nil { + return fmt.Errorf("writeEvent: %w", err) + } + + LogCommand(LogSession(Logger.Trace(), s.ID), bot.ApplicationID, op, name).Msg("sent gateway command") + + return nil +} diff --git a/wrapper/session_manager.go b/wrapper/session_manager.go deleted file mode 100644 index d78271b..0000000 --- a/wrapper/session_manager.go +++ /dev/null @@ -1,327 +0,0 @@ -package wrapper - -import ( - "context" - "errors" - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/switchupcb/websocket" - "golang.org/x/sync/errgroup" -) - -// signal represents a manager Context Signal. -type signal string - -// manager Context Signals. -const ( - // keySignal represents the Context key for a manager's signals. - keySignal = signal("signal") - - // keyReason represents the Context key for a manager's reason for disconnection. - keyReason = signal("reason") - - // signalDisconnect indicates that a disconnection was called purposefully. - signalDisconnect = 1 - - // signalReconnect signals the manager to reconnect upon a successful disconnection. - signalReconnect = 2 -) - -// manager represents a manager of a Session's goroutines. -type manager struct { - // routines represents a goroutine counter that ensures all of the Session's goroutines - // are spawned prior to returning from connect(). - routines sync.WaitGroup - - // cancel represents the cancellation signal for a Session's Context. - cancel context.CancelFunc - - // signal represents the Context Signal for a Session upon disconnection. - signal context.Context - - // err represents the error that this manager detected upon the closing of a Session's goroutines. - err chan error - - // pulses represents the amount of goroutines that can generate heartbeat pulses. - // - // pulses ensures that pulse goroutines always have a receiver channel for heartbeats - // by preventing the heartbeat goroutine from closing before other pulse goroutines. - pulses int32 - - // errgroup ensures that all of the Session's goroutines are closed prior to returning - // from Disconnect(). - // - // IMPLEMENTATION - // A Session's Context is cancelled to indicate a disconnection: - // 1. Context is canceled (via function call or error). - // 2. Goroutines read s.Context.Done() and close accordingly. - // 3. errgroup.Wait() is called to block until all goroutines are closed. - // 4. errgroup.Wait() result is returned once all goroutines are closed. - // - // As a result of 3, disconnection must NEVER occur on a Session's goroutine. - // Otherwise, errorgroup.Wait() blocks the goroutine it's waiting on to be closed. - // In other words, disconnection MUST occur on another goroutine. - // - // ERRGROUP - // errgroup manages a Session's goroutines: listen, heartbeat, pulse, respond. - // - // Upon connection, an (unmanaged) manager goroutine is used to monitor errgroup.Wait(). - // - // When a disconnection is called purposefully, s.Conn and s.Context is closed. - // This results in the eventual closing of a Session's goroutines. - // When errgroup.Wait() returns nil, it indicates a successful disconnection. - // Otherwise, a DisconnectError will be returned. - // - // When an error occurs in a Session's goroutines, errgroup cancels the Session's context. - // This results in the eventual closing of a Session's goroutines. - // When errgroup.Wait() returns err (origin error), the state of the disconnection is managed - // (since s.Conn may or may not need closing). - // When managing the state of disconnection is successful, the manager routine returns err. - // Otherwise, a DisconnectError (which includes err) will be returned. - // - // The above indicates that the manager manages the STATE of disconnection, while disconnect() - // performs the ACTION of disconnection. - // - // This implementation allows a caller of disconnect() to use its return value to await disconnection. - // For example, a channel can be used to receive the value that the manager routine sends. - // Disconnect() is modified in this way to allow the end-user (developer) to only return from Disconnect() - // when disconnection is fully completed (with goroutines closed). - *errgroup.Group -} - -// decrementPulses safely decrements the pulses counter of a Session manager. -func (s *Session) decrementPulses() { - s.Lock() - defer s.Unlock() - - atomic.AddInt32(&s.manager.pulses, -1) -} - -// logClose safely logs the close of a Session's goroutine. -func (s *Session) logClose(routine string) { - LogSession(Logger.Info(), s.ID).Msgf("closed %s routine", routine) -} - -// reconnect spawns a goroutine for reconnection which prompts the manager -// to reconnect upon a disconnection. -func (s *Session) reconnect(bot *Client, reason string) { - s.manager.Go(func() error { - s.Lock() - defer s.logClose("reconnect") - defer s.Unlock() - - LogSession(Logger.Info(), s.ID).Msg(reason) - - s.manager.signal = context.WithValue(s.manager.signal, keySignal, signalReconnect) - if err := s.disconnect(FlagClientCloseEventCodeReconnect); err != nil { - return fmt.Errorf("reconnect: %w", err) - } - - // connect to the Discord Gateway again. - s.Context = nil - if err := s.connect(bot); err != nil { - return fmt.Errorf("reconnect: %w", err) - } - - return nil - }) -} - -// manage manages a Session's goroutines. -func (s *Session) manage(bot *Client) { - s.manager.routines.Done() - defer func() { - s.Lock() - s.logClose("manager") - s.Unlock() - }() - - // wait until all of a Session's goroutines are closed. - err := s.manager.Wait() - s.Lock() - defer s.Unlock() - - // log the reason for disconnection (if applicable). - if reason := s.manager.signal.Value(keyReason); reason != nil { - LogSession(Logger.Info(), s.ID).Msgf("%v", reason) - } - - // when a signal is provided, it indicates that the disconnection was purposeful. - signal := s.manager.signal.Value(keySignal) - switch signal { - case signalDisconnect: - LogSession(Logger.Info(), s.ID).Msg("successfully disconnected") - - s.manager.err <- nil - - return - - case signalReconnect: - LogSession(Logger.Info(), s.ID).Msg("successfully disconnected (while reconnecting)") - - // allow Discord to close the session. - <-time.After(time.Second) - - s.manager.err <- nil - - return - } - - // when an error caused goroutines to close, manage the state of disconnection. - if err != nil { - disconnectErr := new(ErrorDisconnect) - closeErr := new(websocket.CloseError) - switch { - // when an error occurs from a purposeful disconnection. - case errors.As(err, disconnectErr): - s.manager.err <- err - - // when an error occurs from a WebSocket Close Error. - case errors.As(err, closeErr): - if bot == nil { - s.manager.err <- fmt.Errorf("gateway websocket close error, but unable to reconnect: %w", err) - } - - s.manager.err <- s.handleGatewayCloseError(bot, closeErr) - - default: - if cErr := s.Conn.Close(websocket.StatusCode(FlagClientCloseEventCodeAway), ""); cErr != nil { - s.manager.err <- ErrorDisconnect{ - Action: err, - Err: cErr, - Connection: ErrConnectionSession, - } - - return - } - - s.manager.err <- err - } - - return - } - - s.manager.err <- nil -} - -// handleGatewayCloseError handles a WebSocket CloseError. -func (s *Session) handleGatewayCloseError(bot *Client, closeErr *websocket.CloseError) error { - code, ok := GatewayCloseEventCodes[int(closeErr.Code)] - switch ok { - // Gateway Close Event Code is known. - case true: - LogSession(Logger.Info(), s.ID). - Msgf("received Gateway Close Event Code %d %s: %s", - code.Code, code.Description, code.Explanation, - ) - - if code.Reconnect { - s.reconnect(bot, fmt.Sprintf("reconnecting due to Gateway Close Event Code %d", code.Code)) - - return nil - } - - return closeErr - - // Gateway Close Event Code is unknown. - default: - - // when another goroutine calls disconnect(), - // s.Conn.Close is called before s.cancel which will result in - // a CloseError with the close code that Disgo uses to reconnect. - if closeErr.Code == websocket.StatusCode(FlagClientCloseEventCodeReconnect) { - return nil - } - - LogSession(Logger.Info(), s.ID). - Msgf("received unknown Gateway Close Event Code %d with reason %q", - closeErr.Code, closeErr.Reason, - ) - - return closeErr - } -} - -const ( - // SignalNone indicates that Wait() was called on an already disconnected session. - SignalNone = 0 - - // SignalDisconnect indicates that a disconnection was called purposefully. - SignalDisconnect = signalDisconnect - - // SignalReconnect indicates that a disconnection was called purposefully in order to reconnect. - SignalReconnect = signalReconnect - - // SignalError indicates that a disconnection occurred as an error. - SignalError = 3 - - // SignalDisconnectError indicates that a disconnection was called purposefully (for any reason), - // but the Session experienced an error while disconnecting. - SignalDisconnectError = 4 - - // SignalUndefined indicates that a disconnection occurred in an undefined manner. - // - // This signal should NEVER be returned: If it is, report it. - SignalUndefined = 5 -) - -// Wait blocks until the calling Session has disconnected, then returns the reason -// (disgo.SignalReason) for disconnecting and the disconnection error (if it exists). -// -// If Wait() is called on a Session that isn't connected, it will return immediately -// with code SignalNone. -// -// It's NOT recommended to modify a Session after it has disconnected, -// since it will be cleared and placed into a memory pool shortly after. -func (s *Session) Wait() (int, error) { - if !s.isConnected() { - return SignalNone, nil - } - - // NOTE: Wait() is equivalent to the s.manage() s.manager.Wait() handling logic, - // but without the management of the disconnection state, - // and without the usage of a channel that tells another goroutine to unblock. - // - // wait until all of a Session's goroutines are closed. - err := s.manager.Wait() - s.Lock() - defer s.Unlock() - - // when a signal is provided, it indicates that the disconnection was purposeful. - signal := s.manager.signal.Value(keySignal) - switch signal { - case signalDisconnect: - return SignalDisconnect, nil - - case signalReconnect: - return SignalReconnect, nil - } - - // when an error caused goroutines to close. - if err != nil { - disconnectErr := new(ErrorDisconnect) - closeErr := new(websocket.CloseError) - switch { - // when an error occurs from a purposeful disconnection. - case errors.As(err, disconnectErr): - if signal != nil { - if signalValue, ok := signal.(int); ok { - return signalValue, err //nolint:wrapcheck - } - } - - return SignalDisconnectError, err //nolint:wrapcheck - - // when an error occurs from a WebSocket Close Error. - case errors.As(err, closeErr): - return SignalError, s.handleGatewayCloseError(nil, closeErr) - } - - return SignalError, err //nolint:wrapcheck - } - - return SignalUndefined, nil -} diff --git a/wrapper/session_routine_coroner.go b/wrapper/session_routine_coroner.go new file mode 100644 index 0000000..6b14d01 --- /dev/null +++ b/wrapper/session_routine_coroner.go @@ -0,0 +1,37 @@ +package wrapper + +// coroner investigates when a Session's goroutines are shutdown. +func (s *Session) coroner() { + // wait until all the manager goroutines is closed. + err := s.manager.coroner.Wait() + + s.Lock() + + // report the disconnection error + s.manager.actionError <- err + close(s.manager.actionError) + + // remove the session from the client. + s.client_manager.RemoveGatewaySession(s.ID) + + s.logClose("coroner") + s.Unlock() +} + +// Wait blocks until the calling Session is inactive (due to a final disconnect), +// then returns the Session's state and the disconnection error (if it exists). +// +// If Wait() is called on a Session that isn't connected, it will return immediately +// with code SessionStateNew. +// +// A disconnected session is reset and placed into a memory pool, +// so do NOT modify a Session after it disconnects. +func (s *Session) Wait() (string, error) { + if s.State() == SessionStateNew { + return SessionStateNew, nil + } + + err := s.manager.coroner.Wait() + + return s.State(), err +} diff --git a/wrapper/session_heartbeat.go b/wrapper/session_routine_heartbeat.go similarity index 92% rename from wrapper/session_heartbeat.go rename to wrapper/session_routine_heartbeat.go index f9eb23c..e2bc4f6 100644 --- a/wrapper/session_heartbeat.go +++ b/wrapper/session_routine_heartbeat.go @@ -36,7 +36,7 @@ func (s *Session) Monitor() uint32 { func (s *Session) beat(bot *Client) error { s.manager.routines.Done() - // ensure that all pulse routines are closed prior to closing. + // confirm all pulse routines are closed prior to closing. defer func() { for { select { @@ -62,7 +62,7 @@ func (s *Session) beat(bot *Client) error { if atomic.LoadUint32(&s.heartbeat.acks) == 0 { s.Unlock() - s.reconnect(bot, "attempting to reconnect session due to no HeartbeatACK") + s.reconnect("attempting to reconnect session due to no HeartbeatACK") return nil } @@ -153,7 +153,7 @@ func (s *Session) respond(data json.RawMessage) error { s.Lock() - // ensure that the heartbeat routine has not been closed. + // confirm the heartbeat routine has not been closed. if atomic.LoadInt32(&s.manager.pulses) <= 1 { s.Unlock() @@ -177,3 +177,11 @@ func (s *Session) respond(data json.RawMessage) error { return nil } + +// decrementPulses safely decrements the pulses counter. +func (s *Session) decrementPulses() { + s.Lock() + defer s.Unlock() + + atomic.AddInt32(&s.manager.pulses, -1) +} diff --git a/wrapper/session_listener.go b/wrapper/session_routine_listener.go similarity index 92% rename from wrapper/session_listener.go rename to wrapper/session_routine_listener.go index 4c85f9a..4f71912 100644 --- a/wrapper/session_listener.go +++ b/wrapper/session_routine_listener.go @@ -28,16 +28,20 @@ func (s *Session) listen(bot *Client) error { } s.Lock() - defer s.Unlock() defer s.logClose("listen") + defer s.Unlock() - select { - case <-s.Context.Done(): - return nil + if s.Context != nil { + select { + case <-s.Context.Done(): + return nil - default: - return err + default: + return err + } } + + return nil } // onPayload handles an Discord Gateway Payload. @@ -73,7 +77,7 @@ func (s *Session) onPayload(bot *Client, payload GatewayPayload) error { // occurs when the Discord Gateway is shutting down the connection, while signalling the client to reconnect. case FlagGatewayOpcodeReconnect: - s.reconnect(bot, "reconnecting session due to Opcode 7 Reconnect") + s.reconnect("reconnecting session due to Opcode 7 Reconnect") return nil diff --git a/wrapper/session_routine_manager.go b/wrapper/session_routine_manager.go new file mode 100644 index 0000000..e19a8d8 --- /dev/null +++ b/wrapper/session_routine_manager.go @@ -0,0 +1,375 @@ +package wrapper + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/switchupcb/websocket" + "golang.org/x/sync/errgroup" +) + +// manager represents a manager of a Session's goroutines. +type manager struct { + // state represents the state of the Session's connection to Discord. + state string + + // stateMutex is used to protect the Session's manager state from data races + // by providing transactional functionality. + stateMutex sync.RWMutex + + // signals represents a channel of signals. + signals chan uint8 + + // coroner represents a goroutine group to track the manager routine. + coroner errgroup.Group + + // routines represents a goroutine counter that ensures all of the Session's goroutines + // are spawned prior to returning from connect(). + routines sync.WaitGroup + + // cancel represents the cancellation signal for a Session's Context. + cancel context.CancelFunc + + // actionError represents the error this manager detects upon a connection action (e.g., connecting, disconnecting). + actionError chan error + + // pulses represents the amount of goroutines that can generate heartbeat pulses. + // + // pulses ensures that pulse goroutines always have a receiver channel for heartbeats + // by preventing the heartbeat goroutine from closing before other pulse goroutines. + pulses int32 + + // errgroup ensures all of the Session's goroutines are closed prior to returning + // from Disconnect(). + // + // IMPLEMENTATION + // A session is managed by multiple goroutine groups. + // + // 1. The "coroner" is an unmanaged routine which reports an error from the manager when the Session is no longer active. + // a. The "coroner" is responsible for returning an error to a session Disconnect() call. + // b. The "coroner" is not responsible for returning an error to a session Connect() or Reconnect() call, + // because the coroner only receives an error from a manager when the manager is dead. + // + // 2. The "manager" is a tracked routine (by the coroner) which manages the state of the connection to Discord. + // a. The "manager" is responsible for returning an error to a session Connect() or Reconnect() call, + // because the manager cannot shutdown after these calls are made (in comparison to a final Disconnect()). + // b. The manager manages a Session's goroutines: listen, heartbeat, pulse, respond. + // + // USING CONTEXT CANCELLATION (to deactivate the session): + // 1. Context is cancelled (via function call or error in a goroutine). + // 2. Goroutines read s.Context.Done() and close accordingly. + // 3. errgroup.Wait() is called from a manager to block until all goroutines are closed. + // 4. errgroup.Wait() is called from a coroner to return a result once manager goroutine is dead. + // + // + // USING ERRGROUPS (to close goroutines). + // s.Conn and s.Context is closed when a disconnection is called purposefully. + // - This results in the eventual closing of a Session's goroutines. + // - A successful disconnection has occurred when errgroup.Wait() returns nil. + // - Otherwise, an error is returned. + // + *errgroup.Group +} + +// spawnManager spawns a tracked manager. +func (s *Session) spawnManager(bot *Client) { + s.manager = new(manager) + + s.Context, s.manager.cancel = context.WithCancel(context.Background()) + s.manager.Group, s.Context = errgroup.WithContext(s.Context) + s.manager.signals = make(chan uint8) + s.manager.actionError = make(chan error, 1) + + // spawn the manager goroutine. + s.manager.coroner.Go(func() error { + if err := s.manage(bot); err != nil { + return fmt.Errorf("manager: %w", err) + } + + return nil + }) +} + +// Session States represent the state of the Session's connection to Discord. +const ( + SessionStateNew = "" + + SessionStateConnecting = "connecting (before websocket connection)" + SessionStateConnectingWebsocket = "connecting (with websocket connection)" + SessionStateConnected = "connected" + + SessionStateDisconnecting = "disconnecting (purposefully)" + SessionStateDisconnectingError = "disconnecting (due to an error)" + SessionStateDisconnectingReconnect = "disconnecting (while reconnecting)" + + SessionStateDisconnectedFinal = "disconnected (after connection)" + SessionStateDisconnectedError = "disconnected (due to an error)" + SessionStateDisconnectedReconnect = "disconnected (while reconnecting)" + + SessionStateReconnecting = "reconnecting" +) + +// State returns the state of the Session's connection to Discord. +func (s *Session) State() string { + s.manager.stateMutex.RLock() + defer s.manager.stateMutex.RUnlock() + + return s.manager.state +} + +// setState sets the state of a Session. +func (s *Session) setState(state string) { + s.manager.stateMutex.Lock() + s.manager.state = state + s.manager.stateMutex.Unlock() +} + +// canReconnect returns whether the Session's fields are in a valid state to reconnect. +func (s *Session) canReconnect() bool { + return s.ID != "" && s.Endpoint != "" && atomic.LoadInt64(&s.Seq) != 0 +} + +// Session Signals represent manager signals to perform actions to the Session. +const ( + sessionSignalConnect = 1 + sessionSignalDisconnect = 2 + sessionSignalReconnect = 3 +) + +// manage manages a Session's goroutines. +func (s *Session) manage(bot *Client) error { + // spawn the coroner once the manager routine is alive. + go s.coroner() + + defer func() { + if s.State() != SessionStateDisconnectedReconnect { + s.Unlock() + } + + // wait until the previous connection's manager goroutines are closed. + _ = s.manager.Wait() + + s.logClose("manager") + }() + + var managedErr error + + for { + select { + case <-s.Context.Done(): + if s.State() == SessionStateDisconnectedReconnect { + break + } + + // wait until the previous connection's manager goroutines are closed. + err := s.manager.Wait() + if err != nil { + closeErr := new(websocket.CloseError) + + if errors.As(err, closeErr) { + if vErr := s.validateGatewayCloseError(closeErr); vErr == nil { + // reconnect from a state where + s.setState(SessionStateDisconnectedReconnect) + + // manager routines must be reset + s.Context, s.manager.cancel = context.WithCancel(context.Background()) //nolint:fatcontext + s.manager.Group, s.Context = errgroup.WithContext(s.Context) + + go func() { + // send a connection signal. + s.manager.signals <- sessionSignalConnect + + // read the s.manager.actionError send from a successful connection. + e := <-s.manager.actionError + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("captured result from close event reconnect: %q", e) + }() + + s.Lock() + + break // to reconnect from the connect case logic. + } // vErr == nil + } // errors.As + + // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 + if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { + err = nil + } + } // err != nil + + s.Lock() + + return err + + case signal := <-s.manager.signals: + switch signal { + case sessionSignalConnect: + if s.State() != SessionStateDisconnectedReconnect { + s.Lock() + } else { + s.Unlock() + + // wait until the previous connection's manager goroutines are closed. + _ = s.manager.Wait() + + s.Lock() + } + + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connecting session") + + if err := s.connect(bot); err != nil { + managedErr = ErrorSession{SessionID: s.ID, State: s.State(), Type: ErrorSessionTypeGateway, Err: err} + + switch s.State() { + case SessionStateConnectingWebsocket: + go func() { + // send a disconnection signal. + s.manager.signals <- sessionSignalDisconnect + }() + + // case SessionStateNew, SessionStateConnecting... + default: + return managedErr + } + + break // to handle the error in the disconnect case logic. + } + + s.setState(SessionStateConnected) + s.manager.actionError <- nil + s.Unlock() + + case sessionSignalDisconnect: + if managedErr == nil && s.State() != SessionStateReconnecting { + s.Lock() + } + + // update the session's state and client close event code. + code := FlagClientCloseEventCodeNormal + + switch { + case managedErr != nil: + s.setState(SessionStateDisconnectingError) + case s.State() == SessionStateReconnecting: + s.setState(SessionStateDisconnectingReconnect) + code = FlagClientCloseEventCodeReconnect + default: + s.setState(SessionStateDisconnecting) + } + + LogSession(Logger.Info(), s.ID).Msgf("%q session with code %d", s.State(), code) + + // disconnect the session. + if err := s.disconnect(code); err != nil { + managedErr = ErrorSession{ + SessionID: s.ID, + State: s.State(), + Type: ErrorSessionTypeGateway, + Err: ErrorSessionDisconnect{ + Action: managedErr, + Err: err, + }, + } + + // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 + if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { + managedErr = nil + } + + if s.State() != SessionStateDisconnectingReconnect { + return managedErr + } + + // validate error when reconnecting + closeErr := new(websocket.CloseError) + if errors.As(managedErr, closeErr) { + if managedErr = s.validateGatewayCloseError(closeErr); managedErr != nil { + return managedErr + } + } + } // disconnect + + // update the session's state. + switch { + case s.State() == SessionStateDisconnectingError: + s.setState(SessionStateDisconnectedError) + + case s.State() == SessionStateDisconnectingReconnect: + s.setState(SessionStateDisconnectedReconnect) + + default: + s.setState(SessionStateDisconnectedFinal) + } + + LogSession(Logger.Info(), s.ID).Msgf("%q session with code %d", s.State(), code) + + if s.State() == SessionStateDisconnectedReconnect { + // allow Discord to close the session. + <-time.After(time.Second) + + go func() { + // send a connection signal. + s.manager.signals <- sessionSignalConnect + }() + + break + } + + // Destroy the manager when the bot isn't reconnecting. + if managedErr != nil { + return managedErr + } + + return nil + + case sessionSignalReconnect: + s.Lock() + s.setState(SessionStateReconnecting) + + go func() { + // send a disconnection signal. + s.manager.signals <- sessionSignalDisconnect + }() + } + } // select + } // for +} + +// logClose safely logs the close of a Session's goroutine. +func (s *Session) logClose(routine string) { + LogSession(Logger.Info(), s.ID).Msgf("closed %s routine", routine) +} + +// validateGatewayCloseError validates a WebSocket CloseError +// and returns whether to reconnect (when error == nil). +func (s *Session) validateGatewayCloseError(closeErr *websocket.CloseError) error { + code, ok := GatewayCloseEventCodes[int(closeErr.Code)] + + switch ok { + // Gateway Close Event Code is known. + case true: + LogSession(Logger.Info(), s.ID). + Msgf("received Gateway Close Event Code %d %s: %s", + code.Code, code.Description, code.Explanation, + ) + + if code.Reconnect { + return nil + } + + return closeErr + + // Gateway Close Event Code is unknown. + default: + LogSession(Logger.Info(), s.ID). + Msgf("received unknown Gateway Close Event Code %d with reason %q", + closeErr.Code, closeErr.Reason, + ) + + return closeErr + } +} diff --git a/wrapper/session_routine_manager_actions.go b/wrapper/session_routine_manager_actions.go new file mode 100644 index 0000000..8b83a8c --- /dev/null +++ b/wrapper/session_routine_manager_actions.go @@ -0,0 +1,325 @@ +package wrapper + +import ( + "context" + "fmt" + "runtime" + "sync/atomic" + "time" + + json "github.com/goccy/go-json" + "github.com/switchupcb/disgo/wrapper/socket" + "github.com/switchupcb/websocket" +) + +const ( + gatewayEndpointParams = "?v=" + VersionDiscordAPI + "&encoding=json" + invalidSessionWaitTime = 1 * time.Second + maxIdentifyLargeThreshold = 250 +) + +// connect connects a session to a WebSocket Connection. +func (s *Session) connect(bot *Client) error { + var err error + + // request a valid Gateway URL endpoint and response from the Discord API. + gatewayEndpoint := s.Endpoint + var response *GetGatewayBotResponse + + if bot.Config.Gateway.ShardManager != nil { + if response, err = bot.Config.Gateway.ShardManager.SetLimit(bot); err != nil { + return fmt.Errorf("shardmanager: %w", err) + } + } else { + if gatewayEndpoint == "" || !s.canReconnect() { + gateway := GetGatewayBot{} + response, err = gateway.Send(bot) + if err != nil { + return fmt.Errorf("error getting the Gateway API Endpoint: %w", err) + } + + gatewayEndpoint = response.URL + } + } + + // set the maximum allowed (Identify) concurrency rate limit for the bot. + // + // https://discord.com/developers/docs/topics/gateway#rate-limiting + if response != nil { + bot.Config.Gateway.RateLimiter.StartTx() + + identifyBucket := bot.Config.Gateway.RateLimiter.GetBucketFromID(FlagGatewaySendEventNameIdentify) + if identifyBucket == nil { + identifyBucket = getBucket() + bot.Config.Gateway.RateLimiter.SetBucketFromID(FlagGatewaySendEventNameIdentify, identifyBucket) + } + + identifyBucket.Limit = int16(response.SessionStartLimit.MaxConcurrency) //nolint:gosec // disable G115 + + if identifyBucket.Expiry.IsZero() { + identifyBucket.Remaining = identifyBucket.Limit + identifyBucket.Expiry = time.Now().Add(FlagGlobalRateLimitIdentifyInterval) + } + + bot.Config.Gateway.RateLimiter.EndTx() + } + + // set up the Session's Rate Limiter (applied per WebSocket Connection). + // https://discord.com/developers/docs/topics/gateway#rate-limiting + s.RateLimiter = &RateLimit{ //nolint:exhaustruct + ids: make(map[string]string, totalGatewayBucketsPerConnection), + buckets: make(map[string]*Bucket, totalGatewayBucketsPerConnection), + } + + s.RateLimiter.SetBucket( + GlobalRateLimitRouteID, &Bucket{ //nolint:exhaustruct + Limit: FlagGlobalRateLimitGateway, + Remaining: FlagGlobalRateLimitGateway, + Expiry: time.Now().Add(FlagGlobalRateLimitGatewayInterval), + }, + ) + + // connect to the Discord Gateway Websocket. + s.Context, s.manager.cancel = context.WithCancel(context.Background()) + if s.Conn, _, err = websocket.Dial(s.Context, gatewayEndpoint+gatewayEndpointParams, nil); err != nil { + return fmt.Errorf("error connecting to the Discord Gateway: %w", err) + } + + s.setState(SessionStateConnectingWebsocket) + + // handle the incoming Hello event upon connecting to the Gateway. + hello := new(Hello) + if err := readEvent(s, hello); err != nil { + return fmt.Errorf("error reading initial Hello event: %w", err) + } + + for _, handler := range bot.Handlers.Hello { + go handler(hello) + } + + // begin sending heartbeat payloads every heartbeat_interval ms. + ms := time.Millisecond * time.Duration(hello.HeartbeatInterval) + s.heartbeat = &heartbeat{ + interval: ms, + ticker: time.NewTicker(ms), + send: make(chan Heartbeat), + + // add a HeartbeatACK to the HeartbeatACK channel to prevent + // the length of the HeartbeatACK channel from being 0 immediately, + // which results in an attempt to reconnect. + acks: 1, + } + + // spawn the heartbeat pulse goroutine. + s.manager.routines.Add(1) + s.manager.Go(func() error { + atomic.AddInt32(&s.manager.pulses, 1) + s.pulse() + + return nil + }) + + // spawn the heartbeat beat goroutine. + s.manager.routines.Add(1) + s.manager.Go(func() error { + if err := s.beat(bot); err != nil { + return fmt.Errorf("heartbeat: %w", err) + } + + return nil + }) + + // send the initial Identify or Resumed packet. + if err := s.initial(bot, 0); err != nil { + return fmt.Errorf("initial: %w", err) + } + + // spawn the event listener listen goroutine. + s.manager.routines.Add(1) + s.manager.Go(func() error { + if err := s.listen(bot); err != nil { + return fmt.Errorf("listen: %w", err) + } + + return nil + }) + + // confirm the Session's goroutines are spawned. + s.manager.routines.Wait() + + return nil +} + +// initial sends the initial Identify or Resume packet required to connect to the Gateway, +// then handles the incoming Ready or Resumed packet that indicates a successful connection. +func (s *Session) initial(bot *Client, attempt int) error { + if !s.canReconnect() { + // send an Opcode 2 Identify to the Discord Gateway. + identify := Identify{ + Token: bot.Authentication.Token, + Properties: IdentifyConnectionProperties{ + OS: runtime.GOOS, + Browser: module, + Device: module, + }, + Compress: Pointer(true), + LargeThreshold: Pointer(maxIdentifyLargeThreshold), + Shard: s.Shard, + Presence: bot.Config.Gateway.GatewayPresenceUpdate, + Intents: bot.Config.Gateway.Intents, + } + + if err := identify.SendEvent(bot, s); err != nil { + return err + } + } else { + // send an Opcode 6 Resume to the Discord Gateway to reconnect the session. + resume := Resume{ + Token: bot.Authentication.Token, + SessionID: s.ID, + Seq: atomic.LoadInt64(&s.Seq), + } + + if err := resume.SendEvent(bot, s); err != nil { + return err + } + } + + // handle the incoming Ready, Resumed or Replayed event (or Opcode 9 Invalid Session). + payload := getPayload() + defer putPayload(payload) + if err := socket.Read(s.Context, s.Conn, payload); err != nil { + return fmt.Errorf("error reading initial payload: %w", err) + } + + LogPayload(LogSession(Logger.Info(), s.ID), payload.Op, payload.Data).Msg("received initial payload") + + switch payload.Op { + case FlagGatewayOpcodeDispatch: + switch { + // When a connection is successful, the Discord Gateway will respond with a Ready event. + case *payload.EventName == FlagGatewayEventNameReady: + ready := new(Ready) + if err := json.Unmarshal(payload.Data, ready); err != nil { + return fmt.Errorf("error reading ready event: %w", err) + } + + LogSession(Logger.Info(), ready.SessionID).Msg("received Ready event") + + // Configure the session. + s.ID = ready.SessionID + atomic.StoreInt64(&s.Seq, 0) + s.Endpoint = ready.ResumeGatewayURL + + // Store the session in the session manager. + s.client_manager.Gateway.Store(s.ID, s) + + if bot.Config.Gateway.ShardManager != nil { + bot.Config.Gateway.ShardManager.Ready(bot, s, ready) + } + + for _, handler := range bot.Handlers.Ready { + go handler(ready) + } + + // When a reconnection is successful, the Discord Gateway will respond + // by replaying all missed events in order, finalized by a Resumed event. + case *payload.EventName == FlagGatewayEventNameResumed: + LogSession(Logger.Info(), s.ID).Msg("received Resumed event") + + // Store the session in the session manager. + s.client_manager.Gateway.Store(s.ID, s) + + for _, handler := range bot.Handlers.Resumed { + go handler(&Resumed{}) + } + + // When a reconnection is successful, the Discord Gateway will respond + // by replaying all missed events in order, finalized by a Resumed event. + default: + // handle the initial payload(s) until a Resumed event is encountered. + go bot.handle(*payload.EventName, payload.Data) + + for { + replayed := new(GatewayPayload) + if err := socket.Read(s.Context, s.Conn, replayed); err != nil { + return fmt.Errorf("error replaying events: %w", err) + } + + if replayed.Op == FlagGatewayOpcodeDispatch && *replayed.EventName == FlagGatewayEventNameResumed { + LogSession(Logger.Info(), s.ID).Msg("received Resumed event") + + // Store the session in the session manager. + s.client_manager.Gateway.Store(s.ID, s) + + for _, handler := range bot.Handlers.Resumed { + go handler(&Resumed{}) + } + + return nil + } + + go bot.handle(*payload.EventName, payload.Data) + } + } + + // When the maximum concurrency limit has been reached while connecting, or when + // the session does NOT reconnect in time, the Discord Gateway send an Opcode 9 Invalid Session. + case FlagGatewayOpcodeInvalidSession: + // Remove the session from the session manager. + s.client_manager.RemoveGatewaySession(s.ID) + + if attempt < 1 { + // wait for Discord to close the session, then complete a fresh connect. + <-time.NewTimer(invalidSessionWaitTime).C + + s.ID = "" + atomic.StoreInt64(&s.Seq, 0) + if err := s.initial(bot, attempt+1); err != nil { + return err + } + + return nil + } + + return fmt.Errorf("session %q couldn't connect to the Discord Gateway or has invalidated an active session", s.ID) + default: + return fmt.Errorf("session %q received unexpected payload %d during connection", s.ID, payload.Op) + } + + return nil +} + +// disconnect disconnects a session from a WebSocket Connection using the given status code. +func (s *Session) disconnect(code int) error { + // cancel the context to kill the goroutines of the Session. + defer s.manager.cancel() + + if err := s.Conn.Close(websocket.StatusCode(code), ""); err != nil { + return fmt.Errorf("%w", err) + } + + return nil +} + +// reconnect spawns a goroutine for reconnection which prompts the manager +// to reconnect upon a disconnection. +func (s *Session) reconnect(reason string) { + LogSession(Logger.Info(), s.ID).Msg(reason) + + s.manager.signals <- sessionSignalReconnect +} + +// readEvent is a helper function for reading events from the WebSocket Session. +func readEvent(s *Session, dst any) error { + payload := new(GatewayPayload) + if err := socket.Read(s.Context, s.Conn, payload); err != nil { + return fmt.Errorf("readEvent: %w", err) + } + + if err := json.Unmarshal(payload.Data, dst); err != nil { + return fmt.Errorf("readEvent: %w", err) + } + + return nil +} diff --git a/wrapper/tests/integration/ratelimit_test.go b/wrapper/tests/integration/ratelimit_test.go index dca1684..386d7b5 100644 --- a/wrapper/tests/integration/ratelimit_test.go +++ b/wrapper/tests/integration/ratelimit_test.go @@ -62,7 +62,7 @@ func TestRequestGlobalRateLimit(t *testing.T) { t.Fatalf("%v", err) } - // ensure that the next test starts with a full bucket. + // confirm the next test starts with a full bucket. time.After(time.Second * 2) } @@ -117,7 +117,7 @@ func TestRequestRouteRateLimit(t *testing.T) { t.Fatalf("%v", err) } - // ensure that the next test starts with a full bucket. + // confirm the next test starts with a full bucket. time.After(time.Second * 2) } diff --git a/wrapper/tests/integration/session_test.go b/wrapper/tests/integration/session_test.go index 94ea9c9..7cb5217 100644 --- a/wrapper/tests/integration/session_test.go +++ b/wrapper/tests/integration/session_test.go @@ -9,35 +9,8 @@ import ( . "github.com/switchupcb/disgo" ) -// TestSessionManager tests the Session Manager check at the start of a Session Connect() call. -func TestSessionManager(t *testing.T) { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - - bot := &Client{ - Authentication: BotToken(os.Getenv("TOKEN")), - Config: DefaultConfig(), - Handlers: new(Handlers), - } - - s := NewSession() - - // connecting to a connected session should result in an error. - err := s.Connect(bot) - if err == nil { - // disconnect from the Discord Gateway (WebSocket Connection). - if err := s.Disconnect(); err != nil { - t.Fatalf("%v", err) - } - - // allow Discord to close the session. - <-time.After(time.Second * 5) - - t.Fatalf("expected error while connecting to with a bot without a SessionManager") - } -} - // TestConnect tests Connect(), Disconnect(), heartbeat(), listen(), and onPayload() -// in order to ensure that WebSocket functionality works. +// in order to confirm WebSocket functionality works. func TestConnect(t *testing.T) { zerolog.SetGlobalLevel(zerolog.DebugLevel) @@ -143,7 +116,7 @@ DISCONNECT: } // TestReconnect tests Connect(), Disconnect(), heartbeat(), listen(), and onPayload() -// in order to ensure that WebSocket reconnection functionality works. +// in order to confirm WebSocket reconnection functionality works. func TestReconnect(t *testing.T) { zerolog.SetGlobalLevel(zerolog.DebugLevel) diff --git a/wrapper/tests/integration/voice_test.go b/wrapper/tests/integration/voice_test.go index e8c11d8..54d45c8 100644 --- a/wrapper/tests/integration/voice_test.go +++ b/wrapper/tests/integration/voice_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/rs/zerolog" - . "github.com/switchupcb/disgo/wrapper" + . "github.com/switchupcb/disgo" ) func TestConnectVoice(t *testing.T) { diff --git a/wrapper/voice.go b/wrapper/voice.go index 05f604c..269cee5 100644 --- a/wrapper/voice.go +++ b/wrapper/voice.go @@ -94,7 +94,7 @@ func (vc *VoiceChannelConnection) Connect(bot *Client) error { return errors.New("ConnectVoice: Voice ChannelID must be non-nil and non-empty to connect to voice channel") } - if vc.GatewaySession == nil || !vc.GatewaySession.isConnected() { + if vc.GatewaySession == nil || vc.GatewaySession.State() != SessionStateConnected { return errors.New("ConnectVoice: Session must be connected to the Discord Gateway to connect to voice channel") } @@ -155,7 +155,7 @@ VOICESERVERUPDATE: case <-vc.GatewaySession.Context.Done(): vc.VoiceSession.RUnlock() - return <-vc.GatewaySession.manager.err + return <-vc.GatewaySession.manager.actionError default: vc.VoiceSession.RUnlock() //lint:ignore SA4011 break into for loop. diff --git a/wrapper/voice_session.go b/wrapper/voice_session.go index 1f7968e..e6ebda7 100644 --- a/wrapper/voice_session.go +++ b/wrapper/voice_session.go @@ -8,7 +8,6 @@ import ( "time" json "github.com/goccy/go-json" - "github.com/rs/zerolog/log" "github.com/switchupcb/disgo/wrapper/socket" "github.com/switchupcb/websocket" "golang.org/x/sync/errgroup" @@ -98,12 +97,11 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { hello := new(VoiceHello) if err := readEventVoice(s, hello); err != nil { err = fmt.Errorf("error reading initial VoiceHello event: %w", err) - sessionErr := ErrorSession{SessionID: s.ID, Err: err} + sessionErr := ErrorSession{SessionID: s.ID, Err: err} //nolint:exhaustruct // voice needs refactor if disconnectErr := s.disconnect(FlagClientCloseEventCodeNormal); disconnectErr != nil { - sessionErr.Err = ErrorDisconnect{ - Action: err, - Err: disconnectErr, - Connection: ErrConnectionSessionVoice, + sessionErr.Err = ErrorSessionDisconnect{ + Action: err, + Err: disconnectErr, } } @@ -143,7 +141,7 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { s.manager.routines.Add(1) s.manager.Go(func() error { if err := s.beat(); err != nil { - return ErrorSession{ + return ErrorSession{ //nolint:exhaustruct // voice needs refactor SessionID: s.ID, Err: fmt.Errorf("heartbeat: %w", err), } @@ -154,12 +152,11 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { // send the initial Identify or Resumed packet. if err := s.initial(bot, vc); err != nil { - sessionErr := ErrorSession{SessionID: s.ID, Err: err} + sessionErr := ErrorSession{SessionID: s.ID, Err: err} //nolint:exhaustruct // voice needs refactor if disconnectErr := s.disconnect(FlagClientCloseEventCodeNormal); disconnectErr != nil { - sessionErr.Err = ErrorDisconnect{ - Action: err, - Err: disconnectErr, - Connection: ErrConnectionSessionVoice, + sessionErr.Err = ErrorSessionDisconnect{ + Action: err, + Err: disconnectErr, } } @@ -170,7 +167,7 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { s.manager.routines.Add(1) s.manager.Go(func() error { if err := s.listen(bot); err != nil { - return ErrorSession{ + return ErrorSession{ //nolint:exhaustruct // voice needs refactor SessionID: s.ID, Err: fmt.Errorf("listen: %w", err), } @@ -183,7 +180,7 @@ func (s *VoiceSession) connect(bot *Client, vc *VoiceChannelConnection) error { s.manager.routines.Add(1) go s.manage() - // ensure that the Session's goroutines are spawned. + // confirm the Session's goroutines are spawned. s.manager.routines.Wait() return nil @@ -271,8 +268,6 @@ func (s *VoiceSession) disconnect(code int) error { id := s.ID LogSession(Logger.Info(), id).Msgf("disconnecting voice session with code %d", FlagClientCloseEventCodeNormal) - s.manager.signal = context.WithValue(s.manager.signal, keySignal, signalDisconnect) - // cancel the context to kill the goroutines of the Voice Session. defer s.manager.cancel() @@ -300,26 +295,3 @@ func readEventVoice(s *VoiceSession, dst any) error { return nil } - -// writeEventVoice is a helper function for writing voice events to the WebSocket Session. -func writeEventVoice(s *VoiceSession, op int, name string, dst any) error { - LogCommandVoice(log.Trace(), op, name).Msg("sending voice server command") - - // write the event to the WebSocket Connection. - event, err := json.Marshal(dst) - if err != nil { - return fmt.Errorf("writeEvent: %w", err) - } - - if err = socket.Write(s.Context, s.Conn, websocket.MessageText, - VoicePayload{ - Op: op, - Data: event, - }); err != nil { - return fmt.Errorf("writeEvent: %w", err) - } - - LogCommandVoice(log.Trace(), op, name).Msg("sending voice server command") - - return nil -} diff --git a/wrapper/voice_session_command_write_event.go b/wrapper/voice_session_command_write_event.go new file mode 100644 index 0000000..21c6d1c --- /dev/null +++ b/wrapper/voice_session_command_write_event.go @@ -0,0 +1,33 @@ +package wrapper + +import ( + "fmt" + + json "github.com/goccy/go-json" + "github.com/rs/zerolog/log" + "github.com/switchupcb/disgo/wrapper/socket" + "github.com/switchupcb/websocket" +) + +// writeEventVoice is a helper function for writing voice events to the WebSocket Session. +func writeEventVoice(s *VoiceSession, op int, name string, dst any) error { + LogCommandVoice(log.Trace(), op, name).Msg("sending voice server command") + + // write the event to the WebSocket Connection. + event, err := json.Marshal(dst) + if err != nil { + return fmt.Errorf("writeEvent: %w", err) + } + + if err = socket.Write(s.Context, s.Conn, websocket.MessageText, + VoicePayload{ + Op: op, + Data: event, + }); err != nil { + return fmt.Errorf("writeEvent: %w", err) + } + + LogCommandVoice(log.Trace(), op, name).Msg("sending voice server command") + + return nil +} diff --git a/wrapper/voice_session_heartbeat.go b/wrapper/voice_session_routine_heartbeat.go similarity index 98% rename from wrapper/voice_session_heartbeat.go rename to wrapper/voice_session_routine_heartbeat.go index 05c6f91..57af31b 100644 --- a/wrapper/voice_session_heartbeat.go +++ b/wrapper/voice_session_routine_heartbeat.go @@ -33,7 +33,7 @@ func (s *VoiceSession) Monitor() uint32 { func (s *VoiceSession) beat() error { s.manager.routines.Done() - // ensure that all pulse routines are closed prior to closing. + // confirm all pulse routines are closed prior to closing. defer func() { for { if s.heartbeat == nil { diff --git a/wrapper/voice_session_listener.go b/wrapper/voice_session_routine_listener.go similarity index 100% rename from wrapper/voice_session_listener.go rename to wrapper/voice_session_routine_listener.go diff --git a/wrapper/voice_session_manager.go b/wrapper/voice_session_routine_manager.go similarity index 67% rename from wrapper/voice_session_manager.go rename to wrapper/voice_session_routine_manager.go index 309a22a..ea85397 100644 --- a/wrapper/voice_session_manager.go +++ b/wrapper/voice_session_routine_manager.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "sync" - "time" "github.com/switchupcb/websocket" "golang.org/x/sync/errgroup" @@ -88,7 +87,6 @@ func (s *VoiceSession) reconnect(reason string) { LogSession(Logger.Info(), s.ID).Msg(reason) - s.manager.signal = context.WithValue(s.manager.signal, keySignal, signalReconnect) if err := s.disconnect(FlagClientCloseEventCodeReconnect); err != nil { return fmt.Errorf("reconnect: %w", err) } @@ -111,35 +109,9 @@ func (s *VoiceSession) manage() { s.Lock() defer s.Unlock() - // log the reason for disconnection (if applicable). - if reason := s.manager.signal.Value(keyReason); reason != nil { - LogSession(Logger.Info(), s.ID).Msgf("%v", reason) - } - - // when a signal is provided, it indicates that the disconnection was purposeful. - signal := s.manager.signal.Value(keySignal) - switch signal { - case signalDisconnect: - LogSession(Logger.Info(), s.ID).Msg("successfully disconnected") - - s.manager.err <- nil - - return - - case signalReconnect: - LogSession(Logger.Info(), s.ID).Msg("successfully disconnected (while reconnecting)") - - // allow Discord to close the session. - <-time.After(time.Second) - - s.manager.err <- nil - - return - } - // when an error caused goroutines to close, manage the state of disconnection. if err != nil { - disconnectErr := new(ErrorDisconnect) + disconnectErr := new(ErrorSessionDisconnect) closeErr := new(websocket.CloseError) switch { // when an error occurs from a purposeful disconnection. @@ -152,10 +124,9 @@ func (s *VoiceSession) manage() { default: if cErr := s.Conn.Close(websocket.StatusCode(FlagClientCloseEventCodeAway), ""); cErr != nil { - s.manager.err <- ErrorDisconnect{ - Action: err, - Err: cErr, - Connection: ErrConnectionSessionVoice, + s.manager.err <- ErrorSessionDisconnect{ + Action: err, + Err: cErr, } return @@ -201,61 +172,3 @@ func (s *VoiceSession) handleGatewayCloseError(closeErr *websocket.CloseError) e return closeErr } } - -// Wait blocks until the calling Voice Session has disconnected, then returns the reason -// (disgo.SignalReason) for disconnecting and the disconnection error (if it exists). -// -// If Wait() is called on a Voice Session that isn't connected, it will return immediately -// with code SignalNone. -// -// It's NOT recommended to modify a Voice Session after it has disconnected, -// since it will be cleared and placed into a memory pool shortly after. -func (s *VoiceSession) Wait() (int, error) { - if !s.isConnected() { - return SignalNone, nil - } - - // NOTE: Wait() is equivalent to the s.manage() s.manager.Wait() handling logic, - // but without the management of the disconnection state, - // and without the usage of a channel that tells another goroutine to unblock. - // - // wait until all of a Session's goroutines are closed. - err := s.manager.Wait() - s.Lock() - defer s.Unlock() - - // when a signal is provided, it indicates that the disconnection was purposeful. - signal := s.manager.signal.Value(keySignal) - switch signal { - case signalDisconnect: - return SignalDisconnect, nil - - case signalReconnect: - return SignalReconnect, nil - } - - // when an error caused goroutines to close. - if err != nil { - disconnectErr := new(ErrorDisconnect) - closeErr := new(websocket.CloseError) - switch { - // when an error occurs from a purposeful disconnection. - case errors.As(err, disconnectErr): - if signal != nil { - if signalValue, ok := signal.(int); ok { - return signalValue, err //nolint:wrapcheck - } - } - - return SignalDisconnectError, err //nolint:wrapcheck - - // when an error occurs from a WebSocket Close Error. - case errors.As(err, closeErr): - return SignalError, s.handleGatewayCloseError(closeErr) - } - - return SignalError, err //nolint:wrapcheck - } - - return SignalUndefined, nil -}