Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
443 changes: 268 additions & 175 deletions disgo.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions shard/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ func (sm *InstanceShardManager) Disconnect() error {
}

// Reconnect connects to the Discord Gateway using the Shard Manager.
func (sm *InstanceShardManager) Reconnect(bot *disgo.Client) error {
func (sm *InstanceShardManager) Reconnect() error {
// totalShards represents the total number of shards that are connected.
totalShards := len(sm.Sessions)

for sessionCount := 0; sessionCount < totalShards; sessionCount++ {
if err := sm.Sessions[sessionCount].Reconnect(bot); err != nil {
if err := sm.Sessions[sessionCount].Reconnect(); err != nil {
return fmt.Errorf(errShardManager, err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion shard/tests/integration/shard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestReconnect(t *testing.T) {
time.Sleep(time.Second)

// reconnect to the Discord Gateway (WebSocket Session).
if err := s.Reconnect(bot); err != nil {
if err := s.Reconnect(); err != nil {
t.Fatalf("%v", err)
}

Expand Down
2 changes: 2 additions & 0 deletions wrapper/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ func putSession(s *Session) {
s.Endpoint = ""
s.Shard = nil
s.Context = nil
s.cancel = nil
s.Conn = nil
s.setState(SessionStateNew)
s.heartbeat = nil
s.manager = nil
s.client_manager = nil
Expand Down
107 changes: 97 additions & 10 deletions wrapper/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"

"github.com/switchupcb/websocket"
)
Expand Down Expand Up @@ -32,9 +33,19 @@ type Session struct {
// Context is also used as a signal for the Session's goroutines.
Context context.Context

// Cancel represents the cancellation signal for a Session Context.
cancel context.CancelFunc

// Conn represents a WebSocket Connection to the Discord Gateway.
Conn *websocket.Conn

// state represents the state of the Session's connection to Discord.
state string

// stateMutex is used to protect the Session's manager state from data races
// by providing transactional functionality.
stateMutex sync.RWMutex

// heartbeat contains the fields required to implement the heartbeat mechanism.
heartbeat *heartbeat

Expand All @@ -52,6 +63,45 @@ type Session struct {
sync.RWMutex
}

// Session States represent the state of the Session's connection to Discord.
const (
SessionStateNew = ""

SessionStateConnecting = "connecting (before websocket connection)"
SessionStateConnectingWebsocket = "connecting (with websocket connection)"
SessionStateConnected = "connected"

SessionStateDisconnecting = "disconnecting (purposefully)"
SessionStateDisconnectingError = "disconnecting (due to an error)"
SessionStateDisconnectingReconnect = "disconnecting (while reconnecting)"

SessionStateDisconnectedFinal = "disconnected (after connection)"
SessionStateDisconnectedError = "disconnected (due to an error)"
SessionStateDisconnectedReconnect = "disconnected (while reconnecting)"

SessionStateReconnecting = "reconnecting"
)

// State returns the state of the Session's connection to Discord.
func (s *Session) State() string {
s.stateMutex.RLock()
defer s.stateMutex.RUnlock()

return s.state
}

// setState sets the state of a Session.
func (s *Session) setState(state string) {
s.stateMutex.Lock()
s.state = state
s.stateMutex.Unlock()
}

// canReconnect returns whether the Session's fields are in a valid state to reconnect.
func (s *Session) canReconnect() bool {
return s.ID != "" && s.Endpoint != "" && atomic.LoadInt64(&s.Seq) != 0
}

// Connect connects a session to the Discord Gateway (WebSocket Connection).
func (s *Session) Connect(bot *Client) error {
if bot == nil {
Expand Down Expand Up @@ -80,8 +130,23 @@ func (s *Session) Connect(bot *Client) error {
s.manager.signals <- sessionSignalConnect
s.Unlock()

if err := <-s.manager.actionError; err != nil {
return err
// wait until the Session has connected
for {
select {
// Context is cancelled during connection when the manager returns an error
// or disconnects from another goroutine call.
case <-s.manager.context.Done():
return s.manager.coroner.Wait() //nolint:wrapcheck
default:
break
}

// Session is SessionStateConnected after connection.
//
// proof: Calling Connect() during connection cannot happen while the manager exists.
if s.State() == SessionStateConnected {
break
}
}

return nil
Expand All @@ -99,19 +164,17 @@ func (s *Session) Disconnect() error {
s.manager.signals <- sessionSignalDisconnect
s.Unlock()

if err := <-s.manager.actionError; err != nil {
return err
// Session is disconnected from a Disconnect() call when the coroner shuts down.
if err := s.manager.coroner.Wait(); err != nil {
return err //nolint:wrapcheck
}

// Reset the session.
putSession(s)

return nil
}

// Reconnect reconnects an already connected session to the Discord Gateway
// by disconnecting the session, then connecting again.
func (s *Session) Reconnect(bot *Client) error {
func (s *Session) Reconnect() error {
s.Lock()
if s.manager == nil || s.State() != SessionStateConnected {
s.Unlock()
Expand All @@ -122,8 +185,32 @@ func (s *Session) Reconnect(bot *Client) error {
s.manager.signals <- sessionSignalReconnect
s.Unlock()

if err := <-s.manager.actionError; err != nil {
return fmt.Errorf("reconnect: %w", err)
// wait until the manager has received the sessionSignalReconnect
// or changed state to another signal.
for {
if s.State() != SessionStateConnected {
break
}
}

// wait until the Session has reconnected
// or has experienced an error during reconnection.
for {
select {
// Context is cancelled during reconnection when the manager returns an error
// or disconnects from another goroutine call.
case <-s.manager.context.Done():
return s.manager.coroner.Wait() //nolint:wrapcheck
default:
break
}

// Session is SessionStateConnected after reconnection.
//
// proof: Calling Connect() during reconnection cannot happen while the manager exists.
if s.State() == SessionStateConnected {
break
}
}

return nil
Expand Down
21 changes: 8 additions & 13 deletions wrapper/session_routine_coroner.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,19 @@ package wrapper

// coroner investigates when a Session's goroutines are shutdown.
func (s *Session) coroner() {
// wait until all the manager goroutines is closed.
err := s.manager.coroner.Wait()

s.Lock()

// report the disconnection error
s.manager.actionError <- err
close(s.manager.actionError)
// wait until the manager goroutine is closed.
if err := s.manager.coroner.Wait(); err != nil {
LogSession(Logger.Error(), s.ID).Err(err).Msg("coroner manager routine error")
}

// remove the session from the client.
s.client_manager.RemoveGatewaySession(s.ID)
// Reset the session.
putSession(s)

s.logClose("coroner")
s.Unlock()
Logger.Info().Msg("closed coroner routine")
}

// Wait blocks until the calling Session is inactive (due to a final disconnect),
// then returns the Session's state and the disconnection error (if it exists).
// then returns the Session's state and the disconnection error (when it exists).
//
// If Wait() is called on a Session that isn't connected, it will return immediately
// with code SessionStateNew.
Expand Down
9 changes: 2 additions & 7 deletions wrapper/session_routine_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,11 @@ func (s *Session) onPayload(bot *Client, payload GatewayPayload) error {

// in the context of onPayload, an Invalid Session occurs when an active session is invalidated.
case FlagGatewayOpcodeInvalidSession:
// Remove the session from the session manager.
s.client_manager.RemoveGatewaySession(s.ID)

// wait for Discord to close the session, then complete a fresh connect.
<-time.NewTimer(invalidSessionWaitTime).C

s.Lock()
defer s.Unlock()

if err := s.initial(bot, 0); err != nil {
_ = s.Disconnect()
if err := s.Connect(bot); err != nil {
return err
}
}
Expand Down
Loading
Loading