|
| 1 | +"""End-to-end tests for LocalChannel using fake sockets.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +from collections.abc import AsyncGenerator, Callable, Generator |
| 5 | +from queue import Queue |
| 6 | +from typing import Any |
| 7 | +from unittest.mock import Mock, patch |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
| 11 | +from roborock.devices.local_channel import LocalChannel |
| 12 | +from roborock.protocol import create_local_decoder, create_local_encoder |
| 13 | +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol |
| 14 | +from tests.conftest import RequestHandler |
| 15 | +from tests.mock_data import LOCAL_KEY |
| 16 | + |
| 17 | +TEST_HOST = "192.168.1.100" |
| 18 | +TEST_DEVICE_UID = "test_device_uid" |
| 19 | +TEST_CONNECT_NONCE = 12345 |
| 20 | +TEST_ACK_NONCE = 67890 |
| 21 | +TEST_RANDOM = 13579 |
| 22 | + |
| 23 | + |
| 24 | +@pytest.fixture(name="mock_create_local_connection") |
| 25 | +def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]: |
| 26 | + """Fixture that overrides the transport creation to wire it up to the mock socket.""" |
| 27 | + |
| 28 | + async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]: |
| 29 | + protocol = protocol_factory() |
| 30 | + |
| 31 | + def handle_write(data: bytes) -> None: |
| 32 | + response = request_handler(data) |
| 33 | + if response is not None: |
| 34 | + # Call data_received directly to avoid loop scheduling issues in test |
| 35 | + protocol.data_received(response) |
| 36 | + |
| 37 | + closed = asyncio.Event() |
| 38 | + |
| 39 | + mock_transport = Mock() |
| 40 | + mock_transport.write = handle_write |
| 41 | + mock_transport.close = closed.set |
| 42 | + mock_transport.is_reading = lambda: not closed.is_set() |
| 43 | + |
| 44 | + return (mock_transport, protocol) |
| 45 | + |
| 46 | + with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop: |
| 47 | + mock_loop.return_value.create_connection.side_effect = create_connection |
| 48 | + yield |
| 49 | + |
| 50 | + |
| 51 | +@pytest.fixture(name="local_channel") |
| 52 | +async def local_channel_fixture(mock_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]: |
| 53 | + with patch( |
| 54 | + "roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE, device_uid=TEST_DEVICE_UID |
| 55 | + ): |
| 56 | + channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID) |
| 57 | + yield channel |
| 58 | + channel.close() |
| 59 | + |
| 60 | + |
| 61 | +def build_response( |
| 62 | + protocol: RoborockMessageProtocol, |
| 63 | + seq: int, |
| 64 | + payload: bytes, |
| 65 | + random: int, |
| 66 | +) -> bytes: |
| 67 | + """Build an encoded response message.""" |
| 68 | + if protocol == RoborockMessageProtocol.HELLO_RESPONSE: |
| 69 | + encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=None) |
| 70 | + else: |
| 71 | + encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE) |
| 72 | + |
| 73 | + msg = RoborockMessage( |
| 74 | + protocol=protocol, |
| 75 | + random=random, |
| 76 | + seq=seq, |
| 77 | + payload=payload, |
| 78 | + ) |
| 79 | + return encoder(msg) |
| 80 | + |
| 81 | + |
| 82 | +async def test_connect( |
| 83 | + local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes] |
| 84 | +) -> None: |
| 85 | + """Test connecting to the device.""" |
| 86 | + # Queue HELLO response with payload to ensure it can be parsed |
| 87 | + response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)) |
| 88 | + |
| 89 | + await local_channel.connect() |
| 90 | + |
| 91 | + assert local_channel.is_connected |
| 92 | + assert received_requests.qsize() == 1 |
| 93 | + |
| 94 | + # Verify HELLO request |
| 95 | + request_bytes = received_requests.get() |
| 96 | + # Note: We cannot use create_local_decoder here because HELLO_REQUEST has payload=None |
| 97 | + # which causes MessageParser to fail parsing. For now we verify the raw bytes. |
| 98 | + |
| 99 | + # Protocol is at offset 19 (2 bytes) |
| 100 | + # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19 |
| 101 | + assert len(request_bytes) >= 21 |
| 102 | + protocol_bytes = request_bytes[19:21] |
| 103 | + assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST |
| 104 | + |
| 105 | + |
| 106 | +async def test_send_command( |
| 107 | + local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes] |
| 108 | +) -> None: |
| 109 | + """Test sending a command.""" |
| 110 | + # Queue HELLO response |
| 111 | + response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)) |
| 112 | + |
| 113 | + await local_channel.connect() |
| 114 | + |
| 115 | + # Clear requests from handshake |
| 116 | + while not received_requests.empty(): |
| 117 | + received_requests.get() |
| 118 | + |
| 119 | + # Send command |
| 120 | + cmd_seq = 123 |
| 121 | + msg = RoborockMessage( |
| 122 | + protocol=RoborockMessageProtocol.RPC_REQUEST, |
| 123 | + seq=cmd_seq, |
| 124 | + payload=b'{"method":"get_status"}', |
| 125 | + ) |
| 126 | + |
| 127 | + await local_channel.publish(msg) |
| 128 | + |
| 129 | + # Verify request |
| 130 | + assert received_requests.qsize() == 1 |
| 131 | + request_bytes = received_requests.get() |
| 132 | + |
| 133 | + # Decode request |
| 134 | + decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE) |
| 135 | + msgs = list(decoder(request_bytes)) |
| 136 | + assert len(msgs) == 1 |
| 137 | + assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST |
| 138 | + assert msgs[0].payload == b'{"method":"get_status"}' |
0 commit comments