Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 112 additions & 42 deletions server.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
// license that can be found in the LICENSE file.

// Package udpsocket is a simple UDP server to make a virtual secure channel with the clients
// Package udpsocket is a simple UDP server to make a virtual secure channel with the clientIDMap
package udpsocket

import (
"bytes"
"context"
"crypto/sha256"
"errors"
"fmt"
"github.com/theredrad/udpsocket/crypto"
"github.com/theredrad/udpsocket/encoding"
"github.com/theredrad/udpsocket/encoding/pb"
"io/ioutil"
"log"
"net"
"sync"
"time"

"github.com/theredrad/udpsocket/crypto"
"github.com/theredrad/udpsocket/encoding"
"github.com/theredrad/udpsocket/encoding/pb"
)

// HandlerFunc is called when a custom message type is received from the client
Expand Down Expand Up @@ -92,14 +94,15 @@ type Server struct {
asymmCrypto crypto.Asymmetric // an implementation of Asymmetric encryption to decrypt the body of the client handshake hello record
symmCrypto crypto.Symmetric // an implementation of Symmetric encryption to encrypt & decrypt records body for the client after a successful handshake
handler HandlerFunc // Handler func which is called when a custom record type received
clients map[string]*Client // Map of client with index of client ID
garbageCollectionTicker *time.Ticker // Client garbage collector ticker
garbageCollectionStop chan bool // Client garbage collector stop channel
sessionManager *sessionManager // the Session manager generates cookie & session ID
sessions map[string]*Client // Map of client with index of IP_PORT
rawRecords chan rawRecord // raw records channel
logger *log.Logger // Logger
stop chan bool // stop channel to stop listening
clientIDMap sync.Map // Map of client with index of client ID
clientIPPortMap sync.Map // Map of client ID with index of IP_PORT
clientSessionIDMap sync.Map
garbageCollectionTicker *time.Ticker // Client garbage collector ticker
garbageCollectionStop chan bool // Client garbage collector stop channel
sessionManager *sessionManager // the Session manager generates cookie & session ID
rawRecords chan rawRecord // raw records channel
logger *log.Logger // Logger
stop chan bool // stop channel to stop listening
wg *sync.WaitGroup
}

Expand All @@ -109,8 +112,9 @@ func NewServer(conn *net.UDPConn, options ...Option) (*Server, error) {
s := Server{
conn: conn,

clients: make(map[string]*Client),
sessions: make(map[string]*Client),
clientIDMap: sync.Map{},
clientIPPortMap: sync.Map{},
clientSessionIDMap: sync.Map{},

garbageCollectionStop: make(chan bool, 1),
stop: make(chan bool, 1),
Expand Down Expand Up @@ -357,7 +361,7 @@ func (s *Server) handleHandshakeRecord(ctx context.Context, addr *net.UDPAddr, r

// handlePingRecord handles ping record and sends pong response
func (s *Server) handlePingRecord(ctx context.Context, addr *net.UDPAddr, r *record) {
cl, ok := s.findClientByAddr(addr)
cl, ok := s.getClientByAddr(addr)
if !ok {
s.logger.Printf("error while authenticating ping record: %s", ErrClientAddressIsNotRegistered)
return
Expand Down Expand Up @@ -418,13 +422,16 @@ func (s *Server) handlePingRecord(ctx context.Context, addr *net.UDPAddr, r *rec

// handleCustomRecord handle custom record with authorizing the record and call the handler func if is set
func (s *Server) handleCustomRecord(ctx context.Context, addr *net.UDPAddr, r *record) {
cl, ok := s.findClientByAddr(addr)
cl, ok := s.getClientByAddr(addr)
if !ok {
s.logger.Printf("error while authenticating other type record: %s", ErrClientAddressIsNotRegistered)
s.unAuthenticated(addr)
return
}

cl.Lock()
defer cl.Unlock()

payload, err := s.symmCrypto.Decrypt(r.Body, cl.eKey)
if err != nil {
s.logger.Printf("error while decrypting other type record: %s", err)
Expand All @@ -449,9 +456,8 @@ func (s *Server) handleCustomRecord(ctx context.Context, addr *net.UDPAddr, r *r
}

now := time.Now()
cl.Lock()

cl.lastHeartbeat = &now
cl.Unlock()
}

// parseSessionID parses the session ID from the record decrypted body, the session ID must prepend to the body before encryption in the client
Expand All @@ -470,32 +476,53 @@ func (s *Server) registerClient(addr *net.UDPAddr, ID string, eKey []byte) (*Cli
}

now := time.Now()
cl := &Client{
cl := Client{
ID: ID,
sessionID: sessionID,
addr: addr,
eKey: eKey,
lastHeartbeat: &now,
}
s.clients[ID] = cl
s.sessions[fmt.Sprintf("%s_%d", addr.IP.String(), addr.Port)] = cl

return cl, nil
// lock the client to avoid manipulating the client when the indexing is in progress.
// this manipulation might happen when user is added to clientIDMap and rest of indexing is in progress
cl.Lock()
defer cl.Unlock()

s.clientIDMap.Store(ID, &cl)
s.clientIPPortMap.Store(clientAddrKey(addr), &cl)
s.clientSessionIDMap.Store(hashSessionID(cl.sessionID), &cl)

return &cl, nil
}

func (s *Server) deregisterClient(client *Client) {
client.Lock()
defer client.Unlock()

s.clientIDMap.Delete(client.ID)
s.clientSessionIDMap.Delete(hashSessionID(client.sessionID))
s.clientIPPortMap.Delete(clientAddrKey(client.addr))
}

// returns the Client by the Session ID
func (s *Server) findClientBySessionID(sessionID []byte) (*Client, bool) {
for _, client := range s.clients {
if bytes.Equal(client.sessionID, sessionID) {
return client, true
}
func (s *Server) getClientBySessionID(sessionID []byte) (*Client, bool) {
cl, ok := s.clientSessionIDMap.Load(hashSessionID(sessionID))
if !ok {
return nil, false
}

client, ok := cl.(*Client)
if !ok {
return nil, false
}
return nil, false

return client, true
}

// returns the Client ID by the session ID
func (s *Server) findClientIDBySessionID(sessionID []byte) (string, bool) {
cl, ok := s.findClientBySessionID(sessionID)
func (s *Server) getClientIDBySessionID(sessionID []byte) (string, bool) {
cl, ok := s.getClientBySessionID(sessionID)
if !ok {
return "", ok
}
Expand All @@ -504,13 +531,33 @@ func (s *Server) findClientIDBySessionID(sessionID []byte) (string, bool) {
}

// returns the Client by IP & Port
func (s *Server) findClientByAddr(addr *net.UDPAddr) (*Client, bool) {
cl, ok := s.sessions[fmt.Sprintf("%s_%d", addr.IP.String(), addr.Port)]
func (s *Server) getClientByAddr(addr *net.UDPAddr) (*Client, bool) {
cl, ok := s.clientIPPortMap.Load(clientAddrKey(addr))
if !ok {
return nil, ok
return nil, false
}

return cl, true
client, ok := cl.(*Client)
if !ok {
return nil, false
}

return client, true
}

// returns client by ID
func (s *Server) getClientByID(clientID string) (*Client, bool) {
cl, ok := s.clientIDMap.Load(clientID)
if !ok {
return nil, false
}

client, ok := cl.(*Client)
if !ok {
return nil, false
}

return client, true
}

// sendToAddr writes record bytes to the UDP address
Expand All @@ -531,7 +578,7 @@ func (s *Server) sendToClient(client *Client, typ byte, payload []byte) error {

// SendToClientByID sends bytes to the Client by ID
func (s *Server) SendToClientByID(clientID string, typ byte, payload []byte) error {
cl, ok := s.clients[clientID]
cl, ok := s.getClientByID(clientID)
if !ok {
return ErrClientNotFound
}
Expand All @@ -548,19 +595,32 @@ func (s *Server) clientGarbageCollection() {
}
break
case <-s.garbageCollectionTicker.C:
for _, c := range s.clients {
if c.lastHeartbeat != nil && time.Now().After(c.lastHeartbeat.Add(s.heartbeatExpiration)) {
delete(s.clients, c.ID)
delete(s.sessions, fmt.Sprintf("%s_%d", c.addr.IP.String(), c.addr.Port))
s.clientIDMap.Range(func(clientID, client interface{}) bool {
cl, ok := client.(*Client)
if !ok {
s.logger.Printf("error while collecting garbage client: invalid client type in the map")
return true
}
}

if cl.lastHeartbeat != nil && time.Now().After(cl.lastHeartbeat.Add(s.heartbeatExpiration)) {
s.deregisterClient(cl)
}

return true
})
}
}
}

// BroadcastToClients broadcasts bytes to all registered Clients
func (s *Server) BroadcastToClients(typ byte, payload []byte) {
for _, cl := range s.clients {
s.clientIDMap.Range(func(clientID, client interface{}) bool {
cl, ok := client.(*Client)
if !ok {
s.logger.Printf("error while collecting garbage client: invalid client type in the map")
return true
}

s.wg.Add(1)
go func(c *Client) {
defer s.wg.Done()
Expand All @@ -569,7 +629,9 @@ func (s *Server) BroadcastToClients(typ byte, payload []byte) {
s.logger.Printf("error while writing to the client: %s", err)
}
}(cl)
}

return true
})
}

func (s *Server) unAuthenticated(addr *net.UDPAddr) {
Expand Down Expand Up @@ -616,6 +678,14 @@ func parseRecord(rec []byte) (*record, error) {
}, nil
}

func clientAddrKey(addr *net.UDPAddr) string {
return fmt.Sprintf("%s_%d", addr.IP.String(), addr.Port)
}

func hashSessionID(sessionID []byte) [32]byte {
return sha256.Sum256(sessionID)
}

// ValidateSessionID compares the client session ID with the given one
func (c *Client) ValidateSessionID(sessionID []byte) bool {
if bytes.Equal(c.sessionID, sessionID) {
Expand Down