Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 76 additions & 4 deletions server/control/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ func newClientServer(
return nil, fmt.Errorf("client peers store open: %w", err)
}

connsMsgs, _, err := conns.Snapshot()
connsMsgs, connsOffset, err := conns.Snapshot()
if err != nil {
return nil, fmt.Errorf("client snapshot: %w", err)
}

connsCache := map[ClientConnKey]chan struct{}{}
reactivate := map[ClientConnKey][]ClientPeerKey{}
for _, msg := range connsMsgs {
connsCache[msg.Key] = make(chan struct{})
reactivate[msg.Key] = []ClientPeerKey{}
}

Expand Down Expand Up @@ -134,6 +136,9 @@ func newClientServer(
conns: conns,
peers: peers,

connsCache: connsCache,
connsOffset: connsOffset,

peersCache: peersCache,
peersOffset: peersOffset,

Expand All @@ -154,6 +159,10 @@ type clientServer struct {
conns logc.KV[ClientConnKey, ClientConnValue]
peers logc.KV[ClientPeerKey, ClientPeerValue]

connsCache map[ClientConnKey]chan struct{}
connsOffset int64
connsMu sync.RWMutex

peersCache map[cacheKey][]*pbclient.RemotePeer
peersOffset int64
peersMu sync.RWMutex
Expand All @@ -162,12 +171,24 @@ type clientServer struct {
reactivateMu sync.RWMutex
}

func (s *clientServer) connected(id ClientID, auth ClientAuthentication, remote net.Addr, metadata string) error {
func (s *clientServer) connected(ctx context.Context, id ClientID, auth ClientAuthentication, remote net.Addr, metadata string) error {
s.reactivateMu.Lock()
delete(s.reactivate, ClientConnKey{id})
s.reactivateMu.Unlock()

return s.conns.Put(ClientConnKey{id}, ClientConnValue{auth, remote.String(), metadata})
key := ClientConnKey{id}
if connCh := s.cachedConn(key); connCh != nil {
s.logger.Debug("client connection still active, waiting for connection close", "client-id", id, "addr", remote, "metadata", metadata)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(10 * time.Second):
return fmt.Errorf("another connection still active: %s", id)
case <-connCh:
}
}

return s.conns.Put(key, ClientConnValue{auth, remote.String(), metadata})
}

func (s *clientServer) disconnected(id ClientID) error {
Expand All @@ -182,6 +203,13 @@ func (s *clientServer) revoke(endpoint model.Endpoint, role model.Role, id Clien
return s.peers.Del(ClientPeerKey{endpoint, role, id})
}

func (s *clientServer) cachedConn(key ClientConnKey) chan struct{} {
s.connsMu.RLock()
defer s.connsMu.RUnlock()

return s.connsCache[key]
}

func (s *clientServer) cachedPeers(endpoint model.Endpoint, role model.Role) ([]*pbclient.RemotePeer, int64) {
s.peersMu.RLock()
defer s.peersMu.RUnlock()
Expand Down Expand Up @@ -243,6 +271,7 @@ func (s *clientServer) run(ctx context.Context) error {
for _, ingress := range s.ingresses {
g.Go(reliable.Bind(ingress, s.runListener))
}
g.Go(s.runClientsCache)
g.Go(s.runPeerCache)
g.Go(s.runCleaner)

Expand Down Expand Up @@ -315,6 +344,49 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error {
}
}

func (s *clientServer) runClientsCache(ctx context.Context) error {
update := func(msg logc.Message[ClientConnKey, ClientConnValue]) error {
s.connsMu.Lock()
defer s.connsMu.Unlock()

connCh := s.connsCache[msg.Key]
if msg.Delete {
if connCh != nil {
close(connCh)
delete(s.connsCache, msg.Key)
}
} else {
if connCh == nil {
s.connsCache[msg.Key] = make(chan struct{})
}
}

s.connsOffset = msg.Offset + 1
return nil
}

for {
s.connsMu.RLock()
offset := s.connsOffset
s.connsMu.RUnlock()

msgs, nextOffset, err := s.conns.Consume(ctx, offset)
if err != nil {
return err
}

for _, msg := range msgs {
if err := update(msg); err != nil {
return err
}
}

s.connsMu.Lock()
s.connsOffset = nextOffset
s.connsMu.Unlock()
}
}

func (s *clientServer) runPeerCache(ctx context.Context) error {
update := func(msg logc.Message[ClientPeerKey, ClientPeerValue]) error {
s.peersMu.Lock()
Expand Down Expand Up @@ -465,7 +537,7 @@ func (c *clientConn) runErr(ctx context.Context) error {
c.logger.Info("client connected", "addr", c.conn.RemoteAddr(), "metadata", c.metadata)
defer c.logger.Info("client disconnected", "addr", c.conn.RemoteAddr(), "metadata", c.metadata)

if err := c.server.connected(c.id, c.auth, c.conn.RemoteAddr(), c.metadata); err != nil {
if err := c.server.connected(ctx, c.id, c.auth, c.conn.RemoteAddr(), c.metadata); err != nil {
return err
}
defer func() {
Expand Down