3030
3131
3232class RoborockApiClient :
33- def __init__ (self , username : str , base_url = None ) -> None :
33+ def __init__ (self , username : str , base_url = None , session : aiohttp . ClientSession | None = None ) -> None :
3434 """Sample API Client."""
3535 self ._username = username
3636 self ._default_url = "https://euiot.roborock.com"
3737 self .base_url = base_url
3838 self ._device_identifier = secrets .token_urlsafe (16 )
39+ if session is None :
40+ session = aiohttp .ClientSession ()
41+ self .session = session
3942
4043 async def _get_base_url (self ) -> str :
4144 if not self .base_url :
42- url_request = PreparedRequest (self ._default_url )
45+ url_request = PreparedRequest (self ._default_url , self . session )
4346 response = await url_request .request (
4447 "post" ,
4548 "/api/v1/getUrlByEmail" ,
@@ -113,7 +116,7 @@ async def nc_prepare(self, user_data: UserData, timezone: str) -> dict:
113116 ):
114117 raise RoborockException ("Your userdata is missing critical attributes." )
115118 base_url = user_data .rriot .r .a
116- prepare_request = PreparedRequest (base_url )
119+ prepare_request = PreparedRequest (base_url , self . session )
117120 hid = await self ._get_home_id (user_data )
118121
119122 data = FormData ()
@@ -151,7 +154,7 @@ async def add_device(self, user_data: UserData, s: str, t: str) -> dict:
151154 ):
152155 raise RoborockException ("Your userdata is missing critical attributes." )
153156 base_url = user_data .rriot .r .a
154- add_device_request = PreparedRequest (base_url )
157+ add_device_request = PreparedRequest (base_url , self . session )
155158
156159 add_device_response = await add_device_request .request (
157160 "GET" ,
@@ -176,7 +179,7 @@ async def add_device(self, user_data: UserData, s: str, t: str) -> dict:
176179 async def request_code (self ) -> None :
177180 base_url = await self ._get_base_url ()
178181 header_clientid = self ._get_header_client_id ()
179- code_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
182+ code_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
180183
181184 code_response = await code_request .request (
182185 "post" ,
@@ -201,7 +204,7 @@ async def pass_login(self, password: str) -> UserData:
201204 base_url = await self ._get_base_url ()
202205 header_clientid = self ._get_header_client_id ()
203206
204- login_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
207+ login_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
205208 login_response = await login_request .request (
206209 "post" ,
207210 "/api/v1/login" ,
@@ -239,7 +242,7 @@ async def code_login(self, code: int | str) -> UserData:
239242 base_url = await self ._get_base_url ()
240243 header_clientid = self ._get_header_client_id ()
241244
242- login_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
245+ login_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
243246 login_response = await login_request .request (
244247 "post" ,
245248 "/api/v1/loginWithCode" ,
@@ -270,7 +273,7 @@ async def code_login(self, code: int | str) -> UserData:
270273 async def _get_home_id (self , user_data : UserData ):
271274 base_url = await self ._get_base_url ()
272275 header_clientid = self ._get_header_client_id ()
273- home_id_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
276+ home_id_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
274277 home_id_response = await home_id_request .request (
275278 "get" ,
276279 "/api/v1/getHomeDetail" ,
@@ -296,6 +299,7 @@ async def get_home_data(self, user_data: UserData) -> HomeData:
296299 raise RoborockException ("Missing field 'a' in rriot reference" )
297300 home_request = PreparedRequest (
298301 rriot .r .a ,
302+ self .session ,
299303 {
300304 "Authorization" : self ._get_hawk_authentication (rriot , f"/user/homes/{ str (home_id )} " ),
301305 },
@@ -319,6 +323,7 @@ async def get_home_data_v2(self, user_data: UserData) -> HomeData:
319323 raise RoborockException ("Missing field 'a' in rriot reference" )
320324 home_request = PreparedRequest (
321325 rriot .r .a ,
326+ self .session ,
322327 {
323328 "Authorization" : self ._get_hawk_authentication (rriot , "/v2/user/homes/" + str (home_id )),
324329 },
@@ -362,6 +367,7 @@ async def get_rooms(self, user_data: UserData, home_id: int | None = None) -> li
362367 raise RoborockException ("Missing field 'a' in rriot reference" )
363368 room_request = PreparedRequest (
364369 rriot .r .a ,
370+ self .session ,
365371 {
366372 "Authorization" : self ._get_hawk_authentication (rriot , "/v2/user/homes/" + str (home_id )),
367373 },
@@ -386,6 +392,7 @@ async def get_scenes(self, user_data: UserData, device_id: str) -> list[HomeData
386392 raise RoborockException ("Missing field 'a' in rriot reference" )
387393 scenes_request = PreparedRequest (
388394 rriot .r .a ,
395+ self .session ,
389396 {
390397 "Authorization" : self ._get_hawk_authentication (rriot , f"/user/scene/device/{ str (device_id )} " ),
391398 },
@@ -407,6 +414,7 @@ async def execute_scene(self, user_data: UserData, scene_id: int) -> None:
407414 raise RoborockException ("Missing field 'a' in rriot reference" )
408415 execute_scene_request = PreparedRequest (
409416 rriot .r .a ,
417+ self .session ,
410418 {
411419 "Authorization" : self ._get_hawk_authentication (rriot , f"/user/scene/{ str (scene_id )} /execute" ),
412420 },
@@ -419,7 +427,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse:
419427 """Gets all products and their schemas, good for determining status codes and model numbers."""
420428 base_url = await self ._get_base_url ()
421429 header_clientid = self ._get_header_client_id ()
422- product_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
430+ product_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
423431 product_response = await product_request .request (
424432 "get" ,
425433 "/api/v4/product" ,
@@ -437,7 +445,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse:
437445 async def download_code (self , user_data : UserData , product_id : int ):
438446 base_url = await self ._get_base_url ()
439447 header_clientid = self ._get_header_client_id ()
440- product_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
448+ product_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
441449 request = {"apilevel" : 99999 , "productids" : [product_id ], "type" : 2 }
442450 response = await product_request .request (
443451 "post" ,
@@ -450,7 +458,7 @@ async def download_code(self, user_data: UserData, product_id: int):
450458 async def download_category_code (self , user_data : UserData ):
451459 base_url = await self ._get_base_url ()
452460 header_clientid = self ._get_header_client_id ()
453- product_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
461+ product_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
454462 response = await product_request .request (
455463 "get" ,
456464 "api/v1/plugins?apiLevel=99999&type=2" ,
@@ -462,25 +470,27 @@ async def download_category_code(self, user_data: UserData):
462470
463471
464472class PreparedRequest :
465- def __init__ (self , base_url : str , base_headers : dict | None = None ) -> None :
473+ def __init__ (self , base_url : str , session : aiohttp . ClientSession , base_headers : dict | None = None ) -> None :
466474 self .base_url = base_url
467475 self .base_headers = base_headers or {}
476+ self .session = session
468477
469478 async def request (self , method : str , url : str , params = None , data = None , headers = None , json = None ) -> dict :
470479 _url = "/" .join (s .strip ("/" ) for s in [self .base_url , url ])
471480 _headers = {** self .base_headers , ** (headers or {})}
472- async with aiohttp .ClientSession () as session :
481+ try :
482+ async with self .session .request (
483+ method , _url , params = params , data = data , headers = _headers , json = json
484+ ) as resp :
485+ return await resp .json ()
486+ except ContentTypeError as err :
487+ """If we get an error, lets log everything for debugging."""
473488 try :
474- async with session .request (method , _url , params = params , data = data , headers = _headers , json = json ) as resp :
475- return await resp .json ()
476- except ContentTypeError as err :
477- """If we get an error, lets log everything for debugging."""
478- try :
479- resp_json = await resp .json (content_type = None )
480- _LOGGER .info ("Resp: %s" , resp_json )
481- except ContentTypeError as err_2 :
482- _LOGGER .info (err_2 )
483- resp_raw = await resp .read ()
484- _LOGGER .info ("Resp raw: %s" , resp_raw )
485- # Still raise the err so that it's clear it failed.
486- raise err
489+ resp_json = await resp .json (content_type = None )
490+ _LOGGER .info ("Resp: %s" , resp_json )
491+ except ContentTypeError as err_2 :
492+ _LOGGER .info (err_2 )
493+ resp_raw = await resp .read ()
494+ _LOGGER .info ("Resp raw: %s" , resp_raw )
495+ # Still raise the err so that it's clear it failed.
496+ raise err
0 commit comments