From a26454697d5cdf7bbc780299e524a614a6a6e0e6 Mon Sep 17 00:00:00 2001 From: TheRedRad Date: Fri, 27 Dec 2024 20:45:27 +0100 Subject: [PATCH] fix: (issue #2) race condition issue on server clients --- server.go | 154 +++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 112 insertions(+), 42 deletions(-) diff --git a/server.go b/server.go index 67c0c1c..0db2401 100644 --- a/server.go +++ b/server.go @@ -1,21 +1,23 @@ // license that can be found in the LICENSE file. -// Package udpsocket is a simple UDP server to make a virtual secure channel with the clients +// Package udpsocket is a simple UDP server to make a virtual secure channel with the clientIDMap package udpsocket import ( "bytes" "context" + "crypto/sha256" "errors" "fmt" - "github.com/theredrad/udpsocket/crypto" - "github.com/theredrad/udpsocket/encoding" - "github.com/theredrad/udpsocket/encoding/pb" "io/ioutil" "log" "net" "sync" "time" + + "github.com/theredrad/udpsocket/crypto" + "github.com/theredrad/udpsocket/encoding" + "github.com/theredrad/udpsocket/encoding/pb" ) // HandlerFunc is called when a custom message type is received from the client @@ -92,14 +94,15 @@ type Server struct { asymmCrypto crypto.Asymmetric // an implementation of Asymmetric encryption to decrypt the body of the client handshake hello record symmCrypto crypto.Symmetric // an implementation of Symmetric encryption to encrypt & decrypt records body for the client after a successful handshake handler HandlerFunc // Handler func which is called when a custom record type received - clients map[string]*Client // Map of client with index of client ID - garbageCollectionTicker *time.Ticker // Client garbage collector ticker - garbageCollectionStop chan bool // Client garbage collector stop channel - sessionManager *sessionManager // the Session manager generates cookie & session ID - sessions map[string]*Client // Map of client with index of IP_PORT - rawRecords chan rawRecord // raw records channel - logger *log.Logger // Logger - stop chan bool // stop channel to stop listening + clientIDMap sync.Map // Map of client with index of client ID + clientIPPortMap sync.Map // Map of client ID with index of IP_PORT + clientSessionIDMap sync.Map + garbageCollectionTicker *time.Ticker // Client garbage collector ticker + garbageCollectionStop chan bool // Client garbage collector stop channel + sessionManager *sessionManager // the Session manager generates cookie & session ID + rawRecords chan rawRecord // raw records channel + logger *log.Logger // Logger + stop chan bool // stop channel to stop listening wg *sync.WaitGroup } @@ -109,8 +112,9 @@ func NewServer(conn *net.UDPConn, options ...Option) (*Server, error) { s := Server{ conn: conn, - clients: make(map[string]*Client), - sessions: make(map[string]*Client), + clientIDMap: sync.Map{}, + clientIPPortMap: sync.Map{}, + clientSessionIDMap: sync.Map{}, garbageCollectionStop: make(chan bool, 1), stop: make(chan bool, 1), @@ -357,7 +361,7 @@ func (s *Server) handleHandshakeRecord(ctx context.Context, addr *net.UDPAddr, r // handlePingRecord handles ping record and sends pong response func (s *Server) handlePingRecord(ctx context.Context, addr *net.UDPAddr, r *record) { - cl, ok := s.findClientByAddr(addr) + cl, ok := s.getClientByAddr(addr) if !ok { s.logger.Printf("error while authenticating ping record: %s", ErrClientAddressIsNotRegistered) return @@ -418,13 +422,16 @@ func (s *Server) handlePingRecord(ctx context.Context, addr *net.UDPAddr, r *rec // handleCustomRecord handle custom record with authorizing the record and call the handler func if is set func (s *Server) handleCustomRecord(ctx context.Context, addr *net.UDPAddr, r *record) { - cl, ok := s.findClientByAddr(addr) + cl, ok := s.getClientByAddr(addr) if !ok { s.logger.Printf("error while authenticating other type record: %s", ErrClientAddressIsNotRegistered) s.unAuthenticated(addr) return } + cl.Lock() + defer cl.Unlock() + payload, err := s.symmCrypto.Decrypt(r.Body, cl.eKey) if err != nil { s.logger.Printf("error while decrypting other type record: %s", err) @@ -449,9 +456,8 @@ func (s *Server) handleCustomRecord(ctx context.Context, addr *net.UDPAddr, r *r } now := time.Now() - cl.Lock() + cl.lastHeartbeat = &now - cl.Unlock() } // parseSessionID parses the session ID from the record decrypted body, the session ID must prepend to the body before encryption in the client @@ -470,32 +476,53 @@ func (s *Server) registerClient(addr *net.UDPAddr, ID string, eKey []byte) (*Cli } now := time.Now() - cl := &Client{ + cl := Client{ ID: ID, sessionID: sessionID, addr: addr, eKey: eKey, lastHeartbeat: &now, } - s.clients[ID] = cl - s.sessions[fmt.Sprintf("%s_%d", addr.IP.String(), addr.Port)] = cl - return cl, nil + // lock the client to avoid manipulating the client when the indexing is in progress. + // this manipulation might happen when user is added to clientIDMap and rest of indexing is in progress + cl.Lock() + defer cl.Unlock() + + s.clientIDMap.Store(ID, &cl) + s.clientIPPortMap.Store(clientAddrKey(addr), &cl) + s.clientSessionIDMap.Store(hashSessionID(cl.sessionID), &cl) + + return &cl, nil +} + +func (s *Server) deregisterClient(client *Client) { + client.Lock() + defer client.Unlock() + + s.clientIDMap.Delete(client.ID) + s.clientSessionIDMap.Delete(hashSessionID(client.sessionID)) + s.clientIPPortMap.Delete(clientAddrKey(client.addr)) } // returns the Client by the Session ID -func (s *Server) findClientBySessionID(sessionID []byte) (*Client, bool) { - for _, client := range s.clients { - if bytes.Equal(client.sessionID, sessionID) { - return client, true - } +func (s *Server) getClientBySessionID(sessionID []byte) (*Client, bool) { + cl, ok := s.clientSessionIDMap.Load(hashSessionID(sessionID)) + if !ok { + return nil, false + } + + client, ok := cl.(*Client) + if !ok { + return nil, false } - return nil, false + + return client, true } // returns the Client ID by the session ID -func (s *Server) findClientIDBySessionID(sessionID []byte) (string, bool) { - cl, ok := s.findClientBySessionID(sessionID) +func (s *Server) getClientIDBySessionID(sessionID []byte) (string, bool) { + cl, ok := s.getClientBySessionID(sessionID) if !ok { return "", ok } @@ -504,13 +531,33 @@ func (s *Server) findClientIDBySessionID(sessionID []byte) (string, bool) { } // returns the Client by IP & Port -func (s *Server) findClientByAddr(addr *net.UDPAddr) (*Client, bool) { - cl, ok := s.sessions[fmt.Sprintf("%s_%d", addr.IP.String(), addr.Port)] +func (s *Server) getClientByAddr(addr *net.UDPAddr) (*Client, bool) { + cl, ok := s.clientIPPortMap.Load(clientAddrKey(addr)) if !ok { - return nil, ok + return nil, false } - return cl, true + client, ok := cl.(*Client) + if !ok { + return nil, false + } + + return client, true +} + +// returns client by ID +func (s *Server) getClientByID(clientID string) (*Client, bool) { + cl, ok := s.clientIDMap.Load(clientID) + if !ok { + return nil, false + } + + client, ok := cl.(*Client) + if !ok { + return nil, false + } + + return client, true } // sendToAddr writes record bytes to the UDP address @@ -531,7 +578,7 @@ func (s *Server) sendToClient(client *Client, typ byte, payload []byte) error { // SendToClientByID sends bytes to the Client by ID func (s *Server) SendToClientByID(clientID string, typ byte, payload []byte) error { - cl, ok := s.clients[clientID] + cl, ok := s.getClientByID(clientID) if !ok { return ErrClientNotFound } @@ -548,19 +595,32 @@ func (s *Server) clientGarbageCollection() { } break case <-s.garbageCollectionTicker.C: - for _, c := range s.clients { - if c.lastHeartbeat != nil && time.Now().After(c.lastHeartbeat.Add(s.heartbeatExpiration)) { - delete(s.clients, c.ID) - delete(s.sessions, fmt.Sprintf("%s_%d", c.addr.IP.String(), c.addr.Port)) + s.clientIDMap.Range(func(clientID, client interface{}) bool { + cl, ok := client.(*Client) + if !ok { + s.logger.Printf("error while collecting garbage client: invalid client type in the map") + return true } - } + + if cl.lastHeartbeat != nil && time.Now().After(cl.lastHeartbeat.Add(s.heartbeatExpiration)) { + s.deregisterClient(cl) + } + + return true + }) } } } // BroadcastToClients broadcasts bytes to all registered Clients func (s *Server) BroadcastToClients(typ byte, payload []byte) { - for _, cl := range s.clients { + s.clientIDMap.Range(func(clientID, client interface{}) bool { + cl, ok := client.(*Client) + if !ok { + s.logger.Printf("error while collecting garbage client: invalid client type in the map") + return true + } + s.wg.Add(1) go func(c *Client) { defer s.wg.Done() @@ -569,7 +629,9 @@ func (s *Server) BroadcastToClients(typ byte, payload []byte) { s.logger.Printf("error while writing to the client: %s", err) } }(cl) - } + + return true + }) } func (s *Server) unAuthenticated(addr *net.UDPAddr) { @@ -616,6 +678,14 @@ func parseRecord(rec []byte) (*record, error) { }, nil } +func clientAddrKey(addr *net.UDPAddr) string { + return fmt.Sprintf("%s_%d", addr.IP.String(), addr.Port) +} + +func hashSessionID(sessionID []byte) [32]byte { + return sha256.Sum256(sessionID) +} + // ValidateSessionID compares the client session ID with the given one func (c *Client) ValidateSessionID(sessionID []byte) bool { if bytes.Equal(c.sessionID, sessionID) {