diff --git a/ably/realtime/connection.py b/ably/realtime/connection.py index a810ea3a..907f56a5 100644 --- a/ably/realtime/connection.py +++ b/ably/realtime/connection.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import functools import logging from typing import TYPE_CHECKING @@ -64,7 +65,7 @@ async def close(self) -> None: connection without an explicit call to connect() """ self.connection_manager.request_state(ConnectionState.CLOSING) - await self.once_async(ConnectionState.CLOSED) + await self._when_state(ConnectionState.CLOSED) # RTN13 async def ping(self) -> float: @@ -86,6 +87,13 @@ async def ping(self) -> float: """ return await self.__connection_manager.ping() + def _when_state(self, state: ConnectionState): + if self.state == state: + fut = asyncio.get_event_loop().create_future() + fut.set_result(None) + return fut + return self.once_async(state) + def _on_state_update(self, state_change: ConnectionStateChange) -> None: log.info(f'Connection state changing from {self.state} to {state_change.current}') self.__state = state_change.current diff --git a/ably/realtime/connectionmanager.py b/ably/realtime/connectionmanager.py index d555bb9b..01a0735b 100644 --- a/ably/realtime/connectionmanager.py +++ b/ably/realtime/connectionmanager.py @@ -135,6 +135,17 @@ def enact_state_change(self, state: ConnectionState, reason: AblyException | Non self.__state = state if reason: self.__error_reason = reason + + # RTN16d: Clear connection state when entering SUSPENDED or terminal states + if state == ConnectionState.SUSPENDED or state in ( + ConnectionState.CLOSED, + ConnectionState.FAILED + ): + self.__connection_details = None + self.connection_id = None + self.__connection_key = None + self.msg_serial = 0 + self._emit('connectionstate', ConnectionStateChange(current_state, state, state, reason)) def check_connection(self) -> bool: @@ -157,6 +168,10 @@ async def __get_transport_params(self) -> dict: # RTN2a: Set format to msgpack if use_binary_protocol is enabled if self.options.use_binary_protocol: params["format"] = "msgpack" + + # Add any custom transport params from options + params.update(self.options.transport_params) + return params async def close_impl(self) -> None: @@ -165,13 +180,23 @@ async def close_impl(self) -> None: self.cancel_suspend_timer() self.start_transition_timer(ConnectionState.CLOSING, fail_state=ConnectionState.CLOSED) if self.transport: - await self.transport.dispose() + # Try to send protocol CLOSE message in the background + asyncio.create_task(self.transport.close()) + # Yield to event loop to give the close message a chance to send + await asyncio.sleep(0) + await self.transport.dispose() # Dispose transport resources if self.connect_base_task: self.connect_base_task.cancel() if self.disconnect_transport_task: await self.disconnect_transport_task self.cancel_retry_timer() + # Clear connection details to prevent resume on next connect + # When explicitly closed, we want a fresh connection, not a resume + self.__connection_details = None + self.connection_id = None + self.msg_serial = 0 + self.notify_state(ConnectionState.CLOSED) async def send_protocol_message(self, protocol_message: dict) -> None: @@ -648,7 +673,6 @@ def on_suspend_timer_expire() -> None: AblyException("Connection to server unavailable", 400, 80002) ) self.__fail_state = ConnectionState.SUSPENDED - self.__connection_details = None self.suspend_timer = Timer(Defaults.connection_state_ttl, on_suspend_timer_expire) diff --git a/ably/realtime/presencemap.py b/ably/realtime/presencemap.py new file mode 100644 index 00000000..9c5adace --- /dev/null +++ b/ably/realtime/presencemap.py @@ -0,0 +1,351 @@ +""" +PresenceMap - Manages the state of presence members on a channel. + +This module implements RTP2 presence map requirements from the Ably specification. +""" + +import logging +from typing import Callable, Dict, List, Optional, Tuple + +from ably.types.presence import PresenceAction, PresenceMessage + +logger = logging.getLogger(__name__) + + +def _is_newer(item: PresenceMessage, existing: PresenceMessage) -> bool: + """ + Compare two presence messages for newness (RTP2b). + + RTP2b1: If either presence message has a connectionId which is not an initial + substring of its id, compare them by timestamp numerically. This will be the + case when one of them is a 'synthesized leave' event. + + RTP2b1a: If the timestamps compare equal, the newly-incoming message is + considered newer than the existing one. + + RTP2b2: Else split the id of both presence messages (format: connid:msgSerial:index) + and compare them first by msgSerial numerically, then by index numerically, + larger being newer in both cases. + + Args: + item: The incoming presence message + existing: The existing presence message in the map + + Returns: + True if item is newer than existing, False otherwise + + Raises: + ValueError: If message ids cannot be parsed for comparison + """ + # RTP2b1: if either is synthesized, compare by timestamp + if item.is_synthesized() or existing.is_synthesized(): + # RTP2b1a: if equal, prefer the newly-arrived one (item) + if item.timestamp is None and existing.timestamp is None: + return True + if item.timestamp is None: + return False + if existing.timestamp is None: + return True + return item.timestamp >= existing.timestamp + + # RTP2b2: compare by msgSerial and index + # parse_id will raise ValueError if id format is invalid + item_parts = item.parse_id() + existing_parts = existing.parse_id() + + if item_parts['msgSerial'] == existing_parts['msgSerial']: + return item_parts['index'] > existing_parts['index'] + else: + return item_parts['msgSerial'] > existing_parts['msgSerial'] + + +class PresenceMap: + """ + Manages the state of presence members on a channel. + + Maintains a map of members keyed by memberKey (connectionId:clientId). + Handles newness comparison, SYNC operations, and member filtering. + + Implements RTP2 specification requirements. + """ + + def __init__( + self, + member_key_fn: Callable[[PresenceMessage], str], + is_newer_fn: Optional[Callable[[PresenceMessage, PresenceMessage], bool]] = None, + logger_instance: Optional[logging.Logger] = None + ): + """ + Initialize a new PresenceMap. + + Args: + member_key_fn: Function to extract member key from a PresenceMessage + is_newer_fn: Optional custom function for newness comparison (default: _is_newer) + logger_instance: Optional logger instance (default: module logger) + """ + self._map: Dict[str, PresenceMessage] = {} + self._residual_members: Optional[Dict[str, PresenceMessage]] = None + self._sync_in_progress = False + self._member_key_fn = member_key_fn + self._is_newer_fn = is_newer_fn or _is_newer + self._logger = logger_instance or logger + self._sync_complete_callbacks: List[Callable[[], None]] = [] + + @property + def sync_in_progress(self) -> bool: + """Returns True if a SYNC operation is currently in progress.""" + return self._sync_in_progress + + def get(self, key: str) -> Optional[PresenceMessage]: + """ + Get a presence member by key. + + Args: + key: The member key (connectionId:clientId) + + Returns: + The PresenceMessage if found, None otherwise + """ + return self._map.get(key) + + def put(self, item: PresenceMessage) -> bool: + """ + Add or update a presence member (RTP2d). + + For ENTER, UPDATE, or PRESENT actions, the message is stored in the map + with action set to PRESENT (if it passes the newness check). + + Args: + item: The presence message to add/update + + Returns: + True if the item was added/updated, False if rejected due to newness check + """ + # RTP2d: ENTER, UPDATE, PRESENT all get stored as PRESENT + if item.action in (PresenceAction.ENTER, PresenceAction.UPDATE, PresenceAction.PRESENT): + # Create a copy with action set to PRESENT + item_to_store = PresenceMessage( + id=item.id, + action=PresenceAction.PRESENT, + client_id=item.client_id, + connection_id=item.connection_id, + data=item.data, + encoding=item.encoding, + timestamp=item.timestamp, + extras=item.extras + ) + else: + item_to_store = item + + key = self._member_key_fn(item_to_store) + if not key: + self._logger.warning("PresenceMap.put: item has no member key, ignoring") + return False + + # If we're in a sync, mark this member as seen (remove from residual) + if self._residual_members is not None and key in self._residual_members: + del self._residual_members[key] + + # Check newness against existing member + existing = self._map.get(key) + if existing and not self._is_newer_fn(item_to_store, existing): + self._logger.debug(f"PresenceMap.put: incoming message for {key} is not newer, ignoring") + return False + + self._map[key] = item_to_store + self._logger.debug(f"PresenceMap.put: added/updated member {key}") + return True + + def remove(self, item: PresenceMessage) -> bool: + """ + Remove a presence member (RTP2h). + + During a SYNC, the member is marked as ABSENT rather than removed. + Outside of SYNC, the member is removed from the map. + + Args: + item: The presence message with LEAVE action + + Returns: + True if a member was removed/marked absent, False if no action taken + """ + key = self._member_key_fn(item) + if not key: + return False + + existing = self._map.get(key) + if not existing: + return False + + # Check newness (RTP2h requires newness check) + if not self._is_newer_fn(item, existing): + self._logger.debug(f"PresenceMap.remove: incoming message for {key} is not newer, ignoring") + return False + + # RTP2h2: During SYNC, mark as ABSENT instead of removing + if self._sync_in_progress: + absent_item = PresenceMessage( + id=item.id, + action=PresenceAction.ABSENT, + client_id=item.client_id, + connection_id=item.connection_id, + data=item.data, + encoding=item.encoding, + timestamp=item.timestamp, + extras=item.extras + ) + self._map[key] = absent_item + self._logger.debug(f"PresenceMap.remove: marked member {key} as ABSENT (sync in progress)") + else: + # RTP2h1: Outside of SYNC, remove the member + del self._map[key] + self._logger.debug(f"PresenceMap.remove: removed member {key}") + + return True + + def values(self) -> List[PresenceMessage]: + """ + Get all presence members (excluding ABSENT members). + + Returns: + List of all PRESENT members + """ + return [ + msg for msg in self._map.values() + if msg.action != PresenceAction.ABSENT + ] + + def list( + self, + client_id: Optional[str] = None, + connection_id: Optional[str] = None + ) -> List[PresenceMessage]: + """ + Get presence members with optional filtering (RTP11). + + Args: + client_id: Optional filter by client ID + connection_id: Optional filter by connection ID + + Returns: + List of matching PRESENT members + """ + result = [] + for msg in self._map.values(): + # Skip ABSENT members + if msg.action == PresenceAction.ABSENT: + continue + + # Apply filters + if client_id and msg.client_id != client_id: + continue + if connection_id and msg.connection_id != connection_id: + continue + + result.append(msg) + + return result + + def start_sync(self) -> None: + """ + Start a SYNC operation (RTP18). + + Captures current members as residual members to track which ones + are not seen during the sync. + """ + self._logger.info(f"PresenceMap.start_sync: starting sync (in_progress={self._sync_in_progress})") + + # May be called multiple times while a sync is in progress + if not self._sync_in_progress: + # Copy current map as residual members + self._residual_members = dict(self._map) + self._sync_in_progress = True + self._logger.debug(f"PresenceMap.start_sync: captured {len(self._residual_members)} residual members") + + def end_sync(self) -> Tuple[List[PresenceMessage], List[PresenceMessage]]: + """ + End a SYNC operation (RTP18, RTP19). + + Removes ABSENT members and returns lists of members that should have + synthesized leave events emitted. + + Returns: + Tuple of (residual_members, absent_members) that need LEAVE events + """ + self._logger.info(f"PresenceMap.end_sync: ending sync (in_progress={self._sync_in_progress})") + + residual_list: List[PresenceMessage] = [] + absent_list: List[PresenceMessage] = [] + + if self._sync_in_progress: + # Collect ABSENT members and remove them from map (RTP2h2b) + keys_to_remove = [] + for key, msg in self._map.items(): + if msg.action == PresenceAction.ABSENT: + absent_list.append(msg) + keys_to_remove.append(key) + + for key in keys_to_remove: + del self._map[key] + + # Collect residual members (members present at start but not seen during sync) + # These need synthesized LEAVE events (RTP19) + if self._residual_members: + residual_list = list(self._residual_members.values()) + # Remove residual members from map + for key in self._residual_members.keys(): + if key in self._map: + del self._map[key] + + self._residual_members = None + self._sync_in_progress = False + self._logger.debug( + f"PresenceMap.end_sync: removed {len(absent_list)} absent members, " + f"{len(residual_list)} residual members" + ) + + # Notify callbacks that sync is complete + for callback in self._sync_complete_callbacks: + try: + callback() + except Exception as e: + self._logger.error(f"Error in sync complete callback: {e}") + self._sync_complete_callbacks.clear() + + return residual_list, absent_list + + def wait_sync(self, callback: Callable[[], None]) -> None: + """ + Wait for SYNC to complete, calling callback when done. + + If sync is not in progress, callback is called immediately. + + Args: + callback: Function to call when sync completes + """ + if not self._sync_in_progress: + callback() + else: + self._sync_complete_callbacks.append(callback) + + def clear(self) -> None: + """ + Clear all members and reset sync state. + + Used when channel enters DETACHED or FAILED state (RTP5a). + Invokes any pending sync callbacks before clearing to ensure + waiting Futures are resolved and callers are not left blocked. + """ + # Notify any callbacks waiting for sync to complete + # This ensures Futures created by _wait_for_sync() are resolved + for callback in self._sync_complete_callbacks: + try: + callback() + except Exception as e: + self._logger.error(f"Error in sync complete callback during clear: {e}") + + self._map.clear() + self._residual_members = None + self._sync_in_progress = False + self._sync_complete_callbacks.clear() + self._logger.debug("PresenceMap.clear: cleared all members") diff --git a/ably/realtime/realtime_channel.py b/ably/realtime/realtime_channel.py index f75b8129..fa6f396d 100644 --- a/ably/realtime/realtime_channel.py +++ b/ably/realtime/realtime_channel.py @@ -12,6 +12,7 @@ from ably.types.flags import Flag, has_flag from ably.types.message import Message from ably.types.mixins import DecodingContext +from ably.types.presence import PresenceMessage from ably.util.eventemitter import EventEmitter from ably.util.exceptions import AblyException, IncompatibleClientIdException from ably.util.helper import Timer, is_callable_or_coroutine, validate_message_size @@ -136,6 +137,10 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp # will be disrupted if the user called .off() to remove all listeners self.__internal_state_emitter = EventEmitter() + # Initialize presence for this channel + from ably.realtime.realtimepresence import RealtimePresence + self.__presence = RealtimePresence(self) + # Pass channel options as dictionary to parent Channel class Channel.__init__(self, realtime, name, self.__channel_options.to_dict()) @@ -529,6 +534,7 @@ def _on_message(self, proto_msg: dict) -> None: error = proto_msg.get("error") exception = None resumed = False + has_presence = False self.__attach_serial = channel_serial self.__channel_serial = channel_serial @@ -539,6 +545,8 @@ def _on_message(self, proto_msg: dict) -> None: if flags: resumed = has_flag(flags, Flag.RESUMED) + # RTP1: Check for HAS_PRESENCE flag + has_presence = has_flag(flags, Flag.HAS_PRESENCE) # RTL12 if self.state == ChannelState.ATTACHED: @@ -546,7 +554,7 @@ def _on_message(self, proto_msg: dict) -> None: state_change = ChannelStateChange(self.state, ChannelState.ATTACHED, resumed, exception) self._emit("update", state_change) elif self.state == ChannelState.ATTACHING: - self._notify_state(ChannelState.ATTACHED, resumed=resumed) + self._notify_state(ChannelState.ATTACHED, resumed=resumed, has_presence=has_presence) else: log.warn("RealtimeChannel._on_message(): ATTACHED received while not attaching") elif action == ProtocolMessageAction.DETACHED: @@ -570,6 +578,17 @@ def _on_message(self, proto_msg: dict) -> None: log.error(f"Message processing error {e}. Skip messages {proto_msg.get('messages')}") for message in messages: self.__message_emitter._emit(message.name, message) + elif action == ProtocolMessageAction.PRESENCE: + # Handle PRESENCE messages + presence_messages = proto_msg.get('presence', []) + decoded_presence = PresenceMessage.from_encoded_array(presence_messages, cipher=self.cipher) + self.__presence.set_presence(decoded_presence, is_sync=False) + elif action == ProtocolMessageAction.SYNC: + # Handle SYNC messages (RTP18) + presence_messages = proto_msg.get('presence', []) + decoded_presence = PresenceMessage.from_encoded_array(presence_messages, cipher=self.cipher) + sync_channel_serial = proto_msg.get('channelSerial') + self.__presence.set_presence(decoded_presence, is_sync=True, sync_channel_serial=sync_channel_serial) elif action == ProtocolMessageAction.ERROR: error = AblyException.from_dict(proto_msg.get('error')) self._notify_state(ChannelState.FAILED, reason=error) @@ -580,7 +599,7 @@ def _request_state(self, state: ChannelState) -> None: self._check_pending_state() def _notify_state(self, state: ChannelState, reason: AblyException | None = None, - resumed: bool = False) -> None: + resumed: bool = False, has_presence: bool = False) -> None: log.debug(f'RealtimeChannel._notify_state(): state = {state}') self.__clear_state_timer() @@ -618,6 +637,9 @@ def _notify_state(self, state: ChannelState, reason: AblyException | None = None self._emit(state, state_change) self.__internal_state_emitter._emit(state, state_change) + # RTP5: Notify presence of channel state change + self.__presence.act_on_channel_state(state, has_presence=has_presence, error=reason) + def _send_message(self, msg: dict) -> None: asyncio.create_task(self.__realtime.connection.connection_manager.send_protocol_message(msg)) @@ -708,6 +730,11 @@ def params(self) -> dict[str, str]: """Get channel parameters""" return self.__params + @property + def presence(self): + """Get the RealtimePresence object for this channel""" + return self.__presence + def _start_decode_failure_recovery(self, error: AblyException) -> None: """Start RTL18 decode failure recovery procedure""" diff --git a/ably/realtime/realtimepresence.py b/ably/realtime/realtimepresence.py new file mode 100644 index 00000000..2702846d --- /dev/null +++ b/ably/realtime/realtimepresence.py @@ -0,0 +1,790 @@ +""" +RealtimePresence - Manages presence operations on a realtime channel. + +This module implements presence functionality for realtime channels, +including enter/leave operations, presence state management, and SYNC handling. +""" + +from __future__ import annotations + +import asyncio +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from ably.realtime.connection import ConnectionState +from ably.realtime.presencemap import PresenceMap +from ably.types.channelstate import ChannelState, ChannelStateChange +from ably.types.presence import PresenceAction, PresenceMessage +from ably.util.eventemitter import EventEmitter +from ably.util.exceptions import AblyException + +if TYPE_CHECKING: + from ably.realtime.realtime_channel import RealtimeChannel + +log = logging.getLogger(__name__) + + +def _get_client_id(presence: RealtimePresence) -> str | None: + """Get the clientId for the current connection.""" + # Use auth.client_id if available (set after CONNECTED), + # otherwise fall back to auth_options.client_id + return presence.channel.ably.auth.client_id or presence.channel.ably.auth.auth_options.client_id + + +def _is_anonymous_or_wildcard(presence: RealtimePresence) -> bool: + """Check if the client is anonymous or has wildcard clientId (RTP8j).""" + realtime = presence.channel.ably + client_id = _get_client_id(presence) + + # If not currently connected, we can't assume we're anonymous + if realtime.connection.state != ConnectionState.CONNECTED: + return False + + return not client_id or client_id == '*' + + +class RealtimePresence(EventEmitter): + """ + Manages presence operations on a realtime channel. + + Enables clients to subscribe to presence events and to enter, update, + and leave presence on a channel. + + Attributes + ---------- + channel : RealtimeChannel + The channel this presence object belongs to + sync_complete : bool + True if the initial SYNC operation has completed (RTP13) + """ + + def __init__(self, channel: RealtimeChannel): + """ + Initialize a new RealtimePresence instance. + + Args: + channel: The RealtimeChannel this presence belongs to + """ + super().__init__() + self.channel = channel + self.sync_complete = False + + # RTP2: Main presence map keyed by memberKey (connectionId:clientId) + self.members = PresenceMap( + member_key_fn=lambda msg: msg.member_key + ) + + # RTP17: Internal presence map for own members, keyed by clientId only + self._my_members = PresenceMap( + member_key_fn=lambda msg: msg.client_id + ) + + # EventEmitter for presence subscriptions + self._subscriptions = EventEmitter() + + # RTP16: Queue for pending presence messages + self._pending_presence: list[dict] = [] + + async def enter(self, data: Any = None) -> None: + """ + Enter this client into the channel's presence (RTP8). + + Args: + data: Optional data to associate with this presence member + + Raises: + AblyException: If clientId is not specified or channel state prevents entering + """ + # RTP8j: Check for anonymous or wildcard client + if _is_anonymous_or_wildcard(self): + raise AblyException( + 'clientId must be specified to enter a presence channel', + 400, 40012 + ) + + return await self._enter_or_update_client(None, None, data, PresenceAction.ENTER) + + async def update(self, data: Any = None) -> None: + """ + Update this client's presence data (RTP9). + + If the client is not already entered, this will enter the client. + + Args: + data: Optional data to associate with this presence member + + Raises: + AblyException: If clientId is not specified or channel state prevents updating + """ + # RTP9e: In all other ways, identical to enter + if _is_anonymous_or_wildcard(self): + raise AblyException( + 'clientId must be specified to update presence data', + 400, 40012 + ) + + return await self._enter_or_update_client(None, None, data, PresenceAction.UPDATE) + + async def leave(self, data: Any = None) -> None: + """ + Leave this client from the channel's presence (RTP10). + + Args: + data: Optional data to send with the leave message + + Raises: + AblyException: If clientId is not specified or channel state prevents leaving + """ + if _is_anonymous_or_wildcard(self): + raise AblyException( + 'clientId must have been specified to enter or leave a presence channel', + 400, 40012 + ) + + return await self._leave_client(None, data) + + async def enter_client(self, client_id: str, data: Any = None) -> None: + """ + Enter into presence on behalf of another clientId (RTP14). + + This allows a single client with suitable permissions to register + presence on behalf of multiple clients. + + Args: + client_id: The clientId to enter + data: Optional data to associate with this presence member + + Raises: + AblyException: If channel state prevents entering or clientId mismatch + """ + return await self._enter_or_update_client(None, client_id, data, PresenceAction.ENTER) + + async def update_client(self, client_id: str, data: Any = None) -> None: + """ + Update presence on behalf of another clientId (RTP15). + + Args: + client_id: The clientId to update + data: Optional data to associate with this presence member + + Raises: + AblyException: If channel state prevents updating or clientId mismatch + """ + return await self._enter_or_update_client(None, client_id, data, PresenceAction.UPDATE) + + async def leave_client(self, client_id: str, data: Any = None) -> None: + """ + Leave presence on behalf of another clientId (RTP15). + + Args: + client_id: The clientId to leave + data: Optional data to send with the leave message + + Raises: + AblyException: If channel state prevents leaving or clientId mismatch + """ + return await self._leave_client(client_id, data) + + async def _enter_or_update_client( + self, + id: str | None, + client_id: str | None, + data: Any, + action: int + ) -> None: + """ + Internal method to handle enter/update operations. + + Args: + id: Optional presence message id + client_id: Optional clientId (if None, uses connection's clientId) + data: Optional data payload + action: The presence action (ENTER or UPDATE) + + Raises: + AblyException: If connection/channel state prevents operation or clientId mismatch + """ + channel = self.channel + + # Check connection state + if channel.ably.connection.state not in [ + ConnectionState.CONNECTING, + ConnectionState.CONNECTED, + ConnectionState.DISCONNECTED + ]: + raise AblyException( + f'Unable to {PresenceAction._action_name(action).lower()} presence channel; ' + f'connection state = {channel.ably.connection.state}', + 400, 90001 + ) + + action_name = PresenceAction._action_name(action).lower() + + log.info( + f'RealtimePresence.{action_name}(): ' + f'channel = {channel.name}, ' + f'clientId = {client_id or "(implicit) " + str(_get_client_id(self))}' + ) + + # RTP15f: Check clientId mismatch (wildcard '*' is allowed to enter on behalf of any client) + if client_id is not None and not self.channel.ably.auth.can_assume_client_id(client_id): + raise AblyException( + f'Unable to {action_name} presence channel with clientId {client_id} ' + f'as it does not match the current clientId {self.channel.ably.auth.client_id}', + 400, 40012 + ) + + # RTP8c: Use connection's clientId if not explicitly provided + effective_client_id = client_id if client_id is not None else _get_client_id(self) + + # Create presence message + presence_msg = PresenceMessage( + id=id, + action=action, + client_id=effective_client_id, + data=data + ) + + # Encrypt if cipher is configured + if channel.cipher: + presence_msg.encrypt(channel.cipher) + + # Convert to wire format + wire_msg = presence_msg.to_encoded(binary=channel.ably.options.use_binary_protocol) + + # RTP8d/RTP8g: Handle based on channel state + if channel.state == ChannelState.ATTACHED: + # Send immediately + return await self._send_presence([wire_msg]) + elif channel.state in [ChannelState.INITIALIZED, ChannelState.DETACHED]: + # RTP8d: Implicitly attach + asyncio.create_task(channel.attach()) + # Queue the message + return await self._queue_presence(wire_msg) + elif channel.state == ChannelState.ATTACHING: + # Queue the message + return await self._queue_presence(wire_msg) + else: + # RTP8g: DETACHED, FAILED, etc. + raise AblyException( + f'Unable to {action_name} presence channel while in {channel.state} state', + 400, 90001 + ) + + async def _leave_client(self, client_id: str | None, data: Any = None) -> None: + """ + Internal method to handle leave operations. + + Args: + client_id: Optional clientId (if None, uses connection's clientId) + data: Optional data payload + + Raises: + AblyException: If connection/channel state prevents operation or clientId mismatch + """ + channel = self.channel + + # Check connection state + if channel.ably.connection.state not in [ + ConnectionState.CONNECTING, + ConnectionState.CONNECTED, + ConnectionState.DISCONNECTED + ]: + raise AblyException( + f'Unable to leave presence channel; ' + f'connection state = {channel.ably.connection.state}', + 400, 90001 + ) + + log.info( + f'RealtimePresence.leave(): ' + f'channel = {channel.name}, ' + f'clientId = {client_id or _get_client_id(self)}' + ) + + # RTP15f: Check clientId mismatch (wildcard '*' is allowed to leave on behalf of any client) + if client_id is not None and not self.channel.ably.auth.can_assume_client_id(client_id): + raise AblyException( + f'Unable to leave presence channel with clientId {client_id} ' + f'as it does not match the current clientId {self.channel.ably.auth.client_id}', + 400, 40012 + ) + + # RTP10c: Use connection's clientId if not explicitly provided + effective_client_id = client_id if client_id is not None else _get_client_id(self) + + # Create presence message + presence_msg = PresenceMessage( + action=PresenceAction.LEAVE, + client_id=effective_client_id, + data=data + ) + + # Encrypt if cipher is configured + if channel.cipher: + presence_msg.encrypt(channel.cipher) + + # Convert to wire format + wire_msg = presence_msg.to_encoded(binary=channel.ably.options.use_binary_protocol) + + # RTP10e: Handle based on channel state + if channel.state == ChannelState.ATTACHED: + # Send immediately + return await self._send_presence([wire_msg]) + elif channel.state == ChannelState.ATTACHING: + # Queue the message + return await self._queue_presence(wire_msg) + elif channel.state in [ChannelState.INITIALIZED, ChannelState.FAILED]: + # RTP10e: Don't attach just to leave + raise AblyException( + 'Unable to leave presence channel (incompatible state)', + 400, 90001 + ) + else: + raise AblyException( + f'Unable to leave presence channel while in {channel.state} state', + 400, 90001 + ) + + async def _send_presence(self, presence_messages: list[dict]) -> None: + """ + Send presence messages to the server. + + Args: + presence_messages: List of encoded presence messages to send + """ + from ably.transport.websockettransport import ProtocolMessageAction + + protocol_msg = { + 'action': ProtocolMessageAction.PRESENCE, + 'channel': self.channel.name, + 'presence': presence_messages + } + + await self.channel.ably.connection.connection_manager.send_protocol_message(protocol_msg) + + async def _queue_presence(self, wire_msg: dict) -> None: + """ + Queue a presence message to be sent when channel attaches. + + Args: + wire_msg: Encoded presence message to queue + """ + future = asyncio.Future() + + self._pending_presence.append({ + 'presence': wire_msg, + 'future': future + }) + + return await future + + async def get( + self, + wait_for_sync: bool = True, + client_id: str | None = None, + connection_id: str | None = None + ) -> list[PresenceMessage]: + """ + Get the current presence members on this channel (RTP11). + + Args: + wait_for_sync: If True, waits for SYNC to complete before returning (default: True) + client_id: Optional filter by clientId + connection_id: Optional filter by connectionId + + Returns: + List of current presence members + + Raises: + AblyException: If channel state prevents getting presence + """ + # RTP11d: Handle SUSPENDED state + if self.channel.state == ChannelState.SUSPENDED: + if wait_for_sync: + raise AblyException( + 'Presence state is out of sync due to channel being in the SUSPENDED state', + 400, 91005 + ) + else: + # Return current members without waiting + return self.members.list(client_id=client_id, connection_id=connection_id) + + # RTP11b: Implicitly attach if needed + if self.channel.state in [ChannelState.INITIALIZED, ChannelState.DETACHED]: + await self.channel.attach() + elif self.channel.state in [ChannelState.DETACHING, ChannelState.FAILED]: + raise AblyException( + f'Unable to get presence; channel state = {self.channel.state}', + 400, 90001 + ) + + # If channel is still attaching, wait for it to become ATTACHED + if self.channel.state == ChannelState.ATTACHING: + # Wait for channel to reach ATTACHED state + state_change = await self.channel._RealtimeChannel__internal_state_emitter.once_async() + if state_change.current != ChannelState.ATTACHED: + raise AblyException( + f'Unable to get presence; channel state = {state_change.current}', + 400, 90001 + ) + + # Wait for sync if requested and a sync is actually in progress + # If sync_complete is already True OR no sync is in progress, don't wait + if wait_for_sync and not self.sync_complete and self.members.sync_in_progress: + await self._wait_for_sync() + + return self.members.list(client_id=client_id, connection_id=connection_id) + + async def _wait_for_sync(self) -> None: + """Wait for presence SYNC to complete.""" + if self.sync_complete: + return + + # Use the PresenceMap's wait_sync mechanism + future = asyncio.Future() + + def on_sync_complete(): + if not future.done(): + future.set_result(None) + + self.members.wait_sync(on_sync_complete) + + # Wait for the sync to complete + await future + + async def subscribe(self, *args) -> None: + """ + Subscribe to presence events on this channel (RTP6). + + Args: + *args: Either (listener) or (event, listener) or (events, listener) + - listener: Callback for all presence events + - event: Specific event name ('enter', 'leave', 'update', 'present') + - events: List of event names + - listener: Callback for specified events + + Raises: + AblyException: If channel state prevents subscription + """ + # RTP6d: Implicitly attach + if self.channel.state in [ChannelState.INITIALIZED, ChannelState.DETACHED, ChannelState.DETACHING]: + asyncio.create_task(self.channel.attach()) + + # Parse arguments: similar to channel subscribe + if len(args) == 1: + # subscribe(listener) + listener = args[0] + self._subscriptions.on(listener) + elif len(args) == 2: + # subscribe(event, listener) + event = args[0] + listener = args[1] + self._subscriptions.on(event, listener) + else: + raise ValueError('Invalid subscribe arguments') + + def unsubscribe(self, *args) -> None: + """ + Unsubscribe from presence events on this channel (RTP7). + + Args: + *args: Either (), (listener), or (event, listener) + - (): Unsubscribe all listeners + - listener: Unsubscribe this specific listener + - event, listener: Unsubscribe listener for specific event + """ + if len(args) == 0: + # unsubscribe() - remove all + self._subscriptions.off() + elif len(args) == 1: + # unsubscribe(listener) + listener = args[0] + self._subscriptions.off(listener) + elif len(args) == 2: + # unsubscribe(event, listener) + event = args[0] + listener = args[1] + self._subscriptions.off(event, listener) + else: + raise ValueError('Invalid unsubscribe arguments') + + def set_presence( + self, + presence_set: list[PresenceMessage], + is_sync: bool, + sync_channel_serial: str | None = None + ) -> None: + """ + Process incoming presence messages from the server (Phase 3 - RTP2, RTP18). + + Args: + presence_set: List of presence messages received + is_sync: True if this is part of a SYNC operation + sync_channel_serial: Optional sync cursor for tracking sync progress + """ + log.info( + f'RealtimePresence.set_presence(): ' + f'received presence for {len(presence_set)} members; ' + f'syncChannelSerial = {sync_channel_serial}' + ) + + conn_id = self.channel.ably.connection.connection_manager.connection_id + broadcast_messages = [] + + # RTP18: Handle SYNC + if is_sync: + self.members.start_sync() + # Parse sync cursor if present + if sync_channel_serial: + # Format: : + parts = sync_channel_serial.split(':', 1) + sync_cursor = parts[1] if len(parts) > 1 else None + else: + sync_cursor = None + else: + sync_cursor = None + + # Process each presence message + for presence in presence_set: + if presence.action == PresenceAction.LEAVE: + # RTP2h: Handle LEAVE + if self.members.remove(presence): + broadcast_messages.append(presence) + + # RTP17b: Update internal presence map (not synthesized) + if presence.connection_id == conn_id and not presence.is_synthesized(): + self._my_members.remove(presence) + + elif presence.action in ( + PresenceAction.ENTER, + PresenceAction.PRESENT, + PresenceAction.UPDATE + ): + # RTP2d: Handle ENTER/PRESENT/UPDATE + if self.members.put(presence): + broadcast_messages.append(presence) + + # RTP17b: Update internal presence map + if presence.connection_id == conn_id: + self._my_members.put(presence) + + # RTP18b/RTP18c: End sync if cursor is empty or no channelSerial + if is_sync and (not sync_channel_serial or not sync_cursor): + residual, absent = self.members.end_sync() + self.sync_complete = True + + # RTP19: Emit synthesized leave events for residual members + for member in residual + absent: + synthesized_leave = PresenceMessage( + action=PresenceAction.LEAVE, + client_id=member.client_id, + connection_id=member.connection_id, + data=member.data, + encoding=member.encoding, + timestamp=datetime.now(timezone.utc) + ) + broadcast_messages.append(synthesized_leave) + + # Broadcast messages to subscribers + for presence in broadcast_messages: + action_name = PresenceAction._action_name(presence.action).lower() + self._subscriptions._emit(action_name, presence) + + def on_attached(self, has_presence: bool = False) -> None: + """ + Handle channel ATTACHED event (RTP5b). + + Args: + has_presence: True if server will send SYNC + """ + log.info( + f'RealtimePresence.on_attached(): ' + f'channel = {self.channel.name}, hasPresence = {has_presence}' + ) + + # RTP1: Handle presence sync flag + if has_presence: + self.members.start_sync() + self.sync_complete = False + else: + # RTP19a: No presence on channel, synthesize leaves for existing members + self._synthesize_leaves(self.members.values()) + self.members.clear() + self.sync_complete = True + # Also end sync in case one was started + if self.members.sync_in_progress: + self.members.end_sync() + + # RTP17i: Re-enter own members + self._ensure_my_members_present() + + # RTP5b: Send pending presence messages + asyncio.create_task(self._send_pending_presence()) + + def _ensure_my_members_present(self) -> None: + """ + Re-enter own presence members after attach (RTP17g). + """ + conn_id = self.channel.ably.connection.connection_manager.connection_id + + for _client_id, entry in list(self._my_members._map.items()): + log.info( + f'RealtimePresence._ensure_my_members_present(): ' + f'auto-reentering clientId "{entry.client_id}"' + ) + + # RTP17g1: Suppress id if connectionId has changed + msg_id = entry.id if entry.connection_id == conn_id else None + + # Create task to re-enter - use default args to bind loop variables + asyncio.create_task( + self._reenter_member(msg_id, entry.client_id, entry.data) + ) + + async def _reenter_member(self, msg_id: str | None, client_id: str, data: Any) -> None: + """ + Helper method to re-enter a member (RTP17g). + + Args: + msg_id: Optional message ID + client_id: The client ID to re-enter + data: The presence data + """ + try: + await self._enter_or_update_client( + msg_id, + client_id, + data, + PresenceAction.ENTER + ) + except AblyException as e: + log.error( + f'RealtimePresence._reenter_member(): ' + f'auto-reenter failed: {e}' + ) + # RTP17e: Emit update event with error + state_change = ChannelStateChange( + previous=self.channel.state, + current=self.channel.state, + resumed=False, + reason=e + ) + self.channel._emit("update", state_change) + + async def _send_pending_presence(self) -> None: + """ + Send pending presence messages after channel attaches (RTP5b). + """ + if not self._pending_presence: + return + + log.info( + f'RealtimePresence._send_pending_presence(): ' + f'sending {len(self._pending_presence)} queued messages' + ) + + pending = self._pending_presence + self._pending_presence = [] + + # Send all pending messages + presence_array = [item['presence'] for item in pending] + + try: + await self._send_presence(presence_array) + # Resolve all futures AFTER send completes + for item in pending: + if not item['future'].done(): + item['future'].set_result(None) + except Exception as e: + # Reject all futures + for item in pending: + if not item['future'].done(): + item['future'].set_exception(e) + + def _synthesize_leaves(self, members: list[PresenceMessage]) -> None: + """ + Emit synthesized leave events for members (RTP19, RTP19a). + + Args: + members: List of members to synthesize leaves for + """ + for member in members: + synthesized_leave = PresenceMessage( + action=PresenceAction.LEAVE, + client_id=member.client_id, + connection_id=member.connection_id, + data=member.data, + encoding=member.encoding, + timestamp=datetime.now(timezone.utc) + ) + self._subscriptions._emit('leave', synthesized_leave) + + def act_on_channel_state( + self, + state: ChannelState, + has_presence: bool = False, + error: AblyException | None = None + ) -> None: + """ + React to channel state changes (RTP5). + + Args: + state: The new channel state + has_presence: Whether the channel has presence (for ATTACHED) + error: Optional error associated with state change + """ + if state == ChannelState.ATTACHED: + self.on_attached(has_presence) + elif state in (ChannelState.DETACHED, ChannelState.FAILED): + # RTP5a: Clear maps and fail pending + self._my_members.clear() + self.members.clear() + self.sync_complete = False + self._fail_pending_presence(error) + elif state == ChannelState.SUSPENDED: + # RTP5f: Fail pending but keep members, reset sync state + self.sync_complete = False # Sync state is no longer valid + self._fail_pending_presence(error) + + def _fail_pending_presence(self, error: AblyException | None = None) -> None: + """ + Fail all pending presence messages. + + Args: + error: The error to reject with + """ + if not self._pending_presence: + return + + log.info( + f'RealtimePresence._fail_pending_presence(): ' + f'failing {len(self._pending_presence)} queued messages' + ) + + pending = self._pending_presence + self._pending_presence = [] + + exception = error or AblyException('Presence operation failed', 400, 90001) + + for item in pending: + if not item['future'].done(): + item['future'].set_exception(exception) + + +# Helper for PresenceAction to convert action to string +def _action_name_impl(action: int) -> str: + """Convert presence action to string name.""" + names = { + PresenceAction.ABSENT: 'absent', + PresenceAction.PRESENT: 'present', + PresenceAction.ENTER: 'enter', + PresenceAction.LEAVE: 'leave', + PresenceAction.UPDATE: 'update', + } + return names.get(action, f'unknown({action})') + + +# Monkey-patch the helper onto PresenceAction +PresenceAction._action_name = staticmethod(_action_name_impl) diff --git a/ably/transport/websockettransport.py b/ably/transport/websockettransport.py index 450cd364..d75345d4 100644 --- a/ably/transport/websockettransport.py +++ b/ably/transport/websockettransport.py @@ -183,7 +183,9 @@ async def on_protocol_message(self, msg): elif action in ( ProtocolMessageAction.ATTACHED, ProtocolMessageAction.DETACHED, - ProtocolMessageAction.MESSAGE + ProtocolMessageAction.MESSAGE, + ProtocolMessageAction.PRESENCE, + ProtocolMessageAction.SYNC ): self.connection_manager.on_channel_message(msg) diff --git a/ably/types/options.py b/ably/types/options.py index f15b3656..8804b3b9 100644 --- a/ably/types/options.py +++ b/ably/types/options.py @@ -32,7 +32,7 @@ def __init__(self, client_id=None, log_level=0, tls=True, rest_host=None, realti fallback_retry_timeout=None, disconnected_retry_timeout=None, idempotent_rest_publishing=None, loop=None, auto_connect=True, suspended_retry_timeout=None, connectivity_check_url=None, channel_retry_timeout=Defaults.channel_retry_timeout, add_request_ids=False, - vcdiff_decoder: VCDiffDecoder = None, **kwargs): + vcdiff_decoder: VCDiffDecoder = None, transport_params=None, **kwargs): super().__init__(**kwargs) @@ -96,6 +96,7 @@ def __init__(self, client_id=None, log_level=0, tls=True, rest_host=None, realti self.__fallback_realtime_host = None self.__add_request_ids = add_request_ids self.__vcdiff_decoder = vcdiff_decoder + self.__transport_params = transport_params or {} self.__rest_hosts = self.__get_rest_hosts() self.__realtime_hosts = self.__get_realtime_hosts() @@ -282,6 +283,10 @@ def add_request_ids(self): def vcdiff_decoder(self): return self.__vcdiff_decoder + @property + def transport_params(self): + return self.__transport_params + def __get_rest_hosts(self): """ Return the list of hosts as they should be tried. First comes the main diff --git a/ably/types/presence.py b/ably/types/presence.py index c32c634e..723ceacc 100644 --- a/ably/types/presence.py +++ b/ably/types/presence.py @@ -1,8 +1,13 @@ +import base64 +import json from datetime import datetime, timedelta from urllib import parse from ably.http.paginatedresult import PaginatedResult from ably.types.mixins import EncodeDataMixin +from ably.types.typedbuffer import TypedBuffer +from ably.util.crypto import CipherData +from ably.util.exceptions import AblyException def _ms_since_epoch(dt): @@ -38,12 +43,13 @@ def __init__(self, extras=None, # TP3i (functionality not specified) ): + super().__init__(encoding or '') + self.__id = id self.__action = action self.__client_id = client_id self.__connection_id = connection_id self.__data = data - self.__encoding = encoding self.__timestamp = timestamp self.__member_key = member_key self.__extras = extras @@ -68,10 +74,6 @@ def connection_id(self): def data(self): return self.__data - @property - def encoding(self): - return self.__encoding - @property def timestamp(self): return self.__timestamp @@ -85,6 +87,121 @@ def member_key(self): def extras(self): return self.__extras + def is_synthesized(self): + """ + Check if message is synthesized (RTP2b1). + A message is synthesized if its connectionId is not an initial substring of its id. + This happens with synthesized leave events sent by realtime to indicate + a connection disconnected unexpectedly. + """ + if not self.id or not self.connection_id: + return False + return not self.id.startswith(self.connection_id + ':') + + def parse_id(self): + """ + Parse id into components (connId, msgSerial, index) for RTP2b2 comparison. + Expected format: connId:msgSerial:index (e.g., "aaaaaa:0:0") + + Returns: + dict with 'msgSerial' and 'index' as integers + + Raises: + ValueError: If id is missing or has invalid format + """ + if not self.id: + raise ValueError("Cannot parse id: id is None or empty") + + parts = self.id.split(':') + + try: + return { + 'msgSerial': int(parts[1]), + 'index': int(parts[2]) + } + except (ValueError, IndexError) as e: + raise ValueError(f"Cannot parse id: invalid msgSerial or index in '{self.id}'") from e + + def encrypt(self, channel_cipher): + """ + Encrypt the presence message data using the provided cipher. + Similar to Message.encrypt(). + """ + if isinstance(self.data, CipherData): + return + + elif isinstance(self.data, str): + self._encoding_array.append('utf-8') + + if isinstance(self.data, dict) or isinstance(self.data, list): + self._encoding_array.append('json') + self._encoding_array.append('utf-8') + + typed_data = TypedBuffer.from_obj(self.data) + if typed_data.buffer is None: + return + encrypted_data = channel_cipher.encrypt(typed_data.buffer) + self.__data = CipherData(encrypted_data, typed_data.type, + cipher_type=channel_cipher.cipher_type) + + def to_encoded(self, binary=False): + """ + Convert to wire protocol format for sending. + + Handles proper encoding of data including JSON serialization, + base64 encoding for binary data, and encryption support. + """ + data = self.data + data_type = None + encoding = self._encoding_array[:] + + # Handle different data types and build encoding string + if isinstance(data, (dict, list)): + encoding.append('json') + data = json.dumps(data) + data = str(data) + elif isinstance(data, str) and not binary: + pass + elif not binary and isinstance(data, (bytearray, bytes)): + data = base64.b64encode(data).decode('ascii') + encoding.append('base64') + elif isinstance(data, CipherData): + encoding.append(data.encoding_str) + data_type = data.type + if not binary: + data = base64.b64encode(data.buffer).decode('ascii') + encoding.append('base64') + else: + data = data.buffer + elif binary and isinstance(data, bytearray): + data = bytes(data) + + if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): + raise AblyException("Invalid data payload", 400, 40011) + + result = { + 'action': self.action, + } + + if self.id: + result['id'] = self.id + if self.client_id: + result['clientId'] = self.client_id + if self.connection_id: + result['connectionId'] = self.connection_id + if data is not None: + result['data'] = data + if data_type: + result['type'] = data_type + if encoding: + result['encoding'] = '/'.join(encoding).strip('/') + if self.extras: + result['extras'] = self.extras + if self.timestamp: + result['timestamp'] = _ms_since_epoch(self.timestamp) + + return result + @staticmethod def from_encoded(obj, cipher=None, context=None): id = obj.get('id') @@ -112,6 +229,13 @@ def from_encoded(obj, cipher=None, context=None): **decoded_data ) + @staticmethod + def from_encoded_array(encoded_array, cipher=None, context=None): + """ + Decode array of presence messages. + """ + return [PresenceMessage.from_encoded(item, cipher, context) for item in encoded_array] + class Presence: def __init__(self, channel): diff --git a/test/ably/realtime/presencemap_test.py b/test/ably/realtime/presencemap_test.py new file mode 100644 index 00000000..043baeb0 --- /dev/null +++ b/test/ably/realtime/presencemap_test.py @@ -0,0 +1,772 @@ +""" +Unit tests for PresenceMap implementation. + +Tests RTP2 specification requirements for presence map operations. +""" + +from datetime import datetime + +import pytest + +from ably.realtime.presencemap import PresenceMap, _is_newer +from ably.types.presence import PresenceAction, PresenceMessage +from test.ably.utils import BaseAsyncTestCase + + +class TestPresenceMessageHelpers(BaseAsyncTestCase): + """Test helper methods on PresenceMessage (RTP2b support).""" + + def test_is_synthesized_with_matching_connection_id(self): + """Test that normal messages are not synthesized.""" + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + assert not msg.is_synthesized() + + def test_is_synthesized_with_non_matching_connection_id(self): + """Test that synthesized leave events are detected (RTP2b1).""" + msg = PresenceMessage( + id='different:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.LEAVE + ) + assert msg.is_synthesized() + + def test_is_synthesized_without_id(self): + """Test that messages without id are not considered synthesized.""" + msg = PresenceMessage( + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + assert not msg.is_synthesized() + + def test_parse_id_valid(self): + """Test parsing valid presence message id (RTP2b2).""" + msg = PresenceMessage( + id='connection123:42:7', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + parsed = msg.parse_id() + assert parsed['msgSerial'] == 42 + assert parsed['index'] == 7 + + def test_parse_id_without_id(self): + """Test parsing message without id raises ValueError.""" + msg = PresenceMessage( + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + with pytest.raises(ValueError) as context: + msg.parse_id() + assert "id is None or empty" in str(context.value) + + def test_parse_id_invalid_format(self): + """Test parsing invalid id format raises ValueError.""" + msg = PresenceMessage( + id='invalid', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + with pytest.raises(ValueError) as context: + msg.parse_id() + assert "invalid msgSerial or index" in str(context.value) + + def test_parse_id_non_numeric_parts(self): + """Test parsing id with non-numeric msgSerial/index raises ValueError.""" + msg = PresenceMessage( + id='connection123:abc:def', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + with pytest.raises(ValueError) as context: + msg.parse_id() + assert "invalid msgSerial or index" in str(context.value) + + def test_member_key_property(self): + """Test member_key property (TP3h).""" + msg = PresenceMessage( + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + assert msg.member_key == 'connection123:client1' + + def test_member_key_without_connection_id(self): + """Test member_key when connection_id is missing.""" + msg = PresenceMessage( + client_id='client1', + action=PresenceAction.PRESENT + ) + assert msg.member_key is None + + def test_to_encoded(self): + """Test converting message to wire format.""" + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER, + data='test data', + timestamp=datetime(2024, 1, 1, 12, 0, 0) + ) + encoded = msg.to_encoded() + assert encoded['action'] == PresenceAction.ENTER + assert encoded['id'] == 'connection123:0:0' + assert encoded['connectionId'] == 'connection123' + assert encoded['clientId'] == 'client1' + assert encoded['data'] == 'test data' + assert 'timestamp' in encoded + + def test_to_encoded_with_dict_data(self): + """Test converting message with dict data (should be JSON serialized).""" + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER, + data={'key': 'value', 'number': 42} + ) + encoded = msg.to_encoded() + assert encoded['data'] == '{"key": "value", "number": 42}' + assert encoded['encoding'] == 'json' + + def test_to_encoded_with_list_data(self): + """Test converting message with list data (should be JSON serialized).""" + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER, + data=['item1', 'item2', 3] + ) + encoded = msg.to_encoded() + assert encoded['data'] == '["item1", "item2", 3]' + assert encoded['encoding'] == 'json' + + def test_to_encoded_with_binary_data(self): + """Test converting message with binary data (should be base64 encoded).""" + import base64 + binary_data = b'binary data here' + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER, + data=binary_data + ) + encoded = msg.to_encoded() + assert encoded['data'] == base64.b64encode(binary_data).decode('ascii') + assert encoded['encoding'] == 'base64' + + def test_to_encoded_with_bytearray_data(self): + """Test converting message with bytearray data (should be base64 encoded).""" + import base64 + binary_data = bytearray(b'bytearray data') + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER, + data=binary_data + ) + encoded = msg.to_encoded() + assert encoded['data'] == base64.b64encode(binary_data).decode('ascii') + assert encoded['encoding'] == 'base64' + + def test_to_encoded_with_existing_encoding(self): + """Test that existing encoding is preserved and appended to.""" + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER, + data=b'test', + encoding='utf-8' + ) + encoded = msg.to_encoded() + assert 'utf-8' in encoded['encoding'] + assert 'base64' in encoded['encoding'] + assert encoded['encoding'] == 'utf-8/base64' + + def test_to_encoded_binary_mode(self): + """Test converting message in binary mode (no base64 encoding).""" + binary_data = b'binary data' + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER, + data=binary_data + ) + encoded = msg.to_encoded(binary=True) + assert encoded['data'] == binary_data + assert 'encoding' not in encoded # No base64 added in binary mode + + def test_from_encoded_array(self): + """Test decoding array of presence messages.""" + encoded_array = [ + { + 'id': 'conn1:0:0', + 'action': PresenceAction.ENTER, + 'clientId': 'client1', + 'connectionId': 'conn1', + 'data': 'data1' + }, + { + 'id': 'conn2:0:0', + 'action': PresenceAction.PRESENT, + 'clientId': 'client2', + 'connectionId': 'conn2', + 'data': 'data2' + } + ] + messages = PresenceMessage.from_encoded_array(encoded_array) + assert len(messages) == 2 + assert messages[0].client_id == 'client1' + assert messages[1].client_id == 'client2' + + +class TestNewnessComparison(BaseAsyncTestCase): + """Test newness comparison logic (RTP2b).""" + + def test_synthesized_message_newer_by_timestamp(self): + """Test RTP2b1: synthesized messages compared by timestamp.""" + older = PresenceMessage( + id='different:0:0', # Synthesized (doesn't match connection_id) + connection_id='connection123', + client_id='client1', + action=PresenceAction.LEAVE, + timestamp=datetime(2024, 1, 1, 12, 0, 0) + ) + newer = PresenceMessage( + id='connection123:5:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT, + timestamp=datetime(2024, 1, 1, 12, 0, 1) + ) + assert _is_newer(newer, older) + assert not _is_newer(older, newer) + + def test_synthesized_equal_timestamp_incoming_wins(self): + """Test RTP2b1a: equal timestamps, incoming is newer.""" + timestamp = datetime(2024, 1, 1, 12, 0, 0) + existing = PresenceMessage( + id='different:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.LEAVE, + timestamp=timestamp + ) + incoming = PresenceMessage( + id='other:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.LEAVE, + timestamp=timestamp + ) + # Incoming should be considered newer (>=) + assert _is_newer(incoming, existing) + + def test_normal_message_newer_by_msg_serial(self): + """Test RTP2b2: normal messages compared by msgSerial.""" + older = PresenceMessage( + id='connection123:5:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT, + timestamp=datetime(2024, 1, 1, 12, 0, 0) + ) + newer = PresenceMessage( + id='connection123:10:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT, + timestamp=datetime(2024, 1, 1, 11, 0, 0) # Earlier timestamp doesn't matter + ) + assert _is_newer(newer, older) + assert not _is_newer(older, newer) + + def test_normal_message_newer_by_index(self): + """Test RTP2b2: when msgSerial equal, compare by index.""" + older = PresenceMessage( + id='connection123:5:2', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + newer = PresenceMessage( + id='connection123:5:3', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + assert _is_newer(newer, older) + assert not _is_newer(older, newer) + + def test_normal_message_same_serial_and_index(self): + """Test equal msgSerial and index - incoming is not newer.""" + msg1 = PresenceMessage( + id='connection123:5:3', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + msg2 = PresenceMessage( + id='connection123:5:3', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + # Index not greater, so not newer + assert not _is_newer(msg2, msg1) + + +class TestPresenceMapBasicOperations(BaseAsyncTestCase): + """Test basic PresenceMap operations.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Set up test fixtures.""" + self.presence_map = PresenceMap( + member_key_fn=lambda msg: msg.member_key + ) + yield + + def test_put_enter_message(self): + """Test RTP2d: ENTER message stored as PRESENT.""" + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER, + data='test' + ) + result = self.presence_map.put(msg) + assert result is True + + stored = self.presence_map.get('connection123:client1') + assert stored is not None + assert stored.action == PresenceAction.PRESENT + assert stored.client_id == 'client1' + assert stored.data == 'test' + + def test_put_update_message(self): + """Test RTP2d: UPDATE message stored as PRESENT.""" + msg = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.UPDATE, + data='updated' + ) + result = self.presence_map.put(msg) + assert result is True + + stored = self.presence_map.get('connection123:client1') + assert stored.action == PresenceAction.PRESENT + + def test_put_rejects_older_message(self): + """Test RTP2a: older messages are rejected.""" + newer = PresenceMessage( + id='connection123:10:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER + ) + older = PresenceMessage( + id='connection123:5:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.UPDATE + ) + + # Add newer first + self.presence_map.put(newer) + # Try to add older - should be rejected + result = self.presence_map.put(older) + assert result is False + + # Should still have the newer one + stored = self.presence_map.get('connection123:client1') + assert stored.parse_id()['msgSerial'] == 10 + + def test_put_accepts_newer_message(self): + """Test RTP2a: newer messages replace older ones.""" + older = PresenceMessage( + id='connection123:5:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER, + data='old' + ) + newer = PresenceMessage( + id='connection123:10:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.UPDATE, + data='new' + ) + + self.presence_map.put(older) + result = self.presence_map.put(newer) + assert result is True + + stored = self.presence_map.get('connection123:client1') + assert stored.data == 'new' + assert stored.parse_id()['msgSerial'] == 10 + + def test_remove_member(self): + """Test RTP2h1: LEAVE removes member outside of sync.""" + enter = PresenceMessage( + id='connection123:0:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.ENTER + ) + leave = PresenceMessage( + id='connection123:1:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.LEAVE + ) + + self.presence_map.put(enter) + result = self.presence_map.remove(leave) + assert result is True + + # Member should be removed + assert self.presence_map.get('connection123:client1') is None + + def test_remove_rejects_older_leave(self): + """Test RTP2h: LEAVE must pass newness check.""" + newer = PresenceMessage( + id='connection123:10:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.PRESENT + ) + older_leave = PresenceMessage( + id='connection123:5:0', + connection_id='connection123', + client_id='client1', + action=PresenceAction.LEAVE + ) + + self.presence_map.put(newer) + result = self.presence_map.remove(older_leave) + assert result is False + + # Member should still be present + assert self.presence_map.get('connection123:client1') is not None + + def test_values_excludes_absent(self): + """Test that values() excludes ABSENT members.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + msg2 = PresenceMessage( + id='conn2:0:0', + connection_id='conn2', + client_id='client2', + action=PresenceAction.PRESENT + ) + + self.presence_map.put(msg1) + self.presence_map.put(msg2) + + # Manually add an ABSENT member (happens during sync) + absent = PresenceMessage( + id='conn3:0:0', + connection_id='conn3', + client_id='client3', + action=PresenceAction.ABSENT + ) + self.presence_map._map[absent.member_key] = absent + + values = self.presence_map.values() + assert len(values) == 2 + assert all(msg.action == PresenceAction.PRESENT for msg in values) + + def test_list_with_client_id_filter(self): + """Test RTP11c2: list with clientId filter.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + msg2 = PresenceMessage( + id='conn2:0:0', + connection_id='conn2', + client_id='client2', + action=PresenceAction.PRESENT + ) + + self.presence_map.put(msg1) + self.presence_map.put(msg2) + + result = self.presence_map.list(client_id='client1') + assert len(result) == 1 + assert result[0].client_id == 'client1' + + def test_list_with_connection_id_filter(self): + """Test RTP11c3: list with connectionId filter.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + msg2 = PresenceMessage( + id='conn1:0:1', + connection_id='conn1', + client_id='client2', + action=PresenceAction.PRESENT + ) + msg3 = PresenceMessage( + id='conn2:0:0', + connection_id='conn2', + client_id='client3', + action=PresenceAction.PRESENT + ) + + self.presence_map.put(msg1) + self.presence_map.put(msg2) + self.presence_map.put(msg3) + + result = self.presence_map.list(connection_id='conn1') + assert len(result) == 2 + assert all(msg.connection_id == 'conn1' for msg in result) + + def test_clear(self): + """Test RTP5a: clear removes all members.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + self.presence_map.put(msg1) + self.presence_map.clear() + + assert len(self.presence_map.values()) == 0 + assert not self.presence_map.sync_in_progress + + +class TestPresenceMapSyncOperations(BaseAsyncTestCase): + """Test SYNC operations (RTP18, RTP19).""" + + @pytest.fixture(autouse=True) + def setup(self): + """Set up test fixtures.""" + self.presence_map = PresenceMap( + member_key_fn=lambda msg: msg.member_key + ) + yield + + def test_start_sync(self): + """Test RTP18: start_sync captures residual members.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + msg2 = PresenceMessage( + id='conn2:0:0', + connection_id='conn2', + client_id='client2', + action=PresenceAction.PRESENT + ) + + self.presence_map.put(msg1) + self.presence_map.put(msg2) + + self.presence_map.start_sync() + assert self.presence_map.sync_in_progress is True + assert self.presence_map._residual_members is not None + assert len(self.presence_map._residual_members) == 2 + + def test_put_during_sync_removes_from_residual(self): + """Test that members seen during sync are removed from residual.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + + self.presence_map.put(msg1) + self.presence_map.start_sync() + + # Update the same member during sync + msg1_update = PresenceMessage( + id='conn1:1:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT, + data='updated' + ) + self.presence_map.put(msg1_update) + + # Member should be removed from residual + assert 'conn1:client1' not in self.presence_map._residual_members + + def test_remove_during_sync_marks_absent(self): + """Test RTP2h2: LEAVE during sync marks member as ABSENT.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + + self.presence_map.put(msg1) + self.presence_map.start_sync() + + leave = PresenceMessage( + id='conn1:1:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.LEAVE + ) + result = self.presence_map.remove(leave) + assert result is True + + # Should be marked ABSENT, not removed + stored = self.presence_map.get('conn1:client1') + assert stored is not None + assert stored.action == PresenceAction.ABSENT + + def test_end_sync_removes_absent_members(self): + """Test RTP2h2b: end_sync removes ABSENT members.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + + self.presence_map.put(msg1) + self.presence_map.start_sync() + + leave = PresenceMessage( + id='conn1:1:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.LEAVE + ) + self.presence_map.remove(leave) + + residual, absent = self.presence_map.end_sync() + + # Member should be removed after sync + assert self.presence_map.get('conn1:client1') is None + assert not self.presence_map.sync_in_progress + + def test_end_sync_returns_residual_members(self): + """Test RTP19: end_sync returns residual members for leave synthesis.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + msg2 = PresenceMessage( + id='conn2:0:0', + connection_id='conn2', + client_id='client2', + action=PresenceAction.PRESENT + ) + + # Add two members + self.presence_map.put(msg1) + self.presence_map.put(msg2) + + self.presence_map.start_sync() + + # Only see msg1 during sync + msg1_update = PresenceMessage( + id='conn1:1:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + self.presence_map.put(msg1_update) + + # End sync - msg2 should be in residual + residual, absent = self.presence_map.end_sync() + + assert len(residual) == 1 + assert residual[0].client_id == 'client2' + + # msg2 should be removed from map + assert self.presence_map.get('conn2:client2') is None + # msg1 should still be present + assert self.presence_map.get('conn1:client1') is not None + + def test_start_sync_multiple_times(self): + """Test that start_sync can be called multiple times during sync.""" + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + + self.presence_map.put(msg1) + self.presence_map.start_sync() + + initial_residual = self.presence_map._residual_members + + # Call start_sync again - should not reset residual + self.presence_map.start_sync() + assert self.presence_map._residual_members is initial_residual + + def test_clear_invokes_sync_callbacks(self): + """ + Test that clear() invokes pending sync callbacks to prevent hanging. + + This ensures that if get() is waiting for sync and the channel + transitions to DETACHED/FAILED, the waiting Future is resolved + and the caller is not left blocked. + """ + msg1 = PresenceMessage( + id='conn1:0:0', + connection_id='conn1', + client_id='client1', + action=PresenceAction.PRESENT + ) + + self.presence_map.put(msg1) + self.presence_map.start_sync() + + # Register a callback as if _wait_for_sync() was called + callback_invoked = False + + def sync_callback(): + nonlocal callback_invoked + callback_invoked = True + + self.presence_map.wait_sync(sync_callback) + + # Clear should invoke the callback + self.presence_map.clear() + + assert callback_invoked, "clear() should invoke pending sync callbacks" + assert not self.presence_map.sync_in_progress + assert len(self.presence_map.values()) == 0 diff --git a/test/ably/realtime/realtimepresence_test.py b/test/ably/realtime/realtimepresence_test.py new file mode 100644 index 00000000..e7073983 --- /dev/null +++ b/test/ably/realtime/realtimepresence_test.py @@ -0,0 +1,886 @@ +""" +Integration tests for RealtimePresence. + +These tests verify presence functionality with real Ably connections, +testing enter/leave/update operations, presence subscriptions, and SYNC behavior. +""" + +import asyncio + +import pytest + +from ably.realtime.connection import ConnectionState +from ably.types.channelstate import ChannelState +from ably.types.presence import PresenceAction +from ably.util.exceptions import AblyException +from test.ably.testapp import TestApp +from test.ably.utils import BaseAsyncTestCase + + +async def force_suspended(client): + client.connection.connection_manager.request_state(ConnectionState.DISCONNECTED) + + await client.connection._when_state(ConnectionState.DISCONNECTED) + + client.connection.connection_manager.notify_state( + ConnectionState.SUSPENDED, + AblyException("Connection to server unavailable", 400, 80002) + ) + + await client.connection._when_state(ConnectionState.SUSPENDED) + + +@pytest.mark.parametrize('use_binary_protocol', [True, False], ids=['msgpack', 'json']) +class TestRealtimePresenceBasics(BaseAsyncTestCase): + """Test basic presence operations: enter, leave, update.""" + + @pytest.fixture(autouse=True) + async def setup(self, use_binary_protocol): + """Set up test fixtures.""" + self.test_vars = await TestApp.get_test_vars() + self.use_binary_protocol = use_binary_protocol + + self.client1 = await TestApp.get_ably_realtime( + client_id='client1', + use_binary_protocol=use_binary_protocol + ) + self.client2 = await TestApp.get_ably_realtime( + client_id='client2', + use_binary_protocol=use_binary_protocol + ) + + yield + + await self.client1.close() + await self.client2.close() + + async def test_presence_enter_without_attach(self): + """ + Test RTP8d: Enter presence without prior attach (implicit attach). + """ + channel_name = self.get_channel_name('enter_without_attach') + + # Client 1 listens for presence + channel1 = self.client1.channels.get(channel_name) + + presence_received = asyncio.Future() + + def on_presence(msg): + if msg.action == PresenceAction.ENTER and msg.client_id == 'client2': + presence_received.set_result(msg) + + await channel1.presence.subscribe(on_presence) + + # Client 2 enters without attaching first + channel2 = self.client2.channels.get(channel_name) + assert channel2.state == ChannelState.INITIALIZED + + await channel2.presence.enter('test data') + + # Should receive presence event + msg = await asyncio.wait_for(presence_received, timeout=5.0) + assert msg.client_id == 'client2' + assert msg.data == 'test data' + assert msg.action == PresenceAction.ENTER + + async def test_presence_enter_with_callback(self): + """ + Test RTP8b: Enter with callback - callback should be called on success. + """ + channel_name = self.get_channel_name('enter_with_callback') + + channel = self.client1.channels.get(channel_name) + await channel.attach() + + # Enter presence - should succeed + await channel.presence.enter('test data') + + # Verify member is present + members = await channel.presence.get() + assert len(members) == 1 + assert members[0].client_id == 'client1' + assert members[0].data == 'test data' + + async def test_presence_enter_and_leave(self): + """ + Test RTP10: Enter and leave presence, await leave event. + """ + channel_name = self.get_channel_name('enter_and_leave') + + channel1 = self.client1.channels.get(channel_name) + channel2 = self.client2.channels.get(channel_name) + + await channel1.attach() + + # Track events + events = [] + + def on_presence(msg): + events.append((msg.action, msg.client_id)) + + await channel1.presence.subscribe(on_presence) + + # Client 2 enters + await channel2.presence.enter('enter data') + + # Wait for enter event + await asyncio.sleep(0.5) + assert (PresenceAction.ENTER, 'client2') in events + + # Client 2 leaves + await channel2.presence.leave() + + # Wait for leave event + await asyncio.sleep(0.5) + assert (PresenceAction.LEAVE, 'client2') in events + + async def test_presence_enter_update(self): + """ + Test RTP9: Update presence data. + """ + channel_name = self.get_channel_name('enter_update') + + channel1 = self.client1.channels.get(channel_name) + channel2 = self.client2.channels.get(channel_name) + + await channel1.attach() + + # Track update events + updates = [] + + def on_update(msg): + if msg.action == PresenceAction.UPDATE: + updates.append(msg.data) + + await channel1.presence.subscribe('update', on_update) + + # Client 2 enters then updates + await channel2.presence.enter('original data') + await asyncio.sleep(0.3) + + await channel2.presence.update('updated data') + + # Wait for update event + await asyncio.sleep(0.5) + assert 'updated data' in updates + + async def test_presence_anonymous_client_error(self): + """ + Test RTP8j: Anonymous clients cannot enter presence. + """ + # Create client without clientId + client = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) + await client.connection.once_async('connected') + + channel = client.channels.get(self.get_channel_name('anonymous')) + + try: + await channel.presence.enter('data') + pytest.fail('Should have raised exception for anonymous client') + except Exception as e: + assert 'clientId must be specified' in str(e) + finally: + await client.close() + + +@pytest.mark.parametrize('use_binary_protocol', [True, False], ids=['msgpack', 'json']) +class TestRealtimePresenceGet(BaseAsyncTestCase): + """Test presence.get() functionality.""" + + @pytest.fixture(autouse=True) + async def setup(self, use_binary_protocol): + """Set up test fixtures.""" + self.test_vars = await TestApp.get_test_vars() + self.use_binary_protocol = use_binary_protocol + + self.client1 = await TestApp.get_ably_realtime( + client_id='client1', + use_binary_protocol=use_binary_protocol + ) + self.client2 = await TestApp.get_ably_realtime( + client_id='client2', + use_binary_protocol=use_binary_protocol + ) + + yield + + await self.client1.close() + await self.client2.close() + + async def test_presence_enter_get(self): + """ + Test RTP11a: Enter presence and get members. + """ + channel_name = self.get_channel_name('enter_get') + + channel1 = self.client1.channels.get(channel_name) + channel2 = self.client2.channels.get(channel_name) + + # Client 1 enters + await channel1.presence.enter('test data') + + # Wait for presence to sync + await asyncio.sleep(0.5) + + # Client 2 gets presence + members = await channel2.presence.get() + + assert len(members) == 1 + assert members[0].client_id == 'client1' + assert members[0].data == 'test data' + assert members[0].action == PresenceAction.PRESENT + + async def test_presence_get_unattached(self): + """ + Test RTP11b: Get presence on unattached channel (should attach and wait for sync). + """ + channel_name = self.get_channel_name('get_unattached') + + # Client 1 enters + channel1 = self.client1.channels.get(channel_name) + await channel1.presence.enter('test data') + + # Wait for presence + await asyncio.sleep(0.5) + + # Client 2 gets without attaching first + channel2 = self.client2.channels.get(channel_name) + assert channel2.state == ChannelState.INITIALIZED + + members = await channel2.presence.get() + + # Channel should now be attached + assert channel2.state == ChannelState.ATTACHED + assert len(members) == 1 + assert members[0].client_id == 'client1' + + async def test_presence_enter_leave_get(self): + """ + Test RTP11a + RTP10c: Enter, leave, then get (should be empty). + """ + channel_name = self.get_channel_name('enter_leave_get') + + channel1 = self.client1.channels.get(channel_name) + channel2 = self.client2.channels.get(channel_name) + + # Client 1 enters then leaves + await channel1.presence.enter('test data') + await asyncio.sleep(0.3) + await channel1.presence.leave() + + # Wait for leave to process + await asyncio.sleep(0.5) + + # Client 2 gets presence + members = await channel2.presence.get() + + assert len(members) == 0 + + +@pytest.mark.parametrize('use_binary_protocol', [True, False], ids=['msgpack', 'json']) +class TestRealtimePresenceSubscribe(BaseAsyncTestCase): + """Test presence.subscribe() functionality.""" + + @pytest.fixture(autouse=True) + async def setup(self, use_binary_protocol): + """Set up test fixtures.""" + self.test_vars = await TestApp.get_test_vars() + self.use_binary_protocol = use_binary_protocol + + self.client1 = await TestApp.get_ably_realtime( + client_id='client1', + use_binary_protocol=use_binary_protocol + ) + self.client2 = await TestApp.get_ably_realtime( + client_id='client2', + use_binary_protocol=use_binary_protocol + ) + + yield + + await self.client1.close() + await self.client2.close() + + async def test_presence_subscribe_unattached(self): + """ + Test RTP6d: Subscribe on unattached channel should implicitly attach. + """ + channel_name = self.get_channel_name('subscribe_unattached') + + channel1 = self.client1.channels.get(channel_name) + + received = asyncio.Future() + + def on_presence(msg): + if msg.client_id == 'client2': + received.set_result(msg) + + # Subscribe without attaching first + assert channel1.state == ChannelState.INITIALIZED + await channel1.presence.subscribe(on_presence) + + # Should implicitly attach + await asyncio.sleep(0.5) + assert channel1.state == ChannelState.ATTACHED + + # Client 2 enters + channel2 = self.client2.channels.get(channel_name) + await channel2.presence.enter('data') + + # Should receive event + msg = await asyncio.wait_for(received, timeout=5.0) + assert msg.client_id == 'client2' + + async def test_presence_message_action(self): + """ + Test RTP8c: PresenceMessage should have correct action string. + """ + channel_name = self.get_channel_name('message_action') + + channel1 = self.client1.channels.get(channel_name) + + received = asyncio.Future() + + def on_presence(msg): + received.set_result(msg) + + await channel1.presence.subscribe(on_presence) + await channel1.presence.enter() + + msg = await asyncio.wait_for(received, timeout=5.0) + assert msg.action == PresenceAction.ENTER + + +@pytest.mark.parametrize('use_binary_protocol', [True, False], ids=['msgpack', 'json']) +class TestRealtimePresenceEnterClient(BaseAsyncTestCase): + """Test enterClient/updateClient/leaveClient functionality.""" + + @pytest.fixture(autouse=True) + async def setup(self, use_binary_protocol): + """Set up test fixtures.""" + self.test_vars = await TestApp.get_test_vars() + self.use_binary_protocol = use_binary_protocol + + # Use wildcard auth for enterClient + self.client = await TestApp.get_ably_realtime( + client_id='*', + use_binary_protocol=use_binary_protocol + ) + + yield + + await self.client.close() + + async def test_enter_client_multiple(self): + """ + Test RTP14/RTP15: Enter multiple clients on one connection. + """ + channel_name = self.get_channel_name('enter_client_multiple') + channel = self.client.channels.get(channel_name) + + # Enter multiple clients + for i in range(5): + await channel.presence.enter_client(f'test_client_{i}', f'data_{i}') + + # Wait for presence to sync + await asyncio.sleep(0.5) + + # Get all members + members = await channel.presence.get() + + assert len(members) == 5 + client_ids = {m.client_id for m in members} + assert all(f'test_client_{i}' in client_ids for i in range(5)) + + async def test_update_client(self): + """ + Test RTP15: Update client presence data. + """ + channel_name = self.get_channel_name('update_client') + channel = self.client.channels.get(channel_name) + + # Enter client + await channel.presence.enter_client('test_client', 'original data') + await asyncio.sleep(0.3) + + # Update client + await channel.presence.update_client('test_client', 'updated data') + await asyncio.sleep(0.3) + + # Get member + members = await channel.presence.get(client_id='test_client') + + assert len(members) == 1 + assert members[0].data == 'updated data' + + async def test_leave_client(self): + """ + Test RTP15: Leave client presence. + """ + channel_name = self.get_channel_name('leave_client') + channel = self.client.channels.get(channel_name) + + # Enter multiple clients + await channel.presence.enter_client('client1', 'data1') + await channel.presence.enter_client('client2', 'data2') + await asyncio.sleep(0.3) + + # Leave one client + await channel.presence.leave_client('client1') + await asyncio.sleep(0.5) + + # Only client2 should remain + members = await channel.presence.get() + + assert len(members) == 1 + assert members[0].client_id == 'client2' + + +@pytest.mark.parametrize('use_binary_protocol', [True, False], ids=['msgpack', 'json']) +class TestRealtimePresenceConnectionLifecycle(BaseAsyncTestCase): + """Test presence behavior during connection lifecycle events.""" + + @pytest.fixture(autouse=True) + async def setup(self, use_binary_protocol): + """Set up test fixtures.""" + self.test_vars = await TestApp.get_test_vars() + self.use_binary_protocol = use_binary_protocol + yield + + async def test_presence_enter_without_connect(self): + """ + Test entering presence before connection is established. + Related to RTP8d. + """ + channel_name = self.get_channel_name('enter_without_connect') + + # Create listener client + listener_client = await TestApp.get_ably_realtime( + client_id='listener', + use_binary_protocol=self.use_binary_protocol + ) + listener_channel = listener_client.channels.get(channel_name) + + received = asyncio.Future() + + def on_presence(msg): + if msg.client_id == 'enterer' and msg.action == PresenceAction.ENTER: + received.set_result(msg) + + await listener_channel.presence.subscribe(on_presence) + + # Create client and enter before it's connected + enterer_client = await TestApp.get_ably_realtime( + client_id='enterer', + use_binary_protocol=self.use_binary_protocol + ) + enterer_channel = enterer_client.channels.get(channel_name) + + # Enter without waiting for connection + await enterer_channel.presence.enter('test data') + + # Should receive presence event + msg = await asyncio.wait_for(received, timeout=5.0) + assert msg.client_id == 'enterer' + assert msg.data == 'test data' + + await listener_client.close() + await enterer_client.close() + + async def test_presence_enter_after_close(self): + """ + Test re-entering presence after connection close and reconnect. + Related to RTP8d. + """ + channel_name = self.get_channel_name('enter_after_close') + + # Create listener + listener_client = await TestApp.get_ably_realtime( + client_id='listener', + use_binary_protocol=self.use_binary_protocol + ) + listener_channel = listener_client.channels.get(channel_name) + + second_enter_received = asyncio.Future() + + def on_presence(msg): + if msg.client_id == 'enterer' and msg.data == 'second' and msg.action == PresenceAction.ENTER: + second_enter_received.set_result(msg) + + await listener_channel.presence.subscribe(on_presence) + + # Create enterer client + enterer_client = await TestApp.get_ably_realtime( + client_id='enterer', + use_binary_protocol=self.use_binary_protocol + ) + enterer_channel = enterer_client.channels.get(channel_name) + + await enterer_client.connection.once_async('connected') + + # First enter + await enterer_channel.presence.enter('first') + await asyncio.sleep(0.3) + + # Close and wait + await enterer_client.close() + + # Reconnect + enterer_client.connection.connect() + await enterer_client.connection.once_async('connected') + + # Second enter - should automatically reattach + await enterer_channel.presence.enter('second') + + # Should receive second enter event + msg = await asyncio.wait_for(second_enter_received, timeout=5.0) + assert msg.data == 'second' + + await listener_client.close() + await enterer_client.close() + + async def test_presence_enter_closed_error(self): + """ + Test RTP15e: Entering presence on closed connection should error. + """ + channel_name = self.get_channel_name('enter_closed') + + client = await TestApp.get_ably_realtime(use_binary_protocol=self.use_binary_protocol) + channel = client.channels.get(channel_name) + + await client.connection.once_async('connected') + + # Close the connection + await client.close() + + # Try to enter - should fail + try: + await channel.presence.enter_client('client1', 'data') + pytest.fail('Should have raised exception for closed connection') + except Exception as e: + # Should get an error about closed/failed connection + assert 'closed' in str(e).lower() or 'failed' in str(e).lower() or '80017' in str(e) + + await client.close() + + +@pytest.mark.parametrize('use_binary_protocol', [True, False], ids=['msgpack', 'json']) +class TestRealtimePresenceAutoReentry(BaseAsyncTestCase): + """Test automatic re-entry of presence after connection suspension.""" + + @pytest.fixture(autouse=True) + async def setup(self, use_binary_protocol): + """Set up test fixtures.""" + self.test_vars = await TestApp.get_test_vars() + self.use_binary_protocol = use_binary_protocol + yield + + async def test_presence_auto_reenter_after_suspend(self): + """ + Test RTP5f, RTP17, RTP17g, RTP17i: Members automatically re-enter after suspension. + + This test verifies that when a connection is suspended and then reconnected, + presence members that were entered automatically re-enter. + """ + channel_name = self.get_channel_name('auto_reenter') + + client = await TestApp.get_ably_realtime( + client_id='test_client', + use_binary_protocol=self.use_binary_protocol + ) + channel = client.channels.get(channel_name) + + await channel.attach() + + # Enter presence + await channel.presence.enter('original_data') + await asyncio.sleep(0.5) + + # Verify member is present + members = await channel.presence.get() + assert len(members) == 1 + assert members[0].client_id == 'test_client' + assert members[0].data == 'original_data' + + # Suspend the connection + await force_suspended(client) + + # Reconnect - connection will be resumed with same connection ID + client.connection.connect() + await client.connection.once_async('connected') + + # Wait for channel to reattach after suspension + await channel.once_async('attached') + + # Give time for auto-reenter to complete + # Auto-reenter sends a presence message, server ACKs it, but doesn't + # broadcast a new ENTER event because on a resumed connection with + # unchanged data, no state change occurred from the server's perspective + await asyncio.sleep(0.5) + + # Verify member is still in presence set (auto-reenter worked) + # This is the actual requirement of RTP17i - members are automatically + # re-entered after suspension, ensuring they remain in the presence set + members = await channel.presence.get() + assert len(members) >= 1 + assert any(m.client_id == 'test_client' and m.data == 'original_data' for m in members) + + await client.close() + + async def test_presence_auto_reenter_different_connid(self): + """ + Test RTP17g, RTP17g1: Auto re-entry with different connectionId. + + When connection is suspended and reconnects with a different connectionId, + verify that: + 1. A LEAVE is sent for the old connectionId + 2. An ENTER is sent for the new connectionId + 3. The new ENTER does not have the same message ID as the original + """ + channel_name = self.get_channel_name('auto_reenter_different_connid') + + # Create observer client + observer_client = await TestApp.get_ably_realtime( + client_id='observer', + use_binary_protocol=self.use_binary_protocol + ) + observer_channel = observer_client.channels.get(channel_name) + await observer_channel.attach() + + # Track presence events + events = [] + + def on_presence(msg): + events.append({ + 'action': msg.action, + 'client_id': msg.client_id, + 'connection_id': msg.connection_id, + 'id': getattr(msg, 'id', None) + }) + + await observer_channel.presence.subscribe(on_presence) + + # Create main client with remainPresentFor to control LEAVE timing + # This tells the server to send LEAVE for presence members 5 seconds after disconnect + client = await TestApp.get_ably_realtime( + client_id='test_client', + transport_params={'remainPresentFor': 1000}, + use_binary_protocol=self.use_binary_protocol + ) + channel = client.channels.get(channel_name) + + await client.connection.once_async('connected') + first_conn_id = client.connection.connection_manager.connection_id + + # Enter presence + await channel.presence.enter('test_data') + await asyncio.sleep(0.5) + + # Get the original message ID + original_msg_id = None + for event in events: + if event['action'] == PresenceAction.ENTER and event['client_id'] == 'test_client': + original_msg_id = event['id'] + break + + # Force suspension and reconnection with different connection ID + await force_suspended(client) + + # Reconnect + client.connection.connect() + await client.connection.once_async('connected') + second_conn_id = client.connection.connection_manager.connection_id + + # Connection IDs should be different after suspend + assert first_conn_id != second_conn_id + + # Wait for presence events including LEAVE (which arrives after remainPresentFor timeout) + await asyncio.sleep(2) + + # Should see LEAVE for old connection and ENTER for new connection + leave_events = [e for e in events if e['action'] == PresenceAction.LEAVE + and e['client_id'] == 'test_client'] + enter_events = [e for e in events if e['action'] == PresenceAction.ENTER + and e['client_id'] == 'test_client'] + + assert len(leave_events) >= 1, "Should have LEAVE event for old connection" + assert len(enter_events) >= 2, "Should have ENTER event for new connection" + + # Find the leave for first connection + leave_for_first = [e for e in leave_events if e['connection_id'] == first_conn_id] + assert len(leave_for_first) >= 1, "Should have LEAVE for first connection ID" + + # Find the enter for second connection + enter_for_second = [e for e in enter_events if e['connection_id'] == second_conn_id] + assert len(enter_for_second) >= 1, "Should have ENTER for second connection ID" + + # The new ENTER should have a different message ID + new_msg_id = enter_for_second[0]['id'] + if original_msg_id and new_msg_id: + assert original_msg_id != new_msg_id, "New ENTER should have different message ID" + + await observer_client.close() + await client.close() + + +@pytest.mark.parametrize('use_binary_protocol', [True, False], ids=['msgpack', 'json']) +class TestRealtimePresenceSyncBehavior(BaseAsyncTestCase): + """Test presence SYNC behavior and state management.""" + + @pytest.fixture(autouse=True) + async def setup(self, use_binary_protocol): + """Set up test fixtures.""" + self.test_vars = await TestApp.get_test_vars() + self.use_binary_protocol = use_binary_protocol + yield + + async def test_presence_refresh_on_detach(self): + """ + Test RTP15b: Presence map refresh when channel detaches and reattaches. + + When a channel detaches and then reattaches, and the presence set has + changed during that time, verify that the presence map is correctly + refreshed with the new state. + """ + channel_name = self.get_channel_name('refresh_on_detach') + + # Client that manages presence + manager_client = await TestApp.get_ably_realtime( + client_id='*', + use_binary_protocol=self.use_binary_protocol + ) + manager_channel = manager_client.channels.get(channel_name) + + # Observer client that will detach/reattach + observer_client = await TestApp.get_ably_realtime( + client_id='observer', + use_binary_protocol=self.use_binary_protocol + ) + observer_channel = observer_client.channels.get(channel_name) + + # Enter two members + await manager_channel.presence.enter_client('client_one', 'data_one') + await manager_channel.presence.enter_client('client_two', 'data_two') + await asyncio.sleep(0.3) + + # Observer attaches and verifies + await observer_channel.attach() + members = await observer_channel.presence.get() + assert len(members) == 2 + client_ids = {m.client_id for m in members} + assert 'client_one' in client_ids + assert 'client_two' in client_ids + + # Observer detaches + await observer_channel.detach() + + # Change presence while observer is detached + await manager_channel.presence.leave_client('client_two') + await manager_channel.presence.enter_client('client_three', 'data_three') + await asyncio.sleep(0.3) + + # Track presence events on observer + presence_events = [] + + def on_presence(msg): + presence_events.append(msg.client_id) + + await observer_channel.presence.subscribe(on_presence) + + # Reattach and wait for sync + await observer_channel.attach() + await asyncio.sleep(1.0) + + # Should receive PRESENT events for current members + members = await observer_channel.presence.get() + assert len(members) == 2 + client_ids = {m.client_id for m in members} + assert 'client_one' in client_ids + assert 'client_three' in client_ids + assert 'client_two' not in client_ids + + await manager_client.close() + await observer_client.close() + + async def test_suspended_preserves_presence(self): + """ + Test RTP5f, RTP11d: Presence map is preserved during SUSPENDED state. + + Verify that: + 1. Presence map is preserved when connection goes to SUSPENDED + 2. get() with waitForSync=False works while suspended + 3. get() without waitForSync returns error while suspended + 4. Only changed members trigger events after reconnection + """ + channel_name = self.get_channel_name('suspended_preserves') + + # Create multiple clients + main_client = await TestApp.get_ably_realtime( + client_id='main', + use_binary_protocol=self.use_binary_protocol + ) + continuous_client = await TestApp.get_ably_realtime( + client_id='continuous', + use_binary_protocol=self.use_binary_protocol + ) + leaves_client = await TestApp.get_ably_realtime( + client_id='leaves', + use_binary_protocol=self.use_binary_protocol + ) + + main_channel = main_client.channels.get(channel_name) + continuous_channel = continuous_client.channels.get(channel_name) + leaves_channel = leaves_client.channels.get(channel_name) + + # All enter presence + await main_channel.presence.enter('main_data') + await continuous_channel.presence.enter('continuous_data') + await leaves_channel.presence.enter('leaves_data') + await asyncio.sleep(0.5) + + # Verify all present + members = await main_channel.presence.get() + assert len(members) == 3 + client_ids = {m.client_id for m in members} + assert client_ids == {'main', 'continuous', 'leaves'} + + # Simulate suspension on main client + await force_suspended(main_client) + + # leaves_client leaves while main is suspended + await leaves_client.close() + await asyncio.sleep(0.3) + + # Track presence events on main after reconnect + presence_events = [] + + def on_presence(msg): + presence_events.append({ + 'action': msg.action, + 'client_id': msg.client_id + }) + + await main_channel.presence.subscribe(on_presence) + + # Reconnect main client + main_client.connection.connect() + await main_client.connection.once_async('connected') + await main_channel.once_async('attached') + + # Wait for presence sync + await asyncio.sleep(1.0) + + # Should only see LEAVE for leaves_client + leave_events = [e for e in presence_events + if e['action'] == PresenceAction.LEAVE and e['client_id'] == 'leaves'] + assert len(leave_events) >= 1, "Should see LEAVE for leaves client" + + # Final state should have main and continuous + members = await main_channel.presence.get() + assert len(members) >= 2 + client_ids = {m.client_id for m in members} + assert 'main' in client_ids + assert 'continuous' in client_ids + + await main_client.close() + await continuous_client.close()