diff --git a/disgo.go b/disgo.go index 39fb0cc..ced2c94 100644 --- a/disgo.go +++ b/disgo.go @@ -9693,7 +9693,9 @@ func putSession(s *Session) { s.Endpoint = "" s.Shard = nil s.Context = nil + s.cancel = nil s.Conn = nil + s.setState(SessionStateNew) s.heartbeat = nil s.manager = nil s.client_manager = nil @@ -17182,17 +17184,59 @@ func (r *GetCurrentAuthorizationInformation) Send(bot *Client) (*CurrentAuthoriz type Session struct { Context context.Context RateLimiter RateLimiter - Shard *[2]int Conn *websocket.Conn + Shard *[2]int + cancel context.CancelFunc heartbeat *heartbeat manager *manager client_manager *SessionManager - ID string Endpoint string + ID string + state string Seq int64 + stateMutex sync.RWMutex sync.RWMutex } +// Session States represent the state of the Session's connection to Discord. +const ( + SessionStateNew = "" + + SessionStateConnecting = "connecting (before websocket connection)" + SessionStateConnectingWebsocket = "connecting (with websocket connection)" + SessionStateConnected = "connected" + + SessionStateDisconnecting = "disconnecting (purposefully)" + SessionStateDisconnectingError = "disconnecting (due to an error)" + SessionStateDisconnectingReconnect = "disconnecting (while reconnecting)" + + SessionStateDisconnectedFinal = "disconnected (after connection)" + SessionStateDisconnectedError = "disconnected (due to an error)" + SessionStateDisconnectedReconnect = "disconnected (while reconnecting)" + + SessionStateReconnecting = "reconnecting" +) + +// State returns the state of the Session's connection to Discord. +func (s *Session) State() string { + s.stateMutex.RLock() + defer s.stateMutex.RUnlock() + + return s.state +} + +// setState sets the state of a Session. +func (s *Session) setState(state string) { + s.stateMutex.Lock() + s.state = state + s.stateMutex.Unlock() +} + +// canReconnect returns whether the Session's fields are in a valid state to reconnect. +func (s *Session) canReconnect() bool { + return s.ID != "" && s.Endpoint != "" && atomic.LoadInt64(&s.Seq) != 0 +} + // Connect connects a session to the Discord Gateway (WebSocket Connection). func (s *Session) Connect(bot *Client) error { if bot == nil { @@ -17221,8 +17265,23 @@ func (s *Session) Connect(bot *Client) error { s.manager.signals <- sessionSignalConnect s.Unlock() - if err := <-s.manager.actionError; err != nil { - return err + // wait until the Session has connected + for { + select { + // Context is cancelled during connection when the manager returns an error + // or disconnects from another goroutine call. + case <-s.manager.context.Done(): + return s.manager.coroner.Wait() //nolint:wrapcheck + default: + break + } + + // Session is SessionStateConnected after connection. + // + // proof: Calling Connect() during connection cannot happen while the manager exists. + if s.State() == SessionStateConnected { + break + } } return nil @@ -17240,19 +17299,17 @@ func (s *Session) Disconnect() error { s.manager.signals <- sessionSignalDisconnect s.Unlock() - if err := <-s.manager.actionError; err != nil { - return err + // Session is disconnected from a Disconnect() call when the coroner shuts down. + if err := s.manager.coroner.Wait(); err != nil { + return err //nolint:wrapcheck } - // Reset the session. - putSession(s) - return nil } // Reconnect reconnects an already connected session to the Discord Gateway // by disconnecting the session, then connecting again. -func (s *Session) Reconnect(bot *Client) error { +func (s *Session) Reconnect() error { s.Lock() if s.manager == nil || s.State() != SessionStateConnected { s.Unlock() @@ -17263,8 +17320,32 @@ func (s *Session) Reconnect(bot *Client) error { s.manager.signals <- sessionSignalReconnect s.Unlock() - if err := <-s.manager.actionError; err != nil { - return fmt.Errorf("reconnect: %w", err) + // wait until the manager has received the sessionSignalReconnect + // or changed state to another signal. + for { + if s.State() != SessionStateConnected { + break + } + } + + // wait until the Session has reconnected + // or has experienced an error during reconnection. + for { + select { + // Context is cancelled during reconnection when the manager returns an error + // or disconnects from another goroutine call. + case <-s.manager.context.Done(): + return s.manager.coroner.Wait() //nolint:wrapcheck + default: + break + } + + // Session is SessionStateConnected after reconnection. + // + // proof: Calling Connect() during reconnection cannot happen while the manager exists. + if s.State() == SessionStateConnected { + break + } } return nil @@ -20531,24 +20612,19 @@ func (bot *Client) handle(eventname string, data json.RawMessage) { // coroner investigates when a Session's goroutines are shutdown. func (s *Session) coroner() { - // wait until all the manager goroutines is closed. - err := s.manager.coroner.Wait() - - s.Lock() - - // report the disconnection error - s.manager.actionError <- err - close(s.manager.actionError) + // wait until the manager goroutine is closed. + if err := s.manager.coroner.Wait(); err != nil { + LogSession(Logger.Error(), s.ID).Err(err).Msg("coroner manager routine error") + } - // remove the session from the client. - s.client_manager.RemoveGatewaySession(s.ID) + // Reset the session. + putSession(s) - s.logClose("coroner") - s.Unlock() + Logger.Info().Msg("closed coroner routine") } // Wait blocks until the calling Session is inactive (due to a final disconnect), -// then returns the Session's state and the disconnection error (if it exists). +// then returns the Session's state and the disconnection error (when it exists). // // If Wait() is called on a Session that isn't connected, it will return immediately // with code SessionStateNew. @@ -20811,16 +20887,11 @@ func (s *Session) onPayload(bot *Client, payload GatewayPayload) error { // in the context of onPayload, an Invalid Session occurs when an active session is invalidated. case FlagGatewayOpcodeInvalidSession: - // Remove the session from the session manager. - s.client_manager.RemoveGatewaySession(s.ID) - // wait for Discord to close the session, then complete a fresh connect. <-time.NewTimer(invalidSessionWaitTime).C - s.Lock() - defer s.Unlock() - - if err := s.initial(bot, 0); err != nil { + _ = s.Disconnect() + if err := s.Connect(bot); err != nil { return err } } @@ -20830,25 +20901,19 @@ func (s *Session) onPayload(bot *Client, payload GatewayPayload) error { // manager represents a manager of a Session's goroutines. type manager struct { - coroner errgroup.Group - signals chan uint8 - cancel context.CancelFunc - actionError chan error + context context.Context + signals chan uint8 + coroner *errgroup.Group *errgroup.Group - state string - routines sync.WaitGroup - stateMutex sync.RWMutex - pulses int32 + routines sync.WaitGroup + pulses int32 } // spawnManager spawns a tracked manager. func (s *Session) spawnManager(bot *Client) { s.manager = new(manager) - - s.Context, s.manager.cancel = context.WithCancel(context.Background()) - s.manager.Group, s.Context = errgroup.WithContext(s.Context) + s.manager.coroner, s.manager.context = errgroup.WithContext(context.Background()) s.manager.signals = make(chan uint8) - s.manager.actionError = make(chan error, 1) // spawn the manager goroutine. s.manager.coroner.Go(func() error { @@ -20860,45 +20925,6 @@ func (s *Session) spawnManager(bot *Client) { }) } -// Session States represent the state of the Session's connection to Discord. -const ( - SessionStateNew = "" - - SessionStateConnecting = "connecting (before websocket connection)" - SessionStateConnectingWebsocket = "connecting (with websocket connection)" - SessionStateConnected = "connected" - - SessionStateDisconnecting = "disconnecting (purposefully)" - SessionStateDisconnectingError = "disconnecting (due to an error)" - SessionStateDisconnectingReconnect = "disconnecting (while reconnecting)" - - SessionStateDisconnectedFinal = "disconnected (after connection)" - SessionStateDisconnectedError = "disconnected (due to an error)" - SessionStateDisconnectedReconnect = "disconnected (while reconnecting)" - - SessionStateReconnecting = "reconnecting" -) - -// State returns the state of the Session's connection to Discord. -func (s *Session) State() string { - s.manager.stateMutex.RLock() - defer s.manager.stateMutex.RUnlock() - - return s.manager.state -} - -// setState sets the state of a Session. -func (s *Session) setState(state string) { - s.manager.stateMutex.Lock() - s.manager.state = state - s.manager.stateMutex.Unlock() -} - -// canReconnect returns whether the Session's fields are in a valid state to reconnect. -func (s *Session) canReconnect() bool { - return s.ID != "" && s.Endpoint != "" && atomic.LoadInt64(&s.Seq) != 0 -} - // Session Signals represent manager signals to perform actions to the Session. const ( sessionSignalConnect = 1 @@ -20907,51 +20933,58 @@ const ( ) // manage manages a Session's goroutines. -func (s *Session) manage(bot *Client) error { +func (s *Session) manage(bot *Client) error { //nolint:maintidx // spawn the coroner once the manager routine is alive. go s.coroner() - defer func() { - if s.State() != SessionStateDisconnectedReconnect { - s.Unlock() - } + // create a temporary context for a new session (which is reset upon connection). + s.Context = context.Background() - // wait until the previous connection's manager goroutines are closed. - _ = s.manager.Wait() + var managerErr error + + defer func() { + // remove the session from the client. + s.client_manager.RemoveGatewaySession(s.ID) s.logClose("manager") }() - var managedErr error - for { select { + // <-s.Context.Done() when all managed routines are closing + // due to reconnection (while awaiting a connection signal) or + // due to an unexpected error in a managed routine. case <-s.Context.Done(): + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("received signal: <-s.Context.Done with state %q", s.State()) + + // wait until the session's manager goroutines are closed (with s.Unlocked). + // + // proof: s.manager.Wait() returns instantly when SessionStateDisconnectedReconnect (with s.Locked). + err := s.manager.Wait() + + // All session routines are closed when + // + // reconnecting (while waiting for another signal) if s.State() == SessionStateDisconnectedReconnect { break } - // wait until the previous connection's manager goroutines are closed. - err := s.manager.Wait() + // disconnecting (unexpectedly) if err != nil { - closeErr := new(websocket.CloseError) + // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 + if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { + return nil + } + closeErr := new(websocket.CloseError) if errors.As(err, closeErr) { if vErr := s.validateGatewayCloseError(closeErr); vErr == nil { // reconnect from a state where s.setState(SessionStateDisconnectedReconnect) - // manager routines must be reset - s.Context, s.manager.cancel = context.WithCancel(context.Background()) //nolint:fatcontext - s.manager.Group, s.Context = errgroup.WithContext(s.Context) - + // send a connection signal. go func() { - // send a connection signal. s.manager.signals <- sessionSignalConnect - - // read the s.manager.actionError send from a successful connection. - e := <-s.manager.actionError - LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("captured result from close event reconnect: %q", e) }() s.Lock() @@ -20959,147 +20992,198 @@ func (s *Session) manage(bot *Client) error { break // to reconnect from the connect case logic. } // vErr == nil } // errors.As - - // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 - if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { - err = nil - } } // err != nil - s.Lock() - - return err + return nil case signal := <-s.manager.signals: switch signal { case sessionSignalConnect: - if s.State() != SessionStateDisconnectedReconnect { + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("received signal: connect with state %q", s.State()) + + switch s.State() { + // SessionStateNew when Connect() on new session. + case SessionStateNew: s.Lock() - } else { - s.Unlock() - // wait until the previous connection's manager goroutines are closed. - _ = s.manager.Wait() + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connecting session") - s.Lock() - } + if err := s.connect(bot); err != nil { + managerErr = ErrorSession{SessionID: s.ID, State: s.State(), Type: ErrorSessionTypeGateway, Err: err} - LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connecting session") + // disconnect when error occurred after websocket connection + if s.State() == SessionStateConnectingWebsocket { + // send a disconnection signal. + go func() { + s.manager.signals <- sessionSignalDisconnect + }() - if err := s.connect(bot); err != nil { - managedErr = ErrorSession{SessionID: s.ID, State: s.State(), Type: ErrorSessionTypeGateway, Err: err} + break // to handle error after disconnection + } - switch s.State() { - case SessionStateConnectingWebsocket: - go func() { + s.Unlock() + + return managerErr + } + + s.setState(SessionStateConnected) + s.Unlock() + + // SessionStateDisconnectedReconnect when reconnecting from disconnected session. + case SessionStateDisconnectedReconnect: + // s.Lock() called during reconnection signal. + + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("reconnecting session") + + if err := s.connect(bot); err != nil { + managerErr = ErrorSession{SessionID: s.ID, State: s.State(), Type: ErrorSessionTypeGateway, Err: err} + + // disconnect when error occurred after websocket connection + if s.State() == SessionStateConnectingWebsocket { // send a disconnection signal. - s.manager.signals <- sessionSignalDisconnect - }() + go func() { + s.manager.signals <- sessionSignalDisconnect + }() - // case SessionStateNew, SessionStateConnecting... - default: - return managedErr + break // to handle error after disconnection + } + + s.Unlock() + + return managerErr } - break // to handle the error in the disconnect case logic. - } + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connected session") + s.setState(SessionStateConnected) + s.Unlock() - s.setState(SessionStateConnected) - s.manager.actionError <- nil - s.Unlock() + default: + return fmt.Errorf("unexpected state during session connection: %v", s.State()) + } case sessionSignalDisconnect: - if managedErr == nil && s.State() != SessionStateReconnecting { - s.Lock() - } + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("received signal: disconnect with state %q", s.State()) // update the session's state and client close event code. - code := FlagClientCloseEventCodeNormal + var code int switch { - case managedErr != nil: + case managerErr != nil: + // s.Lock() called before error. + s.setState(SessionStateDisconnectingError) + code = FlagClientCloseEventCodeNormal + + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("%q session with code %d", s.State(), code) + case s.State() == SessionStateReconnecting: + // s.Lock() called during reconnection signal. + s.setState(SessionStateDisconnectingReconnect) code = FlagClientCloseEventCodeReconnect + + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("%q session with code %d", s.State(), code) + default: + s.Lock() + s.setState(SessionStateDisconnecting) - } + code = FlagClientCloseEventCodeNormal - LogSession(Logger.Info(), s.ID).Msgf("%q session with code %d", s.State(), code) + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("%q session with code %d", s.State(), code) + } // disconnect the session. if err := s.disconnect(code); err != nil { - managedErr = ErrorSession{ - SessionID: s.ID, - State: s.State(), - Type: ErrorSessionTypeGateway, - Err: ErrorSessionDisconnect{ - Action: managedErr, - Err: err, - }, - } + // validate the disconnection error. + closeErr := new(websocket.CloseError) // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { - managedErr = nil + err = nil + } else if errors.As(err, closeErr) { + err = s.validateGatewayCloseError(closeErr) } - if s.State() != SessionStateDisconnectingReconnect { - return managedErr + if managerErr != nil { + s.Unlock() + + // wait until the session's manager goroutines are closed (with s.Unlocked). + _ = s.manager.Wait() + + return ErrorSession{ + SessionID: s.ID, + State: s.State(), + Type: ErrorSessionTypeGateway, + Err: ErrorSessionDisconnect{ + Action: managerErr, + Err: err, + }, + } } - // validate error when reconnecting - closeErr := new(websocket.CloseError) - if errors.As(managedErr, closeErr) { - if managedErr = s.validateGatewayCloseError(closeErr); managedErr != nil { - return managedErr + if err != nil { + s.Unlock() + + // wait until the session's manager goroutines are closed (with s.Unlocked). + _ = s.manager.Wait() + + return ErrorSession{ + SessionID: s.ID, + State: s.State(), + Type: ErrorSessionTypeGateway, + Err: ErrorSessionDisconnect{ + Action: nil, + Err: err, + }, } } } // disconnect // update the session's state. switch { - case s.State() == SessionStateDisconnectingError: + case managerErr != nil: s.setState(SessionStateDisconnectedError) - case s.State() == SessionStateDisconnectingReconnect: s.setState(SessionStateDisconnectedReconnect) - default: s.setState(SessionStateDisconnectedFinal) } - LogSession(Logger.Info(), s.ID).Msgf("%q session with code %d", s.State(), code) + // wait until the session's manager goroutines are closed (with s.Unlocked). + s.Unlock() + _ = s.manager.Wait() + + s.Lock() + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("%q session with code %d", s.State(), code) if s.State() == SessionStateDisconnectedReconnect { // allow Discord to close the session. <-time.After(time.Second) + // send a connection signal. go func() { - // send a connection signal. s.manager.signals <- sessionSignalConnect }() break } - // Destroy the manager when the bot isn't reconnecting. - if managedErr != nil { - return managedErr - } + s.Unlock() return nil case sessionSignalReconnect: + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("received signal: reconnect with state %q", s.State()) + s.Lock() s.setState(SessionStateReconnecting) + // send a disconnection signal. go func() { - // send a disconnection signal. s.manager.signals <- sessionSignalDisconnect }() - } + } // switch signal } // select } // for } @@ -21130,6 +21214,14 @@ func (s *Session) validateGatewayCloseError(closeErr *websocket.CloseError) erro // Gateway Close Event Code is unknown. default: + // when another goroutine returns an error, + // s.Conn.Close is called before s.cancel which will result in + // a CloseError with the close code that Disgo uses to reconnect. + if closeErr.Code == websocket.StatusCode(FlagClientCloseEventCodeNormal) || + closeErr.Code == websocket.StatusCode(FlagClientCloseEventCodeReconnect) { + return nil + } + LogSession(Logger.Info(), s.ID). Msgf("received unknown Gateway Close Event Code %d with reason %q", closeErr.Code, closeErr.Reason, @@ -21207,7 +21299,8 @@ func (s *Session) connect(bot *Client) error { ) // connect to the Discord Gateway Websocket. - s.Context, s.manager.cancel = context.WithCancel(context.Background()) + s.Context, s.cancel = context.WithCancel(context.Background()) + s.manager.Group, s.Context = errgroup.WithContext(s.Context) if s.Conn, _, err = websocket.Dial(s.Context, gatewayEndpoint+gatewayEndpointParams, nil); err != nil { return fmt.Errorf("error connecting to the Discord Gateway: %w", err) } @@ -21420,7 +21513,7 @@ func (s *Session) initial(bot *Client, attempt int) error { // disconnect disconnects a session from a WebSocket Connection using the given status code. func (s *Session) disconnect(code int) error { // cancel the context to kill the goroutines of the Session. - defer s.manager.cancel() + defer s.cancel() if err := s.Conn.Close(websocket.StatusCode(code), ""); err != nil { return fmt.Errorf("%w", err) @@ -21484,7 +21577,7 @@ type ShardManager interface { Disconnect() error // Reconnect reconnects to the Discord Gateway using the Shard Manager. - Reconnect(bot *Client) error + Reconnect() error } // ShardLimit contains information about sharding limits. @@ -21727,10 +21820,10 @@ VOICESERVERUPDATE: } select { - case <-vc.GatewaySession.Context.Done(): + case <-vc.GatewaySession.manager.context.Done(): vc.VoiceSession.RUnlock() - return <-vc.GatewaySession.manager.actionError + return vc.GatewaySession.manager.Wait() //nolint:wrapcheck default: vc.VoiceSession.RUnlock() //lint:ignore SA4011 break into for loop. diff --git a/shard/instance.go b/shard/instance.go index 6ddd140..6710ce5 100644 --- a/shard/instance.go +++ b/shard/instance.go @@ -125,12 +125,12 @@ func (sm *InstanceShardManager) Disconnect() error { } // Reconnect connects to the Discord Gateway using the Shard Manager. -func (sm *InstanceShardManager) Reconnect(bot *disgo.Client) error { +func (sm *InstanceShardManager) Reconnect() error { // totalShards represents the total number of shards that are connected. totalShards := len(sm.Sessions) for sessionCount := 0; sessionCount < totalShards; sessionCount++ { - if err := sm.Sessions[sessionCount].Reconnect(bot); err != nil { + if err := sm.Sessions[sessionCount].Reconnect(); err != nil { return fmt.Errorf(errShardManager, err) } } diff --git a/shard/tests/integration/shard_test.go b/shard/tests/integration/shard_test.go index 04053a0..ed25ef9 100644 --- a/shard/tests/integration/shard_test.go +++ b/shard/tests/integration/shard_test.go @@ -38,7 +38,7 @@ func TestReconnect(t *testing.T) { time.Sleep(time.Second) // reconnect to the Discord Gateway (WebSocket Session). - if err := s.Reconnect(bot); err != nil { + if err := s.Reconnect(); err != nil { t.Fatalf("%v", err) } diff --git a/wrapper/pool.go b/wrapper/pool.go index 06044ed..016b10c 100644 --- a/wrapper/pool.go +++ b/wrapper/pool.go @@ -52,7 +52,9 @@ func putSession(s *Session) { s.Endpoint = "" s.Shard = nil s.Context = nil + s.cancel = nil s.Conn = nil + s.setState(SessionStateNew) s.heartbeat = nil s.manager = nil s.client_manager = nil diff --git a/wrapper/session.go b/wrapper/session.go index 45b0617..33c1810 100644 --- a/wrapper/session.go +++ b/wrapper/session.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "github.com/switchupcb/websocket" ) @@ -32,9 +33,19 @@ type Session struct { // Context is also used as a signal for the Session's goroutines. Context context.Context + // Cancel represents the cancellation signal for a Session Context. + cancel context.CancelFunc + // Conn represents a WebSocket Connection to the Discord Gateway. Conn *websocket.Conn + // state represents the state of the Session's connection to Discord. + state string + + // stateMutex is used to protect the Session's manager state from data races + // by providing transactional functionality. + stateMutex sync.RWMutex + // heartbeat contains the fields required to implement the heartbeat mechanism. heartbeat *heartbeat @@ -52,6 +63,45 @@ type Session struct { sync.RWMutex } +// Session States represent the state of the Session's connection to Discord. +const ( + SessionStateNew = "" + + SessionStateConnecting = "connecting (before websocket connection)" + SessionStateConnectingWebsocket = "connecting (with websocket connection)" + SessionStateConnected = "connected" + + SessionStateDisconnecting = "disconnecting (purposefully)" + SessionStateDisconnectingError = "disconnecting (due to an error)" + SessionStateDisconnectingReconnect = "disconnecting (while reconnecting)" + + SessionStateDisconnectedFinal = "disconnected (after connection)" + SessionStateDisconnectedError = "disconnected (due to an error)" + SessionStateDisconnectedReconnect = "disconnected (while reconnecting)" + + SessionStateReconnecting = "reconnecting" +) + +// State returns the state of the Session's connection to Discord. +func (s *Session) State() string { + s.stateMutex.RLock() + defer s.stateMutex.RUnlock() + + return s.state +} + +// setState sets the state of a Session. +func (s *Session) setState(state string) { + s.stateMutex.Lock() + s.state = state + s.stateMutex.Unlock() +} + +// canReconnect returns whether the Session's fields are in a valid state to reconnect. +func (s *Session) canReconnect() bool { + return s.ID != "" && s.Endpoint != "" && atomic.LoadInt64(&s.Seq) != 0 +} + // Connect connects a session to the Discord Gateway (WebSocket Connection). func (s *Session) Connect(bot *Client) error { if bot == nil { @@ -80,8 +130,23 @@ func (s *Session) Connect(bot *Client) error { s.manager.signals <- sessionSignalConnect s.Unlock() - if err := <-s.manager.actionError; err != nil { - return err + // wait until the Session has connected + for { + select { + // Context is cancelled during connection when the manager returns an error + // or disconnects from another goroutine call. + case <-s.manager.context.Done(): + return s.manager.coroner.Wait() //nolint:wrapcheck + default: + break + } + + // Session is SessionStateConnected after connection. + // + // proof: Calling Connect() during connection cannot happen while the manager exists. + if s.State() == SessionStateConnected { + break + } } return nil @@ -99,19 +164,17 @@ func (s *Session) Disconnect() error { s.manager.signals <- sessionSignalDisconnect s.Unlock() - if err := <-s.manager.actionError; err != nil { - return err + // Session is disconnected from a Disconnect() call when the coroner shuts down. + if err := s.manager.coroner.Wait(); err != nil { + return err //nolint:wrapcheck } - // Reset the session. - putSession(s) - return nil } // Reconnect reconnects an already connected session to the Discord Gateway // by disconnecting the session, then connecting again. -func (s *Session) Reconnect(bot *Client) error { +func (s *Session) Reconnect() error { s.Lock() if s.manager == nil || s.State() != SessionStateConnected { s.Unlock() @@ -122,8 +185,32 @@ func (s *Session) Reconnect(bot *Client) error { s.manager.signals <- sessionSignalReconnect s.Unlock() - if err := <-s.manager.actionError; err != nil { - return fmt.Errorf("reconnect: %w", err) + // wait until the manager has received the sessionSignalReconnect + // or changed state to another signal. + for { + if s.State() != SessionStateConnected { + break + } + } + + // wait until the Session has reconnected + // or has experienced an error during reconnection. + for { + select { + // Context is cancelled during reconnection when the manager returns an error + // or disconnects from another goroutine call. + case <-s.manager.context.Done(): + return s.manager.coroner.Wait() //nolint:wrapcheck + default: + break + } + + // Session is SessionStateConnected after reconnection. + // + // proof: Calling Connect() during reconnection cannot happen while the manager exists. + if s.State() == SessionStateConnected { + break + } } return nil diff --git a/wrapper/session_routine_coroner.go b/wrapper/session_routine_coroner.go index 6b14d01..b138820 100644 --- a/wrapper/session_routine_coroner.go +++ b/wrapper/session_routine_coroner.go @@ -2,24 +2,19 @@ package wrapper // coroner investigates when a Session's goroutines are shutdown. func (s *Session) coroner() { - // wait until all the manager goroutines is closed. - err := s.manager.coroner.Wait() - - s.Lock() - - // report the disconnection error - s.manager.actionError <- err - close(s.manager.actionError) + // wait until the manager goroutine is closed. + if err := s.manager.coroner.Wait(); err != nil { + LogSession(Logger.Error(), s.ID).Err(err).Msg("coroner manager routine error") + } - // remove the session from the client. - s.client_manager.RemoveGatewaySession(s.ID) + // Reset the session. + putSession(s) - s.logClose("coroner") - s.Unlock() + Logger.Info().Msg("closed coroner routine") } // Wait blocks until the calling Session is inactive (due to a final disconnect), -// then returns the Session's state and the disconnection error (if it exists). +// then returns the Session's state and the disconnection error (when it exists). // // If Wait() is called on a Session that isn't connected, it will return immediately // with code SessionStateNew. diff --git a/wrapper/session_routine_listener.go b/wrapper/session_routine_listener.go index 4f71912..32637a3 100644 --- a/wrapper/session_routine_listener.go +++ b/wrapper/session_routine_listener.go @@ -83,16 +83,11 @@ func (s *Session) onPayload(bot *Client, payload GatewayPayload) error { // in the context of onPayload, an Invalid Session occurs when an active session is invalidated. case FlagGatewayOpcodeInvalidSession: - // Remove the session from the session manager. - s.client_manager.RemoveGatewaySession(s.ID) - // wait for Discord to close the session, then complete a fresh connect. <-time.NewTimer(invalidSessionWaitTime).C - s.Lock() - defer s.Unlock() - - if err := s.initial(bot, 0); err != nil { + _ = s.Disconnect() + if err := s.Connect(bot); err != nil { return err } } diff --git a/wrapper/session_routine_manager.go b/wrapper/session_routine_manager.go index e19a8d8..9f9fb5c 100644 --- a/wrapper/session_routine_manager.go +++ b/wrapper/session_routine_manager.go @@ -6,7 +6,6 @@ import ( "fmt" "strings" "sync" - "sync/atomic" "time" "github.com/switchupcb/websocket" @@ -15,29 +14,19 @@ import ( // manager represents a manager of a Session's goroutines. type manager struct { - // state represents the state of the Session's connection to Discord. - state string - - // stateMutex is used to protect the Session's manager state from data races - // by providing transactional functionality. - stateMutex sync.RWMutex - // signals represents a channel of signals. signals chan uint8 // coroner represents a goroutine group to track the manager routine. - coroner errgroup.Group + coroner *errgroup.Group + + // context is used as a context for the manager routine. + context context.Context // routines represents a goroutine counter that ensures all of the Session's goroutines // are spawned prior to returning from connect(). routines sync.WaitGroup - // cancel represents the cancellation signal for a Session's Context. - cancel context.CancelFunc - - // actionError represents the error this manager detects upon a connection action (e.g., connecting, disconnecting). - actionError chan error - // pulses represents the amount of goroutines that can generate heartbeat pulses. // // pulses ensures that pulse goroutines always have a receiver channel for heartbeats @@ -79,11 +68,8 @@ type manager struct { // spawnManager spawns a tracked manager. func (s *Session) spawnManager(bot *Client) { s.manager = new(manager) - - s.Context, s.manager.cancel = context.WithCancel(context.Background()) - s.manager.Group, s.Context = errgroup.WithContext(s.Context) + s.manager.coroner, s.manager.context = errgroup.WithContext(context.Background()) s.manager.signals = make(chan uint8) - s.manager.actionError = make(chan error, 1) // spawn the manager goroutine. s.manager.coroner.Go(func() error { @@ -95,45 +81,6 @@ func (s *Session) spawnManager(bot *Client) { }) } -// Session States represent the state of the Session's connection to Discord. -const ( - SessionStateNew = "" - - SessionStateConnecting = "connecting (before websocket connection)" - SessionStateConnectingWebsocket = "connecting (with websocket connection)" - SessionStateConnected = "connected" - - SessionStateDisconnecting = "disconnecting (purposefully)" - SessionStateDisconnectingError = "disconnecting (due to an error)" - SessionStateDisconnectingReconnect = "disconnecting (while reconnecting)" - - SessionStateDisconnectedFinal = "disconnected (after connection)" - SessionStateDisconnectedError = "disconnected (due to an error)" - SessionStateDisconnectedReconnect = "disconnected (while reconnecting)" - - SessionStateReconnecting = "reconnecting" -) - -// State returns the state of the Session's connection to Discord. -func (s *Session) State() string { - s.manager.stateMutex.RLock() - defer s.manager.stateMutex.RUnlock() - - return s.manager.state -} - -// setState sets the state of a Session. -func (s *Session) setState(state string) { - s.manager.stateMutex.Lock() - s.manager.state = state - s.manager.stateMutex.Unlock() -} - -// canReconnect returns whether the Session's fields are in a valid state to reconnect. -func (s *Session) canReconnect() bool { - return s.ID != "" && s.Endpoint != "" && atomic.LoadInt64(&s.Seq) != 0 -} - // Session Signals represent manager signals to perform actions to the Session. const ( sessionSignalConnect = 1 @@ -142,51 +89,58 @@ const ( ) // manage manages a Session's goroutines. -func (s *Session) manage(bot *Client) error { +func (s *Session) manage(bot *Client) error { //nolint:maintidx // spawn the coroner once the manager routine is alive. go s.coroner() - defer func() { - if s.State() != SessionStateDisconnectedReconnect { - s.Unlock() - } + // create a temporary context for a new session (which is reset upon connection). + s.Context = context.Background() - // wait until the previous connection's manager goroutines are closed. - _ = s.manager.Wait() + var managerErr error + + defer func() { + // remove the session from the client. + s.client_manager.RemoveGatewaySession(s.ID) s.logClose("manager") }() - var managedErr error - for { select { + // <-s.Context.Done() when all managed routines are closing + // due to reconnection (while awaiting a connection signal) or + // due to an unexpected error in a managed routine. case <-s.Context.Done(): + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("received signal: <-s.Context.Done with state %q", s.State()) + + // wait until the session's manager goroutines are closed (with s.Unlocked). + // + // proof: s.manager.Wait() returns instantly when SessionStateDisconnectedReconnect (with s.Locked). + err := s.manager.Wait() + + // All session routines are closed when + // + // reconnecting (while waiting for another signal) if s.State() == SessionStateDisconnectedReconnect { break } - // wait until the previous connection's manager goroutines are closed. - err := s.manager.Wait() + // disconnecting (unexpectedly) if err != nil { - closeErr := new(websocket.CloseError) + // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 + if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { + return nil + } + closeErr := new(websocket.CloseError) if errors.As(err, closeErr) { if vErr := s.validateGatewayCloseError(closeErr); vErr == nil { // reconnect from a state where s.setState(SessionStateDisconnectedReconnect) - // manager routines must be reset - s.Context, s.manager.cancel = context.WithCancel(context.Background()) //nolint:fatcontext - s.manager.Group, s.Context = errgroup.WithContext(s.Context) - + // send a connection signal. go func() { - // send a connection signal. s.manager.signals <- sessionSignalConnect - - // read the s.manager.actionError send from a successful connection. - e := <-s.manager.actionError - LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("captured result from close event reconnect: %q", e) }() s.Lock() @@ -194,147 +148,198 @@ func (s *Session) manage(bot *Client) error { break // to reconnect from the connect case logic. } // vErr == nil } // errors.As - - // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 - if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { - err = nil - } } // err != nil - s.Lock() - - return err + return nil case signal := <-s.manager.signals: switch signal { case sessionSignalConnect: - if s.State() != SessionStateDisconnectedReconnect { + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("received signal: connect with state %q", s.State()) + + switch s.State() { + // SessionStateNew when Connect() on new session. + case SessionStateNew: s.Lock() - } else { - s.Unlock() - // wait until the previous connection's manager goroutines are closed. - _ = s.manager.Wait() + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connecting session") - s.Lock() - } + if err := s.connect(bot); err != nil { + managerErr = ErrorSession{SessionID: s.ID, State: s.State(), Type: ErrorSessionTypeGateway, Err: err} - LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connecting session") + // disconnect when error occurred after websocket connection + if s.State() == SessionStateConnectingWebsocket { + // send a disconnection signal. + go func() { + s.manager.signals <- sessionSignalDisconnect + }() - if err := s.connect(bot); err != nil { - managedErr = ErrorSession{SessionID: s.ID, State: s.State(), Type: ErrorSessionTypeGateway, Err: err} + break // to handle error after disconnection + } - switch s.State() { - case SessionStateConnectingWebsocket: - go func() { + s.Unlock() + + return managerErr + } + + s.setState(SessionStateConnected) + s.Unlock() + + // SessionStateDisconnectedReconnect when reconnecting from disconnected session. + case SessionStateDisconnectedReconnect: + // s.Lock() called during reconnection signal. + + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("reconnecting session") + + if err := s.connect(bot); err != nil { + managerErr = ErrorSession{SessionID: s.ID, State: s.State(), Type: ErrorSessionTypeGateway, Err: err} + + // disconnect when error occurred after websocket connection + if s.State() == SessionStateConnectingWebsocket { // send a disconnection signal. - s.manager.signals <- sessionSignalDisconnect - }() + go func() { + s.manager.signals <- sessionSignalDisconnect + }() + + break // to handle error after disconnection + } + + s.Unlock() - // case SessionStateNew, SessionStateConnecting... - default: - return managedErr + return managerErr } - break // to handle the error in the disconnect case logic. - } + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msg("connected session") + s.setState(SessionStateConnected) + s.Unlock() - s.setState(SessionStateConnected) - s.manager.actionError <- nil - s.Unlock() + default: + return fmt.Errorf("unexpected state during session connection: %v", s.State()) + } case sessionSignalDisconnect: - if managedErr == nil && s.State() != SessionStateReconnecting { - s.Lock() - } + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("received signal: disconnect with state %q", s.State()) // update the session's state and client close event code. - code := FlagClientCloseEventCodeNormal + var code int switch { - case managedErr != nil: + case managerErr != nil: + // s.Lock() called before error. + s.setState(SessionStateDisconnectingError) + code = FlagClientCloseEventCodeNormal + + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("%q session with code %d", s.State(), code) + case s.State() == SessionStateReconnecting: + // s.Lock() called during reconnection signal. + s.setState(SessionStateDisconnectingReconnect) code = FlagClientCloseEventCodeReconnect + + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("%q session with code %d", s.State(), code) + default: + s.Lock() + s.setState(SessionStateDisconnecting) - } + code = FlagClientCloseEventCodeNormal - LogSession(Logger.Info(), s.ID).Msgf("%q session with code %d", s.State(), code) + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("%q session with code %d", s.State(), code) + } // disconnect the session. if err := s.disconnect(code); err != nil { - managedErr = ErrorSession{ - SessionID: s.ID, - State: s.State(), - Type: ErrorSessionTypeGateway, - Err: ErrorSessionDisconnect{ - Action: managedErr, - Err: err, - }, - } + // validate the disconnection error. + closeErr := new(websocket.CloseError) // TODO: Use errors.As: https://github.com/coder/websocket/issues/519 if strings.Contains(err.Error(), "failed to close WebSocket: received header with unexpected rsv bits set") { - managedErr = nil + err = nil + } else if errors.As(err, closeErr) { + err = s.validateGatewayCloseError(closeErr) } - if s.State() != SessionStateDisconnectingReconnect { - return managedErr + if managerErr != nil { + s.Unlock() + + // wait until the session's manager goroutines are closed (with s.Unlocked). + _ = s.manager.Wait() + + return ErrorSession{ + SessionID: s.ID, + State: s.State(), + Type: ErrorSessionTypeGateway, + Err: ErrorSessionDisconnect{ + Action: managerErr, + Err: err, + }, + } } - // validate error when reconnecting - closeErr := new(websocket.CloseError) - if errors.As(managedErr, closeErr) { - if managedErr = s.validateGatewayCloseError(closeErr); managedErr != nil { - return managedErr + if err != nil { + s.Unlock() + + // wait until the session's manager goroutines are closed (with s.Unlocked). + _ = s.manager.Wait() + + return ErrorSession{ + SessionID: s.ID, + State: s.State(), + Type: ErrorSessionTypeGateway, + Err: ErrorSessionDisconnect{ + Action: nil, + Err: err, + }, } } } // disconnect // update the session's state. switch { - case s.State() == SessionStateDisconnectingError: + case managerErr != nil: s.setState(SessionStateDisconnectedError) - case s.State() == SessionStateDisconnectingReconnect: s.setState(SessionStateDisconnectedReconnect) - default: s.setState(SessionStateDisconnectedFinal) } - LogSession(Logger.Info(), s.ID).Msgf("%q session with code %d", s.State(), code) + // wait until the session's manager goroutines are closed (with s.Unlocked). + s.Unlock() + _ = s.manager.Wait() + + s.Lock() + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("%q session with code %d", s.State(), code) if s.State() == SessionStateDisconnectedReconnect { // allow Discord to close the session. <-time.After(time.Second) + // send a connection signal. go func() { - // send a connection signal. s.manager.signals <- sessionSignalConnect }() break } - // Destroy the manager when the bot isn't reconnecting. - if managedErr != nil { - return managedErr - } + s.Unlock() return nil case sessionSignalReconnect: + LogSession(Logger.Info(), s.ID).Str(LogCtxClient, bot.ApplicationID).Msgf("received signal: reconnect with state %q", s.State()) + s.Lock() s.setState(SessionStateReconnecting) + // send a disconnection signal. go func() { - // send a disconnection signal. s.manager.signals <- sessionSignalDisconnect }() - } + } // switch signal } // select } // for } @@ -365,6 +370,14 @@ func (s *Session) validateGatewayCloseError(closeErr *websocket.CloseError) erro // Gateway Close Event Code is unknown. default: + // when another goroutine returns an error, + // s.Conn.Close is called before s.cancel which will result in + // a CloseError with the close code that Disgo uses to reconnect. + if closeErr.Code == websocket.StatusCode(FlagClientCloseEventCodeNormal) || + closeErr.Code == websocket.StatusCode(FlagClientCloseEventCodeReconnect) { + return nil + } + LogSession(Logger.Info(), s.ID). Msgf("received unknown Gateway Close Event Code %d with reason %q", closeErr.Code, closeErr.Reason, diff --git a/wrapper/session_routine_manager_actions.go b/wrapper/session_routine_manager_actions.go index 8b83a8c..3eb0b50 100644 --- a/wrapper/session_routine_manager_actions.go +++ b/wrapper/session_routine_manager_actions.go @@ -10,6 +10,7 @@ import ( json "github.com/goccy/go-json" "github.com/switchupcb/disgo/wrapper/socket" "github.com/switchupcb/websocket" + "golang.org/x/sync/errgroup" ) const ( @@ -80,7 +81,8 @@ func (s *Session) connect(bot *Client) error { ) // connect to the Discord Gateway Websocket. - s.Context, s.manager.cancel = context.WithCancel(context.Background()) + s.Context, s.cancel = context.WithCancel(context.Background()) + s.manager.Group, s.Context = errgroup.WithContext(s.Context) if s.Conn, _, err = websocket.Dial(s.Context, gatewayEndpoint+gatewayEndpointParams, nil); err != nil { return fmt.Errorf("error connecting to the Discord Gateway: %w", err) } @@ -293,7 +295,7 @@ func (s *Session) initial(bot *Client, attempt int) error { // disconnect disconnects a session from a WebSocket Connection using the given status code. func (s *Session) disconnect(code int) error { // cancel the context to kill the goroutines of the Session. - defer s.manager.cancel() + defer s.cancel() if err := s.Conn.Close(websocket.StatusCode(code), ""); err != nil { return fmt.Errorf("%w", err) diff --git a/wrapper/shard.go b/wrapper/shard.go index c3a94a8..02ed0e3 100644 --- a/wrapper/shard.go +++ b/wrapper/shard.go @@ -35,7 +35,7 @@ type ShardManager interface { Disconnect() error // Reconnect reconnects to the Discord Gateway using the Shard Manager. - Reconnect(bot *Client) error + Reconnect() error } // ShardLimit contains information about sharding limits. diff --git a/wrapper/tests/integration/session_test.go b/wrapper/tests/integration/session_test.go index 7cb5217..38a4952 100644 --- a/wrapper/tests/integration/session_test.go +++ b/wrapper/tests/integration/session_test.go @@ -170,7 +170,7 @@ func TestReconnect(t *testing.T) { RECONNECT: // reconnect. - if err := s.Reconnect(bot); err != nil { + if err := s.Reconnect(); err != nil { t.Fatalf("%v", err) } diff --git a/wrapper/voice.go b/wrapper/voice.go index 269cee5..f3beec7 100644 --- a/wrapper/voice.go +++ b/wrapper/voice.go @@ -152,10 +152,10 @@ VOICESERVERUPDATE: } select { - case <-vc.GatewaySession.Context.Done(): + case <-vc.GatewaySession.manager.context.Done(): vc.VoiceSession.RUnlock() - return <-vc.GatewaySession.manager.actionError + return vc.GatewaySession.manager.Wait() //nolint:wrapcheck default: vc.VoiceSession.RUnlock() //lint:ignore SA4011 break into for loop.