Skip to content

Commit 7d113db

Browse files
authored
chore: add socket based tests for the new APIs (#677)
* chore: add socket based tests for the new APIs This adds socket based tests for the new API so we can have tests that exercise the exact byte streams that are sent and received, similar to the old API. These currently are testing the lower level "session" and "channel" just to get the request infrastructure ready to be used in a real end to end device manaver tests. This violates our rule that tests align with a module, but this is an exception given its for end to end tests. * chore: Fix local session
1 parent 33c174b commit 7d113db

File tree

6 files changed

+373
-92
lines changed

6 files changed

+373
-92
lines changed

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
2121
from tests.mock_data import HOME_DATA_RAW, HOME_DATA_SCENES_RAW, TEST_LOCAL_API_HOST, USER_DATA
2222

23+
# Fixtures for the newer APIs in subdirectories
24+
pytest_plugins = [
25+
"tests.mqtt_fixtures",
26+
]
27+
2328
_LOGGER = logging.getLogger(__name__)
2429

2530

tests/e2e/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""End-to-end tests package."""

tests/e2e/test_local_session.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""End-to-end tests for LocalChannel using fake sockets."""
2+
3+
import asyncio
4+
from collections.abc import AsyncGenerator, Callable, Generator
5+
from queue import Queue
6+
from typing import Any
7+
from unittest.mock import Mock, patch
8+
9+
import pytest
10+
11+
from roborock.devices.local_channel import LocalChannel
12+
from roborock.protocol import create_local_decoder, create_local_encoder
13+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
14+
from tests.conftest import RequestHandler
15+
from tests.mock_data import LOCAL_KEY
16+
17+
TEST_HOST = "192.168.1.100"
18+
TEST_DEVICE_UID = "test_device_uid"
19+
TEST_CONNECT_NONCE = 12345
20+
TEST_ACK_NONCE = 67890
21+
TEST_RANDOM = 13579
22+
23+
24+
@pytest.fixture(name="mock_create_local_connection")
25+
def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]:
26+
"""Fixture that overrides the transport creation to wire it up to the mock socket."""
27+
28+
async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]:
29+
protocol = protocol_factory()
30+
31+
def handle_write(data: bytes) -> None:
32+
response = request_handler(data)
33+
if response is not None:
34+
# Call data_received directly to avoid loop scheduling issues in test
35+
protocol.data_received(response)
36+
37+
closed = asyncio.Event()
38+
39+
mock_transport = Mock()
40+
mock_transport.write = handle_write
41+
mock_transport.close = closed.set
42+
mock_transport.is_reading = lambda: not closed.is_set()
43+
44+
return (mock_transport, protocol)
45+
46+
with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop:
47+
mock_loop.return_value.create_connection.side_effect = create_connection
48+
yield
49+
50+
51+
@pytest.fixture(name="local_channel")
52+
async def local_channel_fixture(mock_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
53+
with patch(
54+
"roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE, device_uid=TEST_DEVICE_UID
55+
):
56+
channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID)
57+
yield channel
58+
channel.close()
59+
60+
61+
def build_response(
62+
protocol: RoborockMessageProtocol,
63+
seq: int,
64+
payload: bytes,
65+
random: int,
66+
) -> bytes:
67+
"""Build an encoded response message."""
68+
if protocol == RoborockMessageProtocol.HELLO_RESPONSE:
69+
encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=None)
70+
else:
71+
encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)
72+
73+
msg = RoborockMessage(
74+
protocol=protocol,
75+
random=random,
76+
seq=seq,
77+
payload=payload,
78+
)
79+
return encoder(msg)
80+
81+
82+
async def test_connect(
83+
local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes]
84+
) -> None:
85+
"""Test connecting to the device."""
86+
# Queue HELLO response with payload to ensure it can be parsed
87+
response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM))
88+
89+
await local_channel.connect()
90+
91+
assert local_channel.is_connected
92+
assert received_requests.qsize() == 1
93+
94+
# Verify HELLO request
95+
request_bytes = received_requests.get()
96+
# Note: We cannot use create_local_decoder here because HELLO_REQUEST has payload=None
97+
# which causes MessageParser to fail parsing. For now we verify the raw bytes.
98+
99+
# Protocol is at offset 19 (2 bytes)
100+
# Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
101+
assert len(request_bytes) >= 21
102+
protocol_bytes = request_bytes[19:21]
103+
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST
104+
105+
106+
async def test_send_command(
107+
local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes]
108+
) -> None:
109+
"""Test sending a command."""
110+
# Queue HELLO response
111+
response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM))
112+
113+
await local_channel.connect()
114+
115+
# Clear requests from handshake
116+
while not received_requests.empty():
117+
received_requests.get()
118+
119+
# Send command
120+
cmd_seq = 123
121+
msg = RoborockMessage(
122+
protocol=RoborockMessageProtocol.RPC_REQUEST,
123+
seq=cmd_seq,
124+
payload=b'{"method":"get_status"}',
125+
)
126+
127+
await local_channel.publish(msg)
128+
129+
# Verify request
130+
assert received_requests.qsize() == 1
131+
request_bytes = received_requests.get()
132+
133+
# Decode request
134+
decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)
135+
msgs = list(decoder(request_bytes))
136+
assert len(msgs) == 1
137+
assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
138+
assert msgs[0].payload == b'{"method":"get_status"}'

tests/e2e/test_mqtt_session.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""End-to-end tests for MQTT session.
2+
3+
These tests use a fake MQTT broker to verify the session implementation. We
4+
mock out the lower level socket connections to simulate a broker which gets us
5+
close to an "end to end" test without needing an actual MQTT broker server.
6+
7+
These are higher level tests that the similar tests in tests/mqtt/test_roborock_session.py
8+
which use mocks to verify specific behaviors.
9+
"""
10+
11+
from collections.abc import AsyncGenerator, Callable
12+
from queue import Queue
13+
14+
import pytest
15+
16+
from roborock.mqtt.roborock_session import create_mqtt_session
17+
from roborock.mqtt.session import MqttSession
18+
from roborock.protocol import MessageParser
19+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
20+
from tests import mqtt_packet
21+
from tests.mock_data import LOCAL_KEY
22+
from tests.mqtt_fixtures import FAKE_PARAMS, Subscriber
23+
24+
25+
@pytest.fixture(autouse=True)
26+
def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None:
27+
"""Automatically use the mock mqtt client fixture."""
28+
29+
30+
@pytest.fixture(autouse=True)
31+
def auto_fast_backoff(fast_backoff_fixture: None) -> None:
32+
"""Automatically use the fast backoff fixture."""
33+
34+
35+
@pytest.fixture(autouse=True)
36+
def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None:
37+
"""Fixture to mock the MQTT connection.
38+
39+
This is here to pull in the mock socket pixtures into all tests used here.
40+
"""
41+
42+
43+
@pytest.fixture(name="session")
44+
async def session_fixture(
45+
push_response: Callable[[bytes], None],
46+
) -> AsyncGenerator[MqttSession, None]:
47+
"""Fixture to create a new connected MQTT session."""
48+
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
49+
session = await create_mqtt_session(FAKE_PARAMS)
50+
assert session.connected
51+
try:
52+
yield session
53+
finally:
54+
await session.close()
55+
56+
57+
async def test_session_e2e_receive_message(push_response: Callable[[bytes], None], session: MqttSession) -> None:
58+
"""Test receiving a real Roborock message through the session."""
59+
assert session.connected
60+
61+
# Subscribe to the topic. We'll next construct and push a message.
62+
push_response(mqtt_packet.gen_suback(mid=1))
63+
subscriber = Subscriber()
64+
await session.subscribe("topic-1", subscriber.append)
65+
66+
msg = RoborockMessage(
67+
protocol=RoborockMessageProtocol.RPC_RESPONSE,
68+
payload=b'{"result":"ok"}',
69+
seq=123,
70+
)
71+
payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False)
72+
73+
# Simulate receiving the message from the broker
74+
push_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload))
75+
76+
# Verify it was dispatched to the subscriber
77+
await subscriber.wait()
78+
assert len(subscriber.messages) == 1
79+
received_payload = subscriber.messages[0]
80+
assert received_payload == payload
81+
82+
# Verify the message payload contents
83+
parsed_msgs, _ = MessageParser.parse(received_payload, local_key=LOCAL_KEY)
84+
assert len(parsed_msgs) == 1
85+
parsed_msg = parsed_msgs[0]
86+
assert parsed_msg.protocol == RoborockMessageProtocol.RPC_RESPONSE
87+
assert parsed_msg.seq == 123
88+
# The payload in parsed_msg should be the decrypted bytes
89+
assert parsed_msg.payload == b'{"result":"ok"}'
90+
91+
92+
async def test_session_e2e_publish_message(
93+
push_response: Callable[[bytes], None],
94+
received_requests: Queue,
95+
session: MqttSession,
96+
) -> None:
97+
"""Test publishing a real Roborock message."""
98+
99+
# Publish a message to the brokwer
100+
msg = RoborockMessage(
101+
protocol=RoborockMessageProtocol.RPC_REQUEST,
102+
payload=b'{"method":"get_status"}',
103+
seq=456,
104+
)
105+
payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False)
106+
107+
await session.publish("topic-1", payload)
108+
109+
# Verify what was sent to the broker
110+
# We expect the payload to be present in the sent bytes
111+
found = False
112+
while not received_requests.empty():
113+
request = received_requests.get()
114+
if payload in request:
115+
found = True
116+
break
117+
118+
assert found, "Published payload not found in sent requests"

0 commit comments

Comments
 (0)