Skip to content

Commit 89ed5cf

Browse files
committed
Simplify future usage within the api clients
1 parent 0872691 commit 89ed5cf

File tree

10 files changed

+85
-70
lines changed

10 files changed

+85
-70
lines changed

roborock/api.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import secrets
99
import time
10-
from collections.abc import Coroutine
1110
from typing import Any
1211

1312
from .containers import (
@@ -16,7 +15,6 @@
1615
from .exceptions import (
1716
RoborockTimeout,
1817
UnknownMethodError,
19-
VacuumError,
2018
)
2119
from .roborock_future import RoborockFuture
2220
from .roborock_message import (
@@ -89,20 +87,18 @@ async def validate_connection(self) -> None:
8987
await self.async_disconnect()
9088
await self.async_connect()
9189

92-
async def _wait_response(self, request_id: int, queue: RoborockFuture) -> tuple[Any, VacuumError | None]:
90+
async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any:
9391
try:
94-
(response, err) = await queue.async_get(self.queue_timeout)
92+
response = await queue.async_get(self.queue_timeout)
9593
if response == "unknown_method":
9694
raise UnknownMethodError("Unknown method")
97-
return response, err
95+
return response
9896
except (asyncio.TimeoutError, asyncio.CancelledError):
9997
raise RoborockTimeout(f"id={request_id} Timeout after {self.queue_timeout} seconds") from None
10098
finally:
10199
self._waiting_queue.pop(request_id, None)
102100

103-
def _async_response(
104-
self, request_id: int, protocol_id: int = 0
105-
) -> Coroutine[Any, Any, tuple[Any, VacuumError | None]]:
101+
def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
106102
queue = RoborockFuture(protocol_id)
107103
if request_id in self._waiting_queue:
108104
new_id = get_next_int(10000, 32767)
@@ -115,7 +111,7 @@ def _async_response(
115111
)
116112
request_id = new_id
117113
self._waiting_queue[request_id] = queue
118-
return self._wait_response(request_id, queue)
114+
return asyncio.ensure_future(self._wait_response(request_id, queue))
119115

120116
async def send_message(self, roborock_message: RoborockMessage):
121117
raise NotImplementedError

roborock/cloud_api.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import base64
54
import logging
65
import threading
76
import typing
87
import uuid
9-
from asyncio import Lock, Task
8+
from asyncio import Lock
109
from typing import Any
1110
from urllib.parse import urlparse
1211

@@ -65,7 +64,7 @@ def on_connect(self, *args, **kwargs):
6564
message = f"Failed to connect ({mqtt.error_string(rc)})"
6665
self._logger.error(message)
6766
if connection_queue:
68-
connection_queue.resolve((None, VacuumError(message)))
67+
connection_queue.set_exception(VacuumError(message))
6968
return
7069
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
7170
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
@@ -74,11 +73,11 @@ def on_connect(self, *args, **kwargs):
7473
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
7574
self._logger.error(message)
7675
if connection_queue:
77-
connection_queue.resolve((None, VacuumError(message)))
76+
connection_queue.set_exception(VacuumError(message))
7877
return
7978
self._logger.info(f"Subscribed to topic {topic}")
8079
if connection_queue:
81-
connection_queue.resolve((True, None))
80+
connection_queue.set_result(True)
8281

8382
def on_message(self, *args, **kwargs):
8483
client, __, msg = args
@@ -97,7 +96,7 @@ def on_disconnect(self, *args, **kwargs):
9796
self.update_client_id()
9897
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
9998
if connection_queue:
100-
connection_queue.resolve((True, None))
99+
connection_queue.set_result(True)
101100
except Exception as ex:
102101
self._logger.exception(ex)
103102

@@ -115,53 +114,53 @@ def sync_start_loop(self) -> None:
115114
self._logger.info("Starting mqtt loop")
116115
super().loop_start()
117116

118-
def sync_disconnect(self) -> tuple[bool, Task[tuple[Any, VacuumError | None]] | None]:
117+
def sync_disconnect(self) -> Any:
119118
if not self.is_connected():
120-
return False, None
119+
return None
121120

122121
self._logger.info("Disconnecting from mqtt")
123-
disconnected_future = asyncio.ensure_future(self._async_response(DISCONNECT_REQUEST_ID))
122+
disconnected_future = self._async_response(DISCONNECT_REQUEST_ID)
124123
rc = super().disconnect()
125124

126125
if rc == mqtt.MQTT_ERR_NO_CONN:
127126
disconnected_future.cancel()
128-
return False, None
127+
return None
129128

130129
if rc != mqtt.MQTT_ERR_SUCCESS:
131130
disconnected_future.cancel()
132131
raise RoborockException(f"Failed to disconnect ({mqtt.error_string(rc)})")
133132

134-
return True, disconnected_future
133+
return disconnected_future
135134

136-
def sync_connect(self) -> tuple[bool, Task[tuple[Any, VacuumError | None]] | None]:
135+
def sync_connect(self) -> Any:
137136
if self.is_connected():
138137
self.sync_start_loop()
139-
return False, None
138+
return None
140139

141140
if self._mqtt_port is None or self._mqtt_host is None:
142141
raise RoborockException("Mqtt information was not entered. Cannot connect.")
143142

144143
self._logger.debug("Connecting to mqtt")
145-
connected_future = asyncio.ensure_future(self._async_response(CONNECT_REQUEST_ID))
144+
connected_future = self._async_response(CONNECT_REQUEST_ID)
146145
super().connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
147146

148147
self.sync_start_loop()
149-
return True, connected_future
148+
return connected_future
150149

151150
async def async_disconnect(self) -> None:
152151
async with self._mutex:
153-
(disconnecting, disconnected_future) = self.sync_disconnect()
154-
if disconnecting and disconnected_future:
155-
(_, err) = await disconnected_future
156-
if err:
152+
if disconnected_future := self.sync_disconnect():
153+
try:
154+
await disconnected_future
155+
except VacuumError as err:
157156
raise RoborockException(err) from err
158157

159158
async def async_connect(self) -> None:
160159
async with self._mutex:
161-
(connecting, connected_future) = self.sync_connect()
162-
if connecting and connected_future:
163-
(_, err) = await connected_future
164-
if err:
160+
if connected_future := self.sync_connect():
161+
try:
162+
await connected_future
163+
except VacuumError as err:
165164
raise RoborockException(err) from err
166165

167166
def _send_msg_raw(self, msg: bytes) -> None:

roborock/roborock_future.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@ def __init__(self, protocol: int):
1414
self.fut: Future = Future()
1515
self.loop = self.fut.get_loop()
1616

17-
def _resolve(self, item: tuple[Any, VacuumError | None]) -> None:
17+
def _set_result(self, item: Any) -> None:
1818
if not self.fut.cancelled():
1919
self.fut.set_result(item)
2020

21-
def resolve(self, item: tuple[Any, VacuumError | None]) -> None:
22-
self.loop.call_soon_threadsafe(self._resolve, item)
21+
def set_result(self, item: Any) -> None:
22+
self.loop.call_soon_threadsafe(self._set_result, item)
23+
24+
def _set_exception(self, exc: VacuumError) -> None:
25+
if not self.fut.cancelled():
26+
self.fut.set_exception(exc)
27+
28+
def set_exception(self, exc: VacuumError) -> None:
29+
self.loop.call_soon_threadsafe(self._set_exception, exc)
2330

2431
async def async_get(self, timeout: float | int) -> tuple[Any, VacuumError | None]:
2532
try:

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -378,20 +378,17 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
378378
if queue and queue.protocol == protocol:
379379
error = data_point_response.get("error")
380380
if error:
381-
queue.resolve(
382-
(
383-
None,
384-
VacuumError(
385-
error.get("code"),
386-
error.get("message"),
387-
),
388-
)
381+
queue.set_exception(
382+
VacuumError(
383+
error.get("code"),
384+
error.get("message"),
385+
),
389386
)
390387
else:
391388
result = data_point_response.get("result")
392389
if isinstance(result, list) and len(result) == 1:
393390
result = result[0]
394-
queue.resolve((result, None))
391+
queue.set_result(result)
395392
else:
396393
self._logger.debug("Received response for unknown request id %s", request_id)
397394
else:
@@ -451,13 +448,13 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
451448
if queue:
452449
if isinstance(decompressed, list):
453450
decompressed = decompressed[0]
454-
queue.resolve((decompressed, None))
451+
queue.set_result(decompressed)
455452
else:
456453
self._logger.debug("Received response for unknown request id %s", request_id)
457454
else:
458455
queue = self._waiting_queue.get(data.seq)
459456
if queue:
460-
queue.resolve((data.payload, None))
457+
queue.set_result(data.payload)
461458
else:
462459
self._logger.debug("Received response for unknown request id %s", data.seq)
463460
except Exception as ex:

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import asyncio
2-
31
from roborock.local_api import RoborockLocalClient
42

53
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
4+
from ..exceptions import VacuumError
65
from ..protocol import MessageParser
76
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
87
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
@@ -53,16 +52,21 @@ async def send_message(self, roborock_message: RoborockMessage):
5352
if method:
5453
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
5554
# Send the command to the Roborock device
56-
async_response = asyncio.ensure_future(self._async_response(request_id, response_protocol))
55+
async_response = self._async_response(request_id, response_protocol)
5756
self._send_msg_raw(msg)
58-
(response, err) = await async_response
59-
self._diagnostic_data[method if method is not None else "unknown"] = {
57+
diagnostic_key = method if method is not None else "unknown"
58+
try:
59+
response = await async_response
60+
except VacuumError as err:
61+
self._diagnostic_data[diagnostic_key] = {
62+
"params": roborock_message.get_params(),
63+
"error": err,
64+
}
65+
raise CommandVacuumError(method, err) from err
66+
self._diagnostic_data[diagnostic_key] = {
6067
"params": roborock_message.get_params(),
6168
"response": response,
62-
"error": err,
6369
}
64-
if err:
65-
raise CommandVacuumError(method, err) from err
6670
if roborock_message.protocol == RoborockMessageProtocol.GENERAL_REQUEST:
6771
self._logger.debug(f"id={request_id} Response from method {roborock_message.get_method()}: {response}")
6872
if response == "retry":

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import base64
32

43
import paho.mqtt.client as mqtt
@@ -10,7 +9,7 @@
109
from roborock.cloud_api import RoborockMqttClient
1110

1211
from ..containers import DeviceData, UserData
13-
from ..exceptions import CommandVacuumError, RoborockException
12+
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
1413
from ..protocol import MessageParser, Utils
1514
from ..roborock_message import (
1615
RoborockMessage,
@@ -49,16 +48,21 @@ async def send_message(self, roborock_message: RoborockMessage):
4948
local_key = self.device_info.device.local_key
5049
msg = MessageParser.build(roborock_message, local_key, False)
5150
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
52-
async_response = asyncio.ensure_future(self._async_response(request_id, response_protocol))
51+
async_response = self._async_response(request_id, response_protocol)
5352
self._send_msg_raw(msg)
54-
(response, err) = await async_response
55-
self._diagnostic_data[method if method is not None else "unknown"] = {
53+
diagnostic_key = method if method is not None else "unknown"
54+
try:
55+
response = await async_response
56+
except VacuumError as err:
57+
self._diagnostic_data[diagnostic_key] = {
58+
"params": roborock_message.get_params(),
59+
"error": err,
60+
}
61+
raise CommandVacuumError(method, err) from err
62+
self._diagnostic_data[diagnostic_key] = {
5663
"params": roborock_message.get_params(),
5764
"response": response,
58-
"error": err,
5965
}
60-
if err:
61-
raise CommandVacuumError(method, err) from err
6266
if response_protocol == RoborockMessageProtocol.MAP_RESPONSE:
6367
self._logger.debug(f"id={request_id} Response from {method}: {len(response)} bytes")
6468
else:

roborock/version_a01_apis/roborock_client_a01.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
135135
converted_response = entries[data_point_protocol].post_process_fn(data_point)
136136
queue = self._waiting_queue.get(int(data_point_number))
137137
if queue and queue.protocol == protocol:
138-
queue.resolve((converted_response, None))
138+
queue.set_result(converted_response)
139139

140140
async def update_values(
141141
self, dyad_data_protocols: list[RoborockDyadDataProtocol | RoborockZeoProtocol]

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def send_message(self, roborock_message: RoborockMessage):
4343
futures = []
4444
if "10000" in payload["dps"]:
4545
for dps in json.loads(payload["dps"]["10000"]):
46-
futures.append(asyncio.ensure_future(self._async_response(dps, response_protocol)))
46+
futures.append(self._async_response(dps, response_protocol))
4747
self._send_msg_raw(m)
4848
responses = await asyncio.gather(*futures, return_exceptions=True)
4949
dps_responses: dict[int, typing.Any] = {}
@@ -54,7 +54,7 @@ async def send_message(self, roborock_message: RoborockMessage):
5454
self._logger.warning("Timed out get req for %s after %s s", dps, self.queue_timeout)
5555
dps_responses[dps] = None
5656
else:
57-
dps_responses[dps] = response[0]
57+
dps_responses[dps] = response
5858
return dps_responses
5959

6060
async def update_values(

tests/test_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ async def test_can_create_mqtt_roborock():
5050
async def test_sync_connect(mqtt_client):
5151
with patch("paho.mqtt.client.Client.connect", return_value=mqtt.MQTT_ERR_SUCCESS):
5252
with patch("paho.mqtt.client.Client.loop_start", return_value=mqtt.MQTT_ERR_SUCCESS):
53-
connecting, connected_future = mqtt_client.sync_connect()
54-
assert connecting is True
53+
connected_future = mqtt_client.sync_connect()
5554
assert connected_future is not None
5655

5756
connected_future.cancel()

tests/test_queue.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from roborock.exceptions import VacuumError
56
from roborock.roborock_future import RoborockFuture
67

78

@@ -10,10 +11,18 @@ def test_can_create():
1011

1112

1213
@pytest.mark.asyncio
13-
async def test_put():
14+
async def test_set_result():
1415
rq = RoborockFuture(1)
15-
rq.resolve(("test", None))
16-
assert await rq.async_get(1) == ("test", None)
16+
rq.set_result("test")
17+
assert await rq.async_get(1) == "test"
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_set_exception():
22+
rq = RoborockFuture(1)
23+
rq.set_exception(VacuumError("test"))
24+
with pytest.raises(VacuumError):
25+
assert await rq.async_get(1)
1726

1827

1928
@pytest.mark.asyncio

0 commit comments

Comments
 (0)