Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ably/realtime/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import functools
import logging
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -64,7 +65,7 @@ async def close(self) -> None:
connection without an explicit call to connect()
"""
self.connection_manager.request_state(ConnectionState.CLOSING)
await self.once_async(ConnectionState.CLOSED)
await self._when_state(ConnectionState.CLOSED)

# RTN13
async def ping(self) -> float:
Expand All @@ -86,6 +87,13 @@ async def ping(self) -> float:
"""
return await self.__connection_manager.ping()

def _when_state(self, state: ConnectionState):
if self.state == state:
fut = asyncio.get_event_loop().create_future()
fut.set_result(None)
return fut
return self.once_async(state)

def _on_state_update(self, state_change: ConnectionStateChange) -> None:
log.info(f'Connection state changing from {self.state} to {state_change.current}')
self.__state = state_change.current
Expand Down
28 changes: 26 additions & 2 deletions ably/realtime/connectionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ def enact_state_change(self, state: ConnectionState, reason: AblyException | Non
self.__state = state
if reason:
self.__error_reason = reason

# RTN16d: Clear connection state when entering SUSPENDED or terminal states
if state == ConnectionState.SUSPENDED or state in (
ConnectionState.CLOSED,
ConnectionState.FAILED
):
self.__connection_details = None
self.connection_id = None
self.__connection_key = None
self.msg_serial = 0

self._emit('connectionstate', ConnectionStateChange(current_state, state, state, reason))

def check_connection(self) -> bool:
Expand All @@ -157,6 +168,10 @@ async def __get_transport_params(self) -> dict:
# RTN2a: Set format to msgpack if use_binary_protocol is enabled
if self.options.use_binary_protocol:
params["format"] = "msgpack"

# Add any custom transport params from options
params.update(self.options.transport_params)

return params

async def close_impl(self) -> None:
Expand All @@ -165,13 +180,23 @@ async def close_impl(self) -> None:
self.cancel_suspend_timer()
self.start_transition_timer(ConnectionState.CLOSING, fail_state=ConnectionState.CLOSED)
if self.transport:
await self.transport.dispose()
# Try to send protocol CLOSE message in the background
asyncio.create_task(self.transport.close())
# Yield to event loop to give the close message a chance to send
await asyncio.sleep(0)
await self.transport.dispose() # Dispose transport resources
if self.connect_base_task:
self.connect_base_task.cancel()
if self.disconnect_transport_task:
await self.disconnect_transport_task
self.cancel_retry_timer()

# Clear connection details to prevent resume on next connect
# When explicitly closed, we want a fresh connection, not a resume
self.__connection_details = None
self.connection_id = None
self.msg_serial = 0

self.notify_state(ConnectionState.CLOSED)

async def send_protocol_message(self, protocol_message: dict) -> None:
Expand Down Expand Up @@ -648,7 +673,6 @@ def on_suspend_timer_expire() -> None:
AblyException("Connection to server unavailable", 400, 80002)
)
self.__fail_state = ConnectionState.SUSPENDED
self.__connection_details = None

self.suspend_timer = Timer(Defaults.connection_state_ttl, on_suspend_timer_expire)

Expand Down
Loading