diff --git a/src/hrm/bt_client.py b/src/hrm/bt_client.py index 95de2a0..ee5db12 100644 --- a/src/hrm/bt_client.py +++ b/src/hrm/bt_client.py @@ -6,7 +6,7 @@ import tempfile import time from datetime import datetime -from typing import List +from typing import List, Optional import matplotlib.pyplot as plt from bleak import BleakClient, BleakScanner @@ -29,7 +29,7 @@ logging.basicConfig(level=logging.INFO) -def upload_file(file_path: str) -> str: +def upload_file(file_path: str) -> Optional[str]: """Upload a file to QINIU Storage, it will let DeepChat to download the file.""" load_dotenv() QINIU_ACCESS_KEY = os.getenv("QINIU_ACCESS_KEY") @@ -67,6 +67,7 @@ def __init__(self): logger.info("BtClient initialized") # 50000 is the max length of the db self.db = TsDB(50000) + self.client: Optional[BleakClient] = None async def list_bluetooth_devices(self) -> dict[str, dict]: """Discover Bluetooth devices and filter by HRM profile. Returns a dic, key is the device id, @@ -105,7 +106,6 @@ async def background_monitor(self, duration: int): if not self.client: return async with self.client: - await self.client.connect() self.db.clear() await self.client.start_notify( HR_MEASUREMENT_CHAR_UUID, self.count_heart_rate @@ -113,7 +113,6 @@ async def background_monitor(self, duration: int): # Keep listening for duration seconds await asyncio.sleep(duration) await self.client.stop_notify(HR_MEASUREMENT_CHAR_UUID) - await self.client.disconnect() logger.info(f"Stopped monitoring heart rate of {self.client.address}") def count_heart_rate(self, sender: int, data: bytearray): @@ -258,6 +257,7 @@ def build_heart_rate_chart(self, since_from: float = 600.0) -> str: key = upload_file(debug_file) if key: logger.info(f"Debug PNG chart uploaded to {key}") + plt.close() else: logger.error("Failed to upload debug PNG chart") svg_buffer = io.StringIO() diff --git a/src/hrm/ts_db.py b/src/hrm/ts_db.py index 193b868..1d4b332 100644 --- a/src/hrm/ts_db.py +++ b/src/hrm/ts_db.py @@ -59,7 +59,7 @@ def clear(self): """ self.data.clear() - def time_bucket(self, start: float, end: float, bucket_size: float) -> List[float]: + def time_bucket(self, start: float, end: float, bucket_size: float) -> List[tuple[float, float]]: """ Bucket the data from the given start timestamp to the given end timestamp into the given time bucket size. """ diff --git a/tests/mcp/test_bt_client.py b/tests/mcp/test_bt_client.py index d01c07e..a07d568 100644 --- a/tests/mcp/test_bt_client.py +++ b/tests/mcp/test_bt_client.py @@ -78,8 +78,8 @@ async def test_monitoring_heart_rate_already_connected(bt_client): async def test_background_monitor(bt_client): # Patch client and its methods mock_client = MagicMock() - mock_client.__aenter__.return_value = mock_client - mock_client.__aexit__.return_value = None + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) mock_client.connect = AsyncMock() mock_client.start_notify = AsyncMock() mock_client.stop_notify = AsyncMock() @@ -87,11 +87,13 @@ async def test_background_monitor(bt_client): bt_client.client = mock_client with patch.object(bt_client.db, "clear") as mock_clear: await bt_client.background_monitor(duration=0.01) - mock_client.connect.assert_called_once() + mock_client.__aenter__.assert_called_once() mock_client.start_notify.assert_called_once() mock_client.stop_notify.assert_called_once() - mock_client.disconnect.assert_called_once() + mock_client.__aexit__.assert_called_once() mock_clear.assert_called_once() + mock_client.connect.assert_not_called() + mock_client.disconnect.assert_not_called() @pytest.mark.parametrize(