30
30
31
31
32
32
class RoborockApiClient :
33
- def __init__ (self , username : str , base_url = None ) -> None :
33
+ def __init__ (self , username : str , base_url = None , session : aiohttp . ClientSession | None = None ) -> None :
34
34
"""Sample API Client."""
35
35
self ._username = username
36
36
self ._default_url = "https://euiot.roborock.com"
37
37
self .base_url = base_url
38
38
self ._device_identifier = secrets .token_urlsafe (16 )
39
+ if session is None :
40
+ session = aiohttp .ClientSession ()
41
+ self .session = session
39
42
40
43
async def _get_base_url (self ) -> str :
41
44
if not self .base_url :
42
- url_request = PreparedRequest (self ._default_url )
45
+ url_request = PreparedRequest (self ._default_url , self . session )
43
46
response = await url_request .request (
44
47
"post" ,
45
48
"/api/v1/getUrlByEmail" ,
@@ -113,7 +116,7 @@ async def nc_prepare(self, user_data: UserData, timezone: str) -> dict:
113
116
):
114
117
raise RoborockException ("Your userdata is missing critical attributes." )
115
118
base_url = user_data .rriot .r .a
116
- prepare_request = PreparedRequest (base_url )
119
+ prepare_request = PreparedRequest (base_url , self . session )
117
120
hid = await self ._get_home_id (user_data )
118
121
119
122
data = FormData ()
@@ -151,7 +154,7 @@ async def add_device(self, user_data: UserData, s: str, t: str) -> dict:
151
154
):
152
155
raise RoborockException ("Your userdata is missing critical attributes." )
153
156
base_url = user_data .rriot .r .a
154
- add_device_request = PreparedRequest (base_url )
157
+ add_device_request = PreparedRequest (base_url , self . session )
155
158
156
159
add_device_response = await add_device_request .request (
157
160
"GET" ,
@@ -176,7 +179,7 @@ async def add_device(self, user_data: UserData, s: str, t: str) -> dict:
176
179
async def request_code (self ) -> None :
177
180
base_url = await self ._get_base_url ()
178
181
header_clientid = self ._get_header_client_id ()
179
- code_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
182
+ code_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
180
183
181
184
code_response = await code_request .request (
182
185
"post" ,
@@ -201,7 +204,7 @@ async def pass_login(self, password: str) -> UserData:
201
204
base_url = await self ._get_base_url ()
202
205
header_clientid = self ._get_header_client_id ()
203
206
204
- login_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
207
+ login_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
205
208
login_response = await login_request .request (
206
209
"post" ,
207
210
"/api/v1/login" ,
@@ -239,7 +242,7 @@ async def code_login(self, code: int | str) -> UserData:
239
242
base_url = await self ._get_base_url ()
240
243
header_clientid = self ._get_header_client_id ()
241
244
242
- login_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
245
+ login_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
243
246
login_response = await login_request .request (
244
247
"post" ,
245
248
"/api/v1/loginWithCode" ,
@@ -270,7 +273,7 @@ async def code_login(self, code: int | str) -> UserData:
270
273
async def _get_home_id (self , user_data : UserData ):
271
274
base_url = await self ._get_base_url ()
272
275
header_clientid = self ._get_header_client_id ()
273
- home_id_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
276
+ home_id_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
274
277
home_id_response = await home_id_request .request (
275
278
"get" ,
276
279
"/api/v1/getHomeDetail" ,
@@ -296,6 +299,7 @@ async def get_home_data(self, user_data: UserData) -> HomeData:
296
299
raise RoborockException ("Missing field 'a' in rriot reference" )
297
300
home_request = PreparedRequest (
298
301
rriot .r .a ,
302
+ self .session ,
299
303
{
300
304
"Authorization" : self ._get_hawk_authentication (rriot , f"/user/homes/{ str (home_id )} " ),
301
305
},
@@ -319,6 +323,7 @@ async def get_home_data_v2(self, user_data: UserData) -> HomeData:
319
323
raise RoborockException ("Missing field 'a' in rriot reference" )
320
324
home_request = PreparedRequest (
321
325
rriot .r .a ,
326
+ self .session ,
322
327
{
323
328
"Authorization" : self ._get_hawk_authentication (rriot , "/v2/user/homes/" + str (home_id )),
324
329
},
@@ -362,6 +367,7 @@ async def get_rooms(self, user_data: UserData, home_id: int | None = None) -> li
362
367
raise RoborockException ("Missing field 'a' in rriot reference" )
363
368
room_request = PreparedRequest (
364
369
rriot .r .a ,
370
+ self .session ,
365
371
{
366
372
"Authorization" : self ._get_hawk_authentication (rriot , "/v2/user/homes/" + str (home_id )),
367
373
},
@@ -386,6 +392,7 @@ async def get_scenes(self, user_data: UserData, device_id: str) -> list[HomeData
386
392
raise RoborockException ("Missing field 'a' in rriot reference" )
387
393
scenes_request = PreparedRequest (
388
394
rriot .r .a ,
395
+ self .session ,
389
396
{
390
397
"Authorization" : self ._get_hawk_authentication (rriot , f"/user/scene/device/{ str (device_id )} " ),
391
398
},
@@ -407,6 +414,7 @@ async def execute_scene(self, user_data: UserData, scene_id: int) -> None:
407
414
raise RoborockException ("Missing field 'a' in rriot reference" )
408
415
execute_scene_request = PreparedRequest (
409
416
rriot .r .a ,
417
+ self .session ,
410
418
{
411
419
"Authorization" : self ._get_hawk_authentication (rriot , f"/user/scene/{ str (scene_id )} /execute" ),
412
420
},
@@ -419,7 +427,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse:
419
427
"""Gets all products and their schemas, good for determining status codes and model numbers."""
420
428
base_url = await self ._get_base_url ()
421
429
header_clientid = self ._get_header_client_id ()
422
- product_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
430
+ product_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
423
431
product_response = await product_request .request (
424
432
"get" ,
425
433
"/api/v4/product" ,
@@ -437,7 +445,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse:
437
445
async def download_code (self , user_data : UserData , product_id : int ):
438
446
base_url = await self ._get_base_url ()
439
447
header_clientid = self ._get_header_client_id ()
440
- product_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
448
+ product_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
441
449
request = {"apilevel" : 99999 , "productids" : [product_id ], "type" : 2 }
442
450
response = await product_request .request (
443
451
"post" ,
@@ -450,7 +458,7 @@ async def download_code(self, user_data: UserData, product_id: int):
450
458
async def download_category_code (self , user_data : UserData ):
451
459
base_url = await self ._get_base_url ()
452
460
header_clientid = self ._get_header_client_id ()
453
- product_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
461
+ product_request = PreparedRequest (base_url , self . session , {"header_clientid" : header_clientid })
454
462
response = await product_request .request (
455
463
"get" ,
456
464
"api/v1/plugins?apiLevel=99999&type=2" ,
@@ -462,25 +470,27 @@ async def download_category_code(self, user_data: UserData):
462
470
463
471
464
472
class PreparedRequest :
465
- def __init__ (self , base_url : str , base_headers : dict | None = None ) -> None :
473
+ def __init__ (self , base_url : str , session : aiohttp . ClientSession , base_headers : dict | None = None ) -> None :
466
474
self .base_url = base_url
467
475
self .base_headers = base_headers or {}
476
+ self .session = session
468
477
469
478
async def request (self , method : str , url : str , params = None , data = None , headers = None , json = None ) -> dict :
470
479
_url = "/" .join (s .strip ("/" ) for s in [self .base_url , url ])
471
480
_headers = {** self .base_headers , ** (headers or {})}
472
- async with aiohttp .ClientSession () as session :
481
+ try :
482
+ async with self .session .request (
483
+ method , _url , params = params , data = data , headers = _headers , json = json
484
+ ) as resp :
485
+ return await resp .json ()
486
+ except ContentTypeError as err :
487
+ """If we get an error, lets log everything for debugging."""
473
488
try :
474
- async with session .request (method , _url , params = params , data = data , headers = _headers , json = json ) as resp :
475
- return await resp .json ()
476
- except ContentTypeError as err :
477
- """If we get an error, lets log everything for debugging."""
478
- try :
479
- resp_json = await resp .json (content_type = None )
480
- _LOGGER .info ("Resp: %s" , resp_json )
481
- except ContentTypeError as err_2 :
482
- _LOGGER .info (err_2 )
483
- resp_raw = await resp .read ()
484
- _LOGGER .info ("Resp raw: %s" , resp_raw )
485
- # Still raise the err so that it's clear it failed.
486
- raise err
489
+ resp_json = await resp .json (content_type = None )
490
+ _LOGGER .info ("Resp: %s" , resp_json )
491
+ except ContentTypeError as err_2 :
492
+ _LOGGER .info (err_2 )
493
+ resp_raw = await resp .read ()
494
+ _LOGGER .info ("Resp raw: %s" , resp_raw )
495
+ # Still raise the err so that it's clear it failed.
496
+ raise err
0 commit comments