|
3 | 3 | import logging |
4 | 4 | import re |
5 | 5 | from asyncio import Protocol |
6 | | -from collections.abc import Callable, Generator |
| 6 | +from collections.abc import AsyncGenerator, Callable, Generator |
7 | 7 | from queue import Queue |
8 | 8 | from typing import Any |
9 | 9 | from unittest.mock import Mock, patch |
@@ -138,16 +138,22 @@ def handle_select(rlist: list, wlist: list, *args: Any) -> list: |
138 | 138 |
|
139 | 139 |
|
140 | 140 | @pytest.fixture(name="mqtt_client") |
141 | | -def mqtt_client(mock_create_connection: None, mock_select: None) -> Generator[RoborockMqttClientV1, None, None]: |
| 141 | +async def mqtt_client(mock_create_connection: None, mock_select: None) -> AsyncGenerator[RoborockMqttClientV1, None]: |
142 | 142 | user_data = UserData.from_dict(USER_DATA) |
143 | 143 | home_data = HomeData.from_dict(HOME_DATA_RAW) |
144 | 144 | device_info = DeviceData( |
145 | 145 | device=home_data.devices[0], |
146 | 146 | model=home_data.products[0].model, |
147 | 147 | ) |
148 | 148 | client = RoborockMqttClientV1(user_data, device_info) |
149 | | - yield client |
150 | | - # Clean up any resources after the test |
| 149 | + try: |
| 150 | + yield client |
| 151 | + finally: |
| 152 | + if not client.is_connected(): |
| 153 | + try: |
| 154 | + await client.async_release() |
| 155 | + except Exception: |
| 156 | + pass |
151 | 157 |
|
152 | 158 |
|
153 | 159 | @pytest.fixture(name="mock_rest", autouse=True) |
@@ -226,11 +232,19 @@ def handle_write(data: bytes) -> None: |
226 | 232 |
|
227 | 233 |
|
228 | 234 | @pytest.fixture(name="local_client") |
229 | | -def local_client_fixture(mock_create_local_connection: None) -> Generator[RoborockLocalClientV1, None, None]: |
| 235 | +async def local_client_fixture(mock_create_local_connection: None) -> AsyncGenerator[RoborockLocalClientV1, None]: |
230 | 236 | home_data = HomeData.from_dict(HOME_DATA_RAW) |
231 | 237 | device_info = DeviceData( |
232 | 238 | device=home_data.devices[0], |
233 | 239 | model=home_data.products[0].model, |
234 | 240 | host=TEST_LOCAL_API_HOST, |
235 | 241 | ) |
236 | | - yield RoborockLocalClientV1(device_info) |
| 242 | + client = RoborockLocalClientV1(device_info) |
| 243 | + try: |
| 244 | + yield client |
| 245 | + finally: |
| 246 | + if not client.is_connected(): |
| 247 | + try: |
| 248 | + await client.async_release() |
| 249 | + except Exception: |
| 250 | + pass |
0 commit comments