Skip to content

Commit c317f8e

Browse files
authored
Organize test fixtures to prepare for more end to end tests (#699)
* chore: Organize test fixtures Rewrite the test fixtures to have a more clear split between local and mqtt fixtures. This is in prepration for running both at once for e2 tests for device manager. All fixtures are moved into a fixtures subdirectory. The helper classes that are imported into other tests are added in separate files for importing, and to avoid import warnings from pytests. This renames all the fixtures to have mqtt prefixed names and local fixtures to have local prefixed names. There is one minor change to make the local asyncio tests uses asyncio Queues rather than blocking queues. * chore: Address co-pilot review feedback * chore: fix lint errors * chore: Remove duplicate captured request log
1 parent 8cd51cc commit c317f8e

22 files changed

+938
-752
lines changed

tests/conftest.py

Lines changed: 0 additions & 532 deletions
This file was deleted.

tests/devices/test_a01_channel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
RoborockMessage,
1212
RoborockMessageProtocol,
1313
)
14-
15-
from ..conftest import FakeChannel
14+
from tests.fixtures.channel_fixtures import FakeChannel
1615

1716

1817
@pytest.fixture

tests/devices/test_v1_channel.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525
from roborock.protocols.v1_protocol import MapResponse, SecurityData, V1RpcChannel
2626
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
2727
from roborock.roborock_typing import RoborockCommand
28-
29-
from .. import mock_data
30-
from ..conftest import FakeChannel
28+
from tests import mock_data
29+
from tests.fixtures.channel_fixtures import FakeChannel
3130

3231
USER_DATA = UserData.from_dict(mock_data.USER_DATA)
3332
TEST_DEVICE_UID = "abc123"

tests/devices/traits/a01/test_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from roborock.devices.traits.a01 import DyadApi, ZeoApi
1010
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockMessageProtocol, RoborockZeoProtocol
11-
from tests.conftest import FakeChannel
11+
from tests.fixtures.channel_fixtures import FakeChannel
1212
from tests.protocols.common import build_a01_message
1313

1414

tests/devices/traits/b01/test_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from roborock.exceptions import RoborockException
1919
from roborock.protocols.b01_protocol import B01_VERSION
2020
from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol
21-
from tests.conftest import FakeChannel
21+
from tests.fixtures.channel_fixtures import FakeChannel
2222

2323

2424
def build_b01_message(message: dict[Any, Any], msg_id: str = "123456789", seq: int = 2020) -> RoborockMessage:

tests/e2e/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
11
"""End-to-end tests package."""
2+
3+
pytest_plugins = [
4+
"tests.fixtures.logging_fixtures",
5+
"tests.fixtures.local_async_fixtures",
6+
"tests.fixtures.pahomqtt_fixtures",
7+
"tests.fixtures.aiomqtt_fixtures",
8+
]

tests/e2e/test_local_session.py

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
"""End-to-end tests for LocalChannel using fake sockets."""
22

33
import asyncio
4-
from collections.abc import AsyncGenerator, Callable, Generator
5-
from queue import Queue
6-
from typing import Any
7-
from unittest.mock import Mock, patch
4+
from collections.abc import AsyncGenerator
5+
from unittest.mock import patch
86

97
import pytest
108

119
from roborock.devices.local_channel import LocalChannel
1210
from roborock.protocol import create_local_decoder, create_local_encoder
1311
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
14-
from tests.conftest import RequestHandler
1512
from tests.mock_data import LOCAL_KEY
1613

1714
TEST_HOST = "192.168.1.100"
@@ -21,35 +18,8 @@
2118
TEST_RANDOM = 13579
2219

2320

24-
@pytest.fixture(name="mock_create_local_connection")
25-
def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]:
26-
"""Fixture that overrides the transport creation to wire it up to the mock socket."""
27-
28-
async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]:
29-
protocol = protocol_factory()
30-
31-
def handle_write(data: bytes) -> None:
32-
response = request_handler(data)
33-
if response is not None:
34-
# Call data_received directly to avoid loop scheduling issues in test
35-
protocol.data_received(response)
36-
37-
closed = asyncio.Event()
38-
39-
mock_transport = Mock()
40-
mock_transport.write = handle_write
41-
mock_transport.close = closed.set
42-
mock_transport.is_reading = lambda: not closed.is_set()
43-
44-
return (mock_transport, protocol)
45-
46-
with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop:
47-
mock_loop.return_value.create_connection.side_effect = create_connection
48-
yield
49-
50-
5121
@pytest.fixture(name="local_channel")
52-
async def local_channel_fixture(mock_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
22+
async def local_channel_fixture(mock_async_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
5323
with patch(
5424
"roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE, device_uid=TEST_DEVICE_UID
5525
):
@@ -80,19 +50,23 @@ def build_response(
8050

8151

8252
async def test_connect(
83-
local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes]
53+
local_channel: LocalChannel,
54+
local_response_queue: asyncio.Queue[bytes],
55+
local_received_requests: asyncio.Queue[bytes],
8456
) -> None:
8557
"""Test connecting to the device."""
8658
# Queue HELLO response with payload to ensure it can be parsed
87-
response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM))
59+
local_response_queue.put_nowait(
60+
build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)
61+
)
8862

8963
await local_channel.connect()
9064

9165
assert local_channel.is_connected
92-
assert received_requests.qsize() == 1
66+
assert local_received_requests.qsize() == 1
9367

9468
# Verify HELLO request
95-
request_bytes = received_requests.get()
69+
request_bytes = await local_received_requests.get()
9670
# Note: We cannot use create_local_decoder here because HELLO_REQUEST has payload=None
9771
# which causes MessageParser to fail parsing. For now we verify the raw bytes.
9872

@@ -104,17 +78,21 @@ async def test_connect(
10478

10579

10680
async def test_send_command(
107-
local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes]
81+
local_channel: LocalChannel,
82+
local_response_queue: asyncio.Queue[bytes],
83+
local_received_requests: asyncio.Queue[bytes],
10884
) -> None:
10985
"""Test sending a command."""
11086
# Queue HELLO response
111-
response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM))
87+
local_response_queue.put_nowait(
88+
build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)
89+
)
11290

11391
await local_channel.connect()
11492

11593
# Clear requests from handshake
116-
while not received_requests.empty():
117-
received_requests.get()
94+
while not local_received_requests.empty():
95+
await local_received_requests.get()
11896

11997
# Send command
12098
cmd_seq = 123
@@ -127,8 +105,8 @@ async def test_send_command(
127105
await local_channel.publish(msg)
128106

129107
# Verify request
130-
assert received_requests.qsize() == 1
131-
request_bytes = received_requests.get()
108+
request_bytes = await local_received_requests.get()
109+
assert local_received_requests.empty()
132110

133111
# Decode request
134112
decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)

tests/e2e/test_mqtt_session.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from roborock.protocol import MessageParser
1919
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
2020
from tests import mqtt_packet
21+
from tests.fixtures.mqtt import FAKE_PARAMS, Subscriber
2122
from tests.mock_data import LOCAL_KEY
22-
from tests.mqtt_fixtures import FAKE_PARAMS, Subscriber
2323

2424

2525
@pytest.fixture(autouse=True)
26-
def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None:
26+
def auto_mock_mqtt_client(mock_aiomqtt_client: None) -> None:
2727
"""Automatically use the mock mqtt client fixture."""
2828

2929

@@ -33,7 +33,7 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None:
3333

3434

3535
@pytest.fixture(autouse=True)
36-
def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None:
36+
def mqtt_server_fixture(mock_paho_mqtt_create_connection: None, mock_paho_mqtt_select: None) -> None:
3737
"""Fixture to mock the MQTT connection.
3838
3939
This is here to pull in the mock socket pixtures into all tests used here.
@@ -42,10 +42,10 @@ def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None
4242

4343
@pytest.fixture(name="session")
4444
async def session_fixture(
45-
push_response: Callable[[bytes], None],
45+
push_mqtt_response: Callable[[bytes], None],
4646
) -> AsyncGenerator[MqttSession, None]:
4747
"""Fixture to create a new connected MQTT session."""
48-
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
48+
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
4949
session = await create_mqtt_session(FAKE_PARAMS)
5050
assert session.connected
5151
try:
@@ -54,12 +54,12 @@ async def session_fixture(
5454
await session.close()
5555

5656

57-
async def test_session_e2e_receive_message(push_response: Callable[[bytes], None], session: MqttSession) -> None:
57+
async def test_session_e2e_receive_message(push_mqtt_response: Callable[[bytes], None], session: MqttSession) -> None:
5858
"""Test receiving a real Roborock message through the session."""
5959
assert session.connected
6060

6161
# Subscribe to the topic. We'll next construct and push a message.
62-
push_response(mqtt_packet.gen_suback(mid=1))
62+
push_mqtt_response(mqtt_packet.gen_suback(mid=1))
6363
subscriber = Subscriber()
6464
await session.subscribe("topic-1", subscriber.append)
6565

@@ -71,12 +71,13 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None
7171
payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False)
7272

7373
# Simulate receiving the message from the broker
74-
push_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload))
74+
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload))
7575

7676
# Verify it was dispatched to the subscriber
7777
await subscriber.wait()
7878
assert len(subscriber.messages) == 1
7979
received_payload = subscriber.messages[0]
80+
assert isinstance(received_payload, bytes)
8081
assert received_payload == payload
8182

8283
# Verify the message payload contents
@@ -90,8 +91,8 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None
9091

9192

9293
async def test_session_e2e_publish_message(
93-
push_response: Callable[[bytes], None],
94-
received_requests: Queue,
94+
push_mqtt_response: Callable[[bytes], None],
95+
mqtt_received_requests: Queue,
9596
session: MqttSession,
9697
) -> None:
9798
"""Test publishing a real Roborock message."""
@@ -109,8 +110,8 @@ async def test_session_e2e_publish_message(
109110
# Verify what was sent to the broker
110111
# We expect the payload to be present in the sent bytes
111112
found = False
112-
while not received_requests.empty():
113-
request = received_requests.get()
113+
while not mqtt_received_requests.empty():
114+
request = mqtt_received_requests.get()
114115
if payload in request:
115116
found = True
116117
break

tests/fixtures/__init__.py

Whitespace-only changes.
Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,11 @@
1010
import paho.mqtt.client as mqtt
1111
import pytest
1212

13-
from roborock.mqtt.session import MqttParams
14-
from tests.conftest import FakeSocketHandler
13+
from .mqtt import FakeMqttSocketHandler
1514

16-
FAKE_PARAMS = MqttParams(
17-
host="localhost",
18-
port=1883,
19-
tls=False,
20-
username="username",
21-
password="password",
22-
timeout=10.0,
23-
)
2415

25-
26-
class Subscriber:
27-
"""Mock subscriber class.
28-
29-
We use this to hold on to received messages for verification.
30-
"""
31-
32-
def __init__(self) -> None:
33-
self.messages: list[bytes] = []
34-
self._event = asyncio.Event()
35-
36-
def append(self, message: bytes) -> None:
37-
self.messages.append(message)
38-
self._event.set()
39-
40-
async def wait(self) -> None:
41-
await asyncio.wait_for(self._event.wait(), timeout=1.0)
42-
self._event.clear()
43-
44-
45-
@pytest.fixture
46-
async def mock_mqtt_client_fixture() -> AsyncGenerator[None, None]:
16+
@pytest.fixture(name="mock_aiomqtt_client")
17+
async def mock_aiomqtt_client_fixture() -> AsyncGenerator[None, None]:
4718
"""Fixture to patch the MQTT underlying sync client.
4819
4920
The tests use fake sockets, so this ensures that the async mqtt client does not
@@ -94,11 +65,13 @@ def fast_backoff_fixture() -> Generator[None, None, None]:
9465

9566

9667
@pytest.fixture
97-
def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]:
68+
def push_mqtt_response(
69+
mqtt_response_queue: Queue, fake_mqtt_socket_handler: FakeMqttSocketHandler
70+
) -> Callable[[bytes], None]:
9871
"""Fixture to push a response to the client."""
9972

10073
def _push(data: bytes) -> None:
101-
response_queue.put(data)
102-
fake_socket_handler.push_response()
74+
mqtt_response_queue.put(data)
75+
fake_mqtt_socket_handler.push_response()
10376

10477
return _push

0 commit comments

Comments
 (0)