Skip to content

Commit a444eae

Browse files
committed
chore: convert to store all iot login info together
1 parent 9c1ad7c commit a444eae

File tree

1 file changed

+62
-30
lines changed

1 file changed

+62
-30
lines changed

roborock/web_api.py

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import secrets
99
import string
1010
import time
11+
from dataclasses import dataclass
1112

1213
import aiohttp
1314
from aiohttp import ContentTypeError, FormData
@@ -31,6 +32,15 @@
3132
_LOGGER = logging.getLogger(__name__)
3233

3334

35+
@dataclass
36+
class IotLoginInfo:
37+
"""Information about the login to the iot server."""
38+
39+
base_url: str
40+
country_code: str
41+
country: str
42+
43+
3444
class RoborockApiClient:
3545
_LOGIN_RATES = [
3646
Rate(1, Duration.SECOND),
@@ -48,23 +58,29 @@ class RoborockApiClient:
4858
_login_limiter = Limiter(_LOGIN_RATES)
4959
_home_data_limiter = Limiter(_HOME_DATA_RATES)
5060

51-
def __init__(self, username: str, base_url=None, session: aiohttp.ClientSession | None = None) -> None:
61+
def __init__(
62+
self, username: str, base_url: str | None = None, session: aiohttp.ClientSession | None = None
63+
) -> None:
5264
"""Sample API Client."""
5365
self._username = username
54-
self.base_url = base_url
66+
self._base_url = base_url
5567
self._device_identifier = secrets.token_urlsafe(16)
5668
self.session = session
57-
self._country = None
58-
self._country_code = None
59-
60-
async def _get_base_url(self) -> str:
61-
if not self.base_url:
62-
for iot_url in [
63-
"https://usiot.roborock.com",
64-
"https://euiot.roborock.com",
65-
"https://cniot.roborock.com",
66-
"https://ruiot.roborock.com",
67-
]:
69+
self._iot_login_info: IotLoginInfo | None = None
70+
71+
async def _get_iot_login_info(self) -> IotLoginInfo:
72+
if self._iot_login_info is None:
73+
valid_urls = (
74+
[
75+
"https://usiot.roborock.com",
76+
"https://euiot.roborock.com",
77+
"https://cniot.roborock.com",
78+
"https://ruiot.roborock.com",
79+
]
80+
if self._base_url is None
81+
else [self._base_url]
82+
)
83+
for iot_url in valid_urls:
6884
url_request = PreparedRequest(iot_url, self.session)
6985
response = await url_request.request(
7086
"post",
@@ -86,15 +102,31 @@ async def _get_base_url(self) -> str:
86102
"Failed to get base url for %s with the following context: %s", self._username, response
87103
)
88104
if response["data"]["countrycode"] is not None:
89-
self._country_code = response["data"]["countrycode"]
90-
self._country = response["data"]["country"]
91-
self.base_url = response["data"]["url"]
92-
return self.base_url
105+
self._iot_login_info = IotLoginInfo(
106+
base_url=response["data"]["url"],
107+
country=response["data"]["country"],
108+
country_code=response["data"]["countrycode"],
109+
)
110+
return self._iot_login_info
93111
raise RoborockNoResponseFromBaseURL(
94112
"No account was found for any base url we tried. Either your email is incorrect or we do not have a"
95113
" record of the roborock server your device is on."
96114
)
97-
return self.base_url
115+
return self._iot_login_info
116+
117+
@property
118+
async def base_url(self):
119+
if self._base_url is not None:
120+
return self._base_url
121+
return (await self._get_iot_login_info()).base_url
122+
123+
@property
124+
async def country(self):
125+
return (await self._get_iot_login_info()).country
126+
127+
@property
128+
async def country_code(self):
129+
return (await self._get_iot_login_info()).country_code
98130

99131
def _get_header_client_id(self):
100132
md5 = hashlib.md5()
@@ -178,7 +210,7 @@ async def request_code(self) -> None:
178210
except BucketFullException as ex:
179211
_LOGGER.info(ex.meta_info)
180212
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
181-
base_url = await self._get_base_url()
213+
base_url = await self.base_url
182214
header_clientid = self._get_header_client_id()
183215
code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
184216

@@ -209,7 +241,7 @@ async def request_code_v4(self) -> None:
209241
except BucketFullException as ex:
210242
_LOGGER.info(ex.meta_info)
211243
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
212-
base_url = await self._get_base_url()
244+
base_url = await self.base_url
213245
header_clientid = self._get_header_client_id()
214246
code_request = PreparedRequest(
215247
base_url,
@@ -240,7 +272,7 @@ async def request_code_v4(self) -> None:
240272

241273
async def _sign_key_v3(self, s: str) -> str:
242274
"""Sign a randomly generated string."""
243-
base_url = await self._get_base_url()
275+
base_url = await self.base_url
244276
header_clientid = self._get_header_client_id()
245277
code_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
246278

@@ -269,11 +301,11 @@ async def code_login_v4(
269301
:param country: The two-character representation of the country, i.e. "US"
270302
:param country_code: the country phone number code i.e. 1 for US.
271303
"""
272-
base_url = await self._get_base_url()
304+
base_url = await self.base_url
273305
if country is None:
274-
country = self._country
306+
country = await self.country
275307
if country_code is None:
276-
country_code = self._country_code
308+
country_code = await self.country_code
277309
header_clientid = self._get_header_client_id()
278310
x_mercy_ks = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(16))
279311
x_mercy_k = await self._sign_key_v3(x_mercy_ks)
@@ -321,7 +353,7 @@ async def pass_login(self, password: str) -> UserData:
321353
except BucketFullException as ex:
322354
_LOGGER.info(ex.meta_info)
323355
raise RoborockRateLimit("Reached maximum requests for login. Please try again later.") from ex
324-
base_url = await self._get_base_url()
356+
base_url = await self.base_url
325357
header_clientid = self._get_header_client_id()
326358

327359
login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
@@ -360,7 +392,7 @@ async def pass_login_v3(self, password: str) -> UserData:
360392
raise NotImplementedError("Pass_login_v3 has not yet been implemented")
361393

362394
async def code_login(self, code: int | str) -> UserData:
363-
base_url = await self._get_base_url()
395+
base_url = await self.base_url
364396
header_clientid = self._get_header_client_id()
365397

366398
login_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
@@ -393,7 +425,7 @@ async def code_login(self, code: int | str) -> UserData:
393425
return UserData.from_dict(user_data)
394426

395427
async def _get_home_id(self, user_data: UserData):
396-
base_url = await self._get_base_url()
428+
base_url = await self.base_url
397429
header_clientid = self._get_header_client_id()
398430
home_id_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
399431
home_id_response = await home_id_request.request(
@@ -564,7 +596,7 @@ async def execute_scene(self, user_data: UserData, scene_id: int) -> None:
564596

565597
async def get_products(self, user_data: UserData) -> ProductResponse:
566598
"""Gets all products and their schemas, good for determining status codes and model numbers."""
567-
base_url = await self._get_base_url()
599+
base_url = await self.base_url
568600
header_clientid = self._get_header_client_id()
569601
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
570602
product_response = await product_request.request(
@@ -582,7 +614,7 @@ async def get_products(self, user_data: UserData) -> ProductResponse:
582614
raise RoborockException("product result was an unexpected type")
583615

584616
async def download_code(self, user_data: UserData, product_id: int):
585-
base_url = await self._get_base_url()
617+
base_url = await self.base_url
586618
header_clientid = self._get_header_client_id()
587619
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
588620
request = {"apilevel": 99999, "productids": [product_id], "type": 2}
@@ -595,7 +627,7 @@ async def download_code(self, user_data: UserData, product_id: int):
595627
return response["data"][0]["url"]
596628

597629
async def download_category_code(self, user_data: UserData):
598-
base_url = await self._get_base_url()
630+
base_url = await self.base_url
599631
header_clientid = self._get_header_client_id()
600632
product_request = PreparedRequest(base_url, self.session, {"header_clientid": header_clientid})
601633
response = await product_request.request(

0 commit comments

Comments
 (0)