Skip to content

Commit 96cc718

Browse files
authored
fix: close the session if we created it (#356)
1 parent 7608dd4 commit 96cc718

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

roborock/web_api.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ def __init__(self, username: str, base_url=None, session: aiohttp.ClientSession
3636
self._default_url = "https://euiot.roborock.com"
3737
self.base_url = base_url
3838
self._device_identifier = secrets.token_urlsafe(16)
39-
if session is None:
40-
session = aiohttp.ClientSession()
4139
self.session = session
4240

4341
async def _get_base_url(self) -> str:
@@ -470,18 +468,20 @@ async def download_category_code(self, user_data: UserData):
470468

471469

472470
class PreparedRequest:
473-
def __init__(self, base_url: str, session: aiohttp.ClientSession, base_headers: dict | None = None) -> None:
471+
def __init__(
472+
self, base_url: str, session: aiohttp.ClientSession | None = None, base_headers: dict | None = None
473+
) -> None:
474474
self.base_url = base_url
475475
self.base_headers = base_headers or {}
476476
self.session = session
477477

478478
async def request(self, method: str, url: str, params=None, data=None, headers=None, json=None) -> dict:
479479
_url = "/".join(s.strip("/") for s in [self.base_url, url])
480480
_headers = {**self.base_headers, **(headers or {})}
481+
close_session = self.session is None
482+
session = self.session if self.session is not None else aiohttp.ClientSession()
481483
try:
482-
async with self.session.request(
483-
method, _url, params=params, data=data, headers=_headers, json=json
484-
) as resp:
484+
async with session.request(method, _url, params=params, data=data, headers=_headers, json=json) as resp:
485485
return await resp.json()
486486
except ContentTypeError as err:
487487
"""If we get an error, lets log everything for debugging."""
@@ -494,3 +494,6 @@ async def request(self, method: str, url: str, params=None, data=None, headers=N
494494
_LOGGER.info("Resp raw: %s", resp_raw)
495495
# Still raise the err so that it's clear it failed.
496496
raise err
497+
finally:
498+
if close_session:
499+
await session.close()

tests/test_web_api.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
import aiohttp
2+
13
from roborock import HomeData, HomeDataScene, UserData
24
from roborock.web_api import RoborockApiClient
35
from tests.mock_data import HOME_DATA_RAW, USER_DATA
46

57

68
async def test_pass_login_flow() -> None:
79
"""Test that we can login with a password and we get back the correct userdata object."""
8-
api = RoborockApiClient(username="[email protected]")
10+
my_session = aiohttp.ClientSession()
11+
api = RoborockApiClient(username="[email protected]", session=my_session)
912
ud = await api.pass_login("password")
1013
assert ud == UserData.from_dict(USER_DATA)
14+
assert not my_session.closed
1115

1216

1317
async def test_code_login_flow() -> None:

0 commit comments

Comments
 (0)