diff --git a/server/control/clients.go b/server/control/clients.go index 30a81ea..9144aa8 100644 --- a/server/control/clients.go +++ b/server/control/clients.go @@ -64,13 +64,15 @@ func newClientServer( return nil, fmt.Errorf("client peers store open: %w", err) } - connsMsgs, _, err := conns.Snapshot() + connsMsgs, connsOffset, err := conns.Snapshot() if err != nil { return nil, fmt.Errorf("client snapshot: %w", err) } + connsCache := map[ClientConnKey]chan struct{}{} reactivate := map[ClientConnKey][]ClientPeerKey{} for _, msg := range connsMsgs { + connsCache[msg.Key] = make(chan struct{}) reactivate[msg.Key] = []ClientPeerKey{} } @@ -134,6 +136,9 @@ func newClientServer( conns: conns, peers: peers, + connsCache: connsCache, + connsOffset: connsOffset, + peersCache: peersCache, peersOffset: peersOffset, @@ -154,6 +159,10 @@ type clientServer struct { conns logc.KV[ClientConnKey, ClientConnValue] peers logc.KV[ClientPeerKey, ClientPeerValue] + connsCache map[ClientConnKey]chan struct{} + connsOffset int64 + connsMu sync.RWMutex + peersCache map[cacheKey][]*pbclient.RemotePeer peersOffset int64 peersMu sync.RWMutex @@ -162,12 +171,24 @@ type clientServer struct { reactivateMu sync.RWMutex } -func (s *clientServer) connected(id ClientID, auth ClientAuthentication, remote net.Addr, metadata string) error { +func (s *clientServer) connected(ctx context.Context, id ClientID, auth ClientAuthentication, remote net.Addr, metadata string) error { s.reactivateMu.Lock() delete(s.reactivate, ClientConnKey{id}) s.reactivateMu.Unlock() - return s.conns.Put(ClientConnKey{id}, ClientConnValue{auth, remote.String(), metadata}) + key := ClientConnKey{id} + if connCh := s.cachedConn(key); connCh != nil { + s.logger.Debug("client connection still active, waiting for connection close", "client-id", id, "addr", remote, "metadata", metadata) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(10 * time.Second): + return fmt.Errorf("another connection still active: %s", id) + case <-connCh: + } + } + + return s.conns.Put(key, ClientConnValue{auth, remote.String(), metadata}) } func (s *clientServer) disconnected(id ClientID) error { @@ -182,6 +203,13 @@ func (s *clientServer) revoke(endpoint model.Endpoint, role model.Role, id Clien return s.peers.Del(ClientPeerKey{endpoint, role, id}) } +func (s *clientServer) cachedConn(key ClientConnKey) chan struct{} { + s.connsMu.RLock() + defer s.connsMu.RUnlock() + + return s.connsCache[key] +} + func (s *clientServer) cachedPeers(endpoint model.Endpoint, role model.Role) ([]*pbclient.RemotePeer, int64) { s.peersMu.RLock() defer s.peersMu.RUnlock() @@ -243,6 +271,7 @@ func (s *clientServer) run(ctx context.Context) error { for _, ingress := range s.ingresses { g.Go(reliable.Bind(ingress, s.runListener)) } + g.Go(s.runClientsCache) g.Go(s.runPeerCache) g.Go(s.runCleaner) @@ -315,6 +344,49 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { } } +func (s *clientServer) runClientsCache(ctx context.Context) error { + update := func(msg logc.Message[ClientConnKey, ClientConnValue]) error { + s.connsMu.Lock() + defer s.connsMu.Unlock() + + connCh := s.connsCache[msg.Key] + if msg.Delete { + if connCh != nil { + close(connCh) + delete(s.connsCache, msg.Key) + } + } else { + if connCh == nil { + s.connsCache[msg.Key] = make(chan struct{}) + } + } + + s.connsOffset = msg.Offset + 1 + return nil + } + + for { + s.connsMu.RLock() + offset := s.connsOffset + s.connsMu.RUnlock() + + msgs, nextOffset, err := s.conns.Consume(ctx, offset) + if err != nil { + return err + } + + for _, msg := range msgs { + if err := update(msg); err != nil { + return err + } + } + + s.connsMu.Lock() + s.connsOffset = nextOffset + s.connsMu.Unlock() + } +} + func (s *clientServer) runPeerCache(ctx context.Context) error { update := func(msg logc.Message[ClientPeerKey, ClientPeerValue]) error { s.peersMu.Lock() @@ -465,7 +537,7 @@ func (c *clientConn) runErr(ctx context.Context) error { c.logger.Info("client connected", "addr", c.conn.RemoteAddr(), "metadata", c.metadata) defer c.logger.Info("client disconnected", "addr", c.conn.RemoteAddr(), "metadata", c.metadata) - if err := c.server.connected(c.id, c.auth, c.conn.RemoteAddr(), c.metadata); err != nil { + if err := c.server.connected(ctx, c.id, c.auth, c.conn.RemoteAddr(), c.metadata); err != nil { return err } defer func() {