Skip to content

Commit 8e43167

Browse files
committed
chore: add socket based tests for the new APIs
This adds socket based tests for the new API so we can have tests that exercise the exact byte streams that are sent and received, similar to the old API. These currently are testing the lower level "session" and "channel" just to get the request infrastructure ready to be used in a real end to end device manaver tests. This violates our rule that tests align with a module, but this is an exception given its for end to end tests.
1 parent 4c4051e commit 8e43167

File tree

3 files changed

+322
-0
lines changed

3 files changed

+322
-0
lines changed

tests/e2e/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""End-to-end tests package."""

tests/e2e/test_local_session.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""End-to-end tests for LocalChannel using fake sockets."""
2+
3+
import asyncio
4+
from collections.abc import AsyncGenerator, Callable, Generator
5+
from queue import Queue
6+
from typing import Any
7+
from unittest.mock import Mock, patch
8+
9+
import pytest
10+
11+
from roborock.devices.local_channel import LocalChannel
12+
from roborock.protocol import create_local_decoder, create_local_encoder
13+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
14+
from tests.conftest import RequestHandler
15+
from tests.mock_data import LOCAL_KEY
16+
17+
TEST_HOST = "192.168.1.100"
18+
TEST_DEVICE_UID = "test_device_uid"
19+
TEST_CONNECT_NONCE = 12345
20+
TEST_ACK_NONCE = 67890
21+
TEST_RANDOM = 13579
22+
23+
24+
@pytest.fixture(name="mock_create_local_connection")
25+
def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]:
26+
"""Fixture that overrides the transport creation to wire it up to the mock socket."""
27+
28+
async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]:
29+
protocol = protocol_factory()
30+
31+
def handle_write(data: bytes) -> None:
32+
response = request_handler(data)
33+
if response is not None:
34+
# Call data_received directly to avoid loop scheduling issues in test
35+
protocol.data_received(response)
36+
37+
closed = asyncio.Event()
38+
39+
mock_transport = Mock()
40+
mock_transport.write = handle_write
41+
mock_transport.close = closed.set
42+
mock_transport.is_reading = lambda: not closed.is_set()
43+
44+
return (mock_transport, protocol)
45+
46+
with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop:
47+
mock_loop.return_value.create_connection.side_effect = create_connection
48+
yield
49+
50+
51+
@pytest.fixture(name="local_channel")
52+
async def local_channel_fixture(mock_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
53+
with patch("roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE):
54+
channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY)
55+
yield channel
56+
channel.close()
57+
58+
59+
def build_response(
60+
protocol: RoborockMessageProtocol,
61+
seq: int,
62+
payload: bytes,
63+
random: int,
64+
) -> bytes:
65+
"""Build an encoded response message."""
66+
if protocol == RoborockMessageProtocol.HELLO_RESPONSE:
67+
encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=None)
68+
else:
69+
encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)
70+
71+
msg = RoborockMessage(
72+
protocol=protocol,
73+
random=random,
74+
seq=seq,
75+
payload=payload,
76+
)
77+
return encoder(msg)
78+
79+
80+
async def test_connect(local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes]):
81+
"""Test connecting to the device."""
82+
# Queue HELLO response with payload to ensure it can be parsed
83+
response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM))
84+
85+
await local_channel.connect()
86+
87+
assert local_channel.is_connected
88+
assert received_requests.qsize() == 1
89+
90+
# Verify HELLO request
91+
request_bytes = received_requests.get()
92+
# Note: We cannot use create_local_decoder here because HELLO_REQUEST has payload=None
93+
# which causes MessageParser to fail parsing. For now we verify the raw bytes.
94+
95+
# Protocol is at offset 19 (2 bytes)
96+
# Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
97+
assert len(request_bytes) >= 21
98+
protocol_bytes = request_bytes[19:21]
99+
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST
100+
101+
102+
async def test_send_command(local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes]):
103+
"""Test sending a command."""
104+
# Queue HELLO response
105+
response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM))
106+
107+
await local_channel.connect()
108+
109+
# Clear requests from handshake
110+
while not received_requests.empty():
111+
received_requests.get()
112+
113+
# Send command
114+
cmd_seq = 123
115+
msg = RoborockMessage(
116+
protocol=RoborockMessageProtocol.RPC_REQUEST,
117+
seq=cmd_seq,
118+
payload=b'{"method":"get_status"}',
119+
)
120+
121+
await local_channel.publish(msg)
122+
123+
# Verify request
124+
assert received_requests.qsize() == 1
125+
request_bytes = received_requests.get()
126+
127+
# Decode request
128+
decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)
129+
msgs = list(decoder(request_bytes))
130+
assert len(msgs) == 1
131+
assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
132+
assert msgs[0].payload == b'{"method":"get_status"}'

tests/e2e/test_mqtt_session.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""End-to-end tests for MQTT session.
2+
3+
These tests use a fake MQTT broker to verify the session implementation. We
4+
mock out the lower level socket connections to simulate a broker which gets us
5+
close to an "end to end" test without needing an actual MQTT broker server.
6+
"""
7+
8+
import asyncio
9+
from collections.abc import AsyncGenerator, Callable, Generator
10+
from queue import Queue
11+
from typing import Any
12+
from unittest.mock import patch
13+
14+
import paho.mqtt.client as mqtt
15+
import pytest
16+
17+
from roborock.mqtt.roborock_session import create_mqtt_session
18+
from roborock.mqtt.session import MqttParams, MqttSession
19+
from roborock.protocol import MessageParser
20+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
21+
from tests import mqtt_packet
22+
from tests.conftest import FakeSocketHandler
23+
from tests.mock_data import LOCAL_KEY
24+
25+
FAKE_PARAMS = MqttParams(
26+
host="localhost",
27+
port=1883,
28+
tls=False,
29+
username="username",
30+
password="password",
31+
timeout=10.0,
32+
)
33+
34+
35+
@pytest.fixture(autouse=True)
36+
async def mock_client_fixture() -> AsyncGenerator[None, None]:
37+
"""Fixture to patch the MQTT underlying sync client.
38+
39+
The tests use fake sockets, so this ensures that the async mqtt client does not
40+
attempt to listen on them directly. We instead just poll the socket for
41+
data ourselves.
42+
"""
43+
44+
event_loop = asyncio.get_running_loop()
45+
46+
orig_class = mqtt.Client
47+
48+
async def poll_sockets(client: mqtt.Client) -> None:
49+
"""Poll the mqtt client sockets in a loop to pick up new data."""
50+
while True:
51+
event_loop.call_soon_threadsafe(client.loop_read)
52+
event_loop.call_soon_threadsafe(client.loop_write)
53+
await asyncio.sleep(0.01)
54+
55+
task: asyncio.Task[None] | None = None
56+
57+
def new_client(*args: Any, **kwargs: Any) -> mqtt.Client:
58+
"""Create a new mqtt client and start the socket polling task."""
59+
nonlocal task
60+
client = orig_class(*args, **kwargs)
61+
task = event_loop.create_task(poll_sockets(client))
62+
return client
63+
64+
with (
65+
patch("aiomqtt.client.Client._on_socket_open"),
66+
patch("aiomqtt.client.Client._on_socket_close"),
67+
patch("aiomqtt.client.Client._on_socket_register_write"),
68+
patch("aiomqtt.client.Client._on_socket_unregister_write"),
69+
patch("aiomqtt.client.mqtt.Client", side_effect=new_client),
70+
):
71+
yield
72+
if task:
73+
task.cancel()
74+
75+
76+
@pytest.fixture(autouse=True)
77+
def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None:
78+
"""Fixture to mock the MQTT connection."""
79+
80+
81+
@pytest.fixture(autouse=True)
82+
def fast_backoff_fixture() -> Generator[None, None, None]:
83+
"""Fixture to speed up backoff."""
84+
with patch("roborock.mqtt.roborock_session.MIN_BACKOFF_INTERVAL", 0.01):
85+
yield
86+
87+
88+
@pytest.fixture
89+
def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]:
90+
"""Fixture to push a response to the client."""
91+
92+
def _push(data: bytes) -> None:
93+
response_queue.put(data)
94+
fake_socket_handler.push_response()
95+
96+
return _push
97+
98+
99+
@pytest.fixture(name="session")
100+
async def session_fixture(push_response: Callable[[bytes], None]) -> AsyncGenerator[MqttSession, None]:
101+
"""Fixture to create a new connected MQTT session."""
102+
try:
103+
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
104+
session = await create_mqtt_session(FAKE_PARAMS)
105+
assert session.connected
106+
yield session
107+
finally:
108+
await session.close()
109+
110+
111+
class Subscriber:
112+
"""Mock subscriber class.
113+
114+
We use this to hold on to received messages for verification.
115+
"""
116+
117+
def __init__(self) -> None:
118+
self.messages: list[bytes] = []
119+
self._event = asyncio.Event()
120+
121+
def append(self, message: bytes) -> None:
122+
self.messages.append(message)
123+
self._event.set()
124+
125+
async def wait(self) -> None:
126+
await asyncio.wait_for(self._event.wait(), timeout=1.0)
127+
self._event.clear()
128+
129+
130+
async def test_session_e2e_receive_message(push_response: Callable[[bytes], None], session: MqttSession) -> None:
131+
"""Test receiving a real Roborock message through the session."""
132+
assert session.connected
133+
134+
# Subscribe to the topic. We'll next construct and push a message.
135+
push_response(mqtt_packet.gen_suback(mid=1))
136+
subscriber = Subscriber()
137+
await session.subscribe("topic-1", subscriber.append)
138+
139+
msg = RoborockMessage(
140+
protocol=RoborockMessageProtocol.RPC_RESPONSE,
141+
payload=b'{"result":"ok"}',
142+
seq=123,
143+
)
144+
payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False)
145+
146+
# Simulate receiving the message from the broker
147+
push_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload))
148+
149+
# Verify it was dispatched to the subscriber
150+
await subscriber.wait()
151+
assert len(subscriber.messages) == 1
152+
received_payload = subscriber.messages[0]
153+
assert received_payload == payload
154+
155+
# Verify the message payload contents
156+
parsed_msgs, _ = MessageParser.parse(received_payload, local_key=LOCAL_KEY)
157+
assert len(parsed_msgs) == 1
158+
parsed_msg = parsed_msgs[0]
159+
assert parsed_msg.protocol == RoborockMessageProtocol.RPC_RESPONSE
160+
assert parsed_msg.seq == 123
161+
# The payload in parsed_msg should be the decrypted bytes
162+
assert parsed_msg.payload == b'{"result":"ok"}'
163+
164+
165+
async def test_session_e2e_publish_message(
166+
push_response: Callable[[bytes], None], received_requests: Queue, session: MqttSession
167+
) -> None:
168+
"""Test publishing a real Roborock message."""
169+
170+
# Publish a message to the brokwer
171+
msg = RoborockMessage(
172+
protocol=RoborockMessageProtocol.RPC_REQUEST,
173+
payload=b'{"method":"get_status"}',
174+
seq=456,
175+
)
176+
payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False)
177+
178+
await session.publish("topic-1", payload)
179+
180+
# Verify what was sent to the broker
181+
# We expect the payload to be present in the sent bytes
182+
found = False
183+
while not received_requests.empty():
184+
request = received_requests.get()
185+
if payload in request:
186+
found = True
187+
break
188+
189+
assert found, "Published payload not found in sent requests"

0 commit comments

Comments
 (0)