Skip to content

Commit 1d31cf6

Browse files
authored
feat: allow passing in clientsession (#354)
* feat: allow passing in clientsession * fix: test
1 parent 404a47c commit 1d31cf6

File tree

2 files changed

+38
-28
lines changed

2 files changed

+38
-28
lines changed

roborock/web_api.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,19 @@
3030

3131

3232
class RoborockApiClient:
33-
def __init__(self, username: str, base_url=None) -> None:
33+
def __init__(self, username: str, base_url=None, session: aiohttp.ClientSession | None = None) -> None:
3434
"""Sample API Client."""
3535
self._username = username
3636
self._default_url = "https://euiot.roborock.com"
3737
self.base_url = base_url
3838
self._device_identifier = secrets.token_urlsafe(16)
39+
if session is None:
40+
session = aiohttp.ClientSession()
41+
self.session = session
3942

4043
async def _get_base_url(self) -> str:
4144
if not self.base_url:
42-
url_request = PreparedRequest(self._default_url)
45+
url_request = PreparedRequest(self._default_url, self.session)
4346
response = await url_request.request(
4447
"post",
4548
"/api/v1/getUrlByEmail",
@@ -113,7 +116,7 @@ async def nc_prepare(self, user_data: UserData, timezone: str) -> dict:
113116
):
114117
raise RoborockException("Your userdata is missing critical attributes.")
115118
base_url = user_data.rriot.r.a
116-
prepare_request = PreparedRequest(base_url)
119+
prepare_request = PreparedRequest(base_url, self.session)
117120
hid = await self._get_home_id(user_data)
118121

119122
data = FormData()
@@ -151,7 +154,7 @@ async def add_device(self, user_data: UserData, s: str, t: str) -> dict:
151154
):
152155
raise RoborockException("Your userdata is missing critical attributes.")
153156
base_url = user_data.rriot.r.a
154-
add_device_request = PreparedRequest(base_url)
157+
add_device_request = PreparedRequest(base_url, self.session)
155158

156159
add_device_response = await add_device_request.request(
157160
"GET",
@@ -176,7 +179,7 @@ async def add_device(self, user_data: UserData, s: str, t: str) -> dict:
176179
async def request_code(self) -> None:
177180
base_url = await self._get_base_url()
178181
header_clientid = self._get_header_client_id()
179-
code_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
182+
code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
180183

181184
code_response = await code_request.request(
182185
"post",
@@ -201,7 +204,7 @@ async def pass_login(self, password: str) -> UserData:
201204
base_url = await self._get_base_url()
202205
header_clientid = self._get_header_client_id()
203206

204-
login_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
207+
login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
205208
login_response = await login_request.request(
206209
"post",
207210
"/api/v1/login",
@@ -239,7 +242,7 @@ async def code_login(self, code: int | str) -> UserData:
239242
base_url = await self._get_base_url()
240243
header_clientid = self._get_header_client_id()
241244

242-
login_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
245+
login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
243246
login_response = await login_request.request(
244247
"post",
245248
"/api/v1/loginWithCode",
@@ -270,7 +273,7 @@ async def code_login(self, code: int | str) -> UserData:
270273
async def _get_home_id(self, user_data: UserData):
271274
base_url = await self._get_base_url()
272275
header_clientid = self._get_header_client_id()
273-
home_id_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
276+
home_id_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
274277
home_id_response = await home_id_request.request(
275278
"get",
276279
"/api/v1/getHomeDetail",
@@ -296,6 +299,7 @@ async def get_home_data(self, user_data: UserData) -> HomeData:
296299
raise RoborockException("Missing field 'a' in rriot reference")
297300
home_request = PreparedRequest(
298301
rriot.r.a,
302+
self.session,
299303
{
300304
"Authorization": self._get_hawk_authentication(rriot, f"/user/homes/{str(home_id)}"),
301305
},
@@ -319,6 +323,7 @@ async def get_home_data_v2(self, user_data: UserData) -> HomeData:
319323
raise RoborockException("Missing field 'a' in rriot reference")
320324
home_request = PreparedRequest(
321325
rriot.r.a,
326+
self.session,
322327
{
323328
"Authorization": self._get_hawk_authentication(rriot, "/v2/user/homes/" + str(home_id)),
324329
},
@@ -362,6 +367,7 @@ async def get_rooms(self, user_data: UserData, home_id: int | None = None) -> li
362367
raise RoborockException("Missing field 'a' in rriot reference")
363368
room_request = PreparedRequest(
364369
rriot.r.a,
370+
self.session,
365371
{
366372
"Authorization": self._get_hawk_authentication(rriot, "/v2/user/homes/" + str(home_id)),
367373
},
@@ -386,6 +392,7 @@ async def get_scenes(self, user_data: UserData, device_id: str) -> list[HomeData
386392
raise RoborockException("Missing field 'a' in rriot reference")
387393
scenes_request = PreparedRequest(
388394
rriot.r.a,
395+
self.session,
389396
{
390397
"Authorization": self._get_hawk_authentication(rriot, f"/user/scene/device/{str(device_id)}"),
391398
},
@@ -407,6 +414,7 @@ async def execute_scene(self, user_data: UserData, scene_id: int) -> None:
407414
raise RoborockException("Missing field 'a' in rriot reference")
408415
execute_scene_request = PreparedRequest(
409416
rriot.r.a,
417+
self.session,
410418
{
411419
"Authorization": self._get_hawk_authentication(rriot, f"/user/scene/{str(scene_id)}/execute"),
412420
},
@@ -419,7 +427,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse:
419427
"""Gets all products and their schemas, good for determining status codes and model numbers."""
420428
base_url = await self._get_base_url()
421429
header_clientid = self._get_header_client_id()
422-
product_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
430+
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
423431
product_response = await product_request.request(
424432
"get",
425433
"/api/v4/product",
@@ -437,7 +445,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse:
437445
async def download_code(self, user_data: UserData, product_id: int):
438446
base_url = await self._get_base_url()
439447
header_clientid = self._get_header_client_id()
440-
product_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
448+
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
441449
request = {"apilevel": 99999, "productids": [product_id], "type": 2}
442450
response = await product_request.request(
443451
"post",
@@ -450,7 +458,7 @@ async def download_code(self, user_data: UserData, product_id: int):
450458
async def download_category_code(self, user_data: UserData):
451459
base_url = await self._get_base_url()
452460
header_clientid = self._get_header_client_id()
453-
product_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
461+
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
454462
response = await product_request.request(
455463
"get",
456464
"api/v1/plugins?apiLevel=99999&type=2",
@@ -462,25 +470,27 @@ async def download_category_code(self, user_data: UserData):
462470

463471

464472
class PreparedRequest:
465-
def __init__(self, base_url: str, base_headers: dict | None = None) -> None:
473+
def __init__(self, base_url: str, session: aiohttp.ClientSession, base_headers: dict | None = None) -> None:
466474
self.base_url = base_url
467475
self.base_headers = base_headers or {}
476+
self.session = session
468477

469478
async def request(self, method: str, url: str, params=None, data=None, headers=None, json=None) -> dict:
470479
_url = "/".join(s.strip("/") for s in [self.base_url, url])
471480
_headers = {**self.base_headers, **(headers or {})}
472-
async with aiohttp.ClientSession() as session:
481+
try:
482+
async with self.session.request(
483+
method, _url, params=params, data=data, headers=_headers, json=json
484+
) as resp:
485+
return await resp.json()
486+
except ContentTypeError as err:
487+
"""If we get an error, lets log everything for debugging."""
473488
try:
474-
async with session.request(method, _url, params=params, data=data, headers=_headers, json=json) as resp:
475-
return await resp.json()
476-
except ContentTypeError as err:
477-
"""If we get an error, lets log everything for debugging."""
478-
try:
479-
resp_json = await resp.json(content_type=None)
480-
_LOGGER.info("Resp: %s", resp_json)
481-
except ContentTypeError as err_2:
482-
_LOGGER.info(err_2)
483-
resp_raw = await resp.read()
484-
_LOGGER.info("Resp raw: %s", resp_raw)
485-
# Still raise the err so that it's clear it failed.
486-
raise err
489+
resp_json = await resp.json(content_type=None)
490+
_LOGGER.info("Resp: %s", resp_json)
491+
except ContentTypeError as err_2:
492+
_LOGGER.info(err_2)
493+
resp_raw = await resp.read()
494+
_LOGGER.info("Resp raw: %s", resp_raw)
495+
# Still raise the err so that it's clear it failed.
496+
raise err

tests/test_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import AsyncGenerator
55
from queue import Queue
66
from typing import Any
7-
from unittest.mock import patch
7+
from unittest.mock import AsyncMock, patch
88

99
import paho.mqtt.client as mqtt
1010
import pytest
@@ -36,7 +36,7 @@
3636

3737

3838
def test_can_create_prepared_request():
39-
PreparedRequest("https://sample.com")
39+
PreparedRequest("https://sample.com", AsyncMock())
4040

4141

4242
async def test_can_create_mqtt_roborock():

0 commit comments

Comments
 (0)