Skip to content

Commit aa66e1d

Browse files
fix: refactoring api
1 parent d6e3b34 commit aa66e1d

File tree

4 files changed

+125
-105
lines changed

4 files changed

+125
-105
lines changed

roborock/api.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import base64
77
import binascii
8+
import gzip
89
import hashlib
910
import hmac
1011
import json
@@ -13,20 +14,19 @@
1314
import secrets
1415
import struct
1516
import time
16-
from typing import Any
17+
from typing import Any, Callable
1718

1819
import aiohttp
1920
from Crypto.Cipher import AES
2021
from Crypto.Util.Padding import pad, unpad
2122

2223
from roborock.exceptions import (
23-
RoborockException,
24+
RoborockException, RoborockTimeout, VacuumError,
2425
)
2526
from .code_mappings import WASH_MODE_MAP, DUST_COLLECTION_MAP, RoborockDockType, \
26-
RoborockDockDustCollectionType, RoborockDockWashingModeType
27+
RoborockDockDustCollectionType, RoborockDockWashingModeType, STATE_CODE_TO_STATUS
2728
from .containers import (
2829
UserData,
29-
HomeDataDevice,
3030
Status,
3131
CleanSummary,
3232
Consumable,
@@ -37,6 +37,7 @@
3737
SmartWashParameters,
3838

3939
)
40+
from .roborock_queue import RoborockQueue
4041
from .typing import (
4142
RoborockDeviceProp,
4243
RoborockCommand,
@@ -86,7 +87,7 @@ async def request(
8687
return await resp.json()
8788

8889

89-
class RoborockClient():
90+
class RoborockClient:
9091

9192
def __init__(self, endpoint: str, device_localkey: dict[str, str], prefixed=False) -> None:
9293
self.device_localkey = device_localkey
@@ -97,6 +98,8 @@ def __init__(self, endpoint: str, device_localkey: dict[str, str], prefixed=Fals
9798
self._endpoint = base64.b64encode(md5bin(endpoint)[8:14]).decode()
9899
self._nonce = secrets.token_bytes(16)
99100
self._prefixed = prefixed
101+
self._waiting_queue: dict[int, RoborockQueue] = {}
102+
self._status_listeners: list[Callable[[str, str], None]] = []
100103

101104
def _decode_msg(self, msg: bytes, local_key: str) -> dict[str, Any]:
102105
if self._prefixed:
@@ -112,13 +115,13 @@ def _decode_msg(self, msg: bytes, local_key: str) -> dict[str, Any]:
112115
"timestamp": timestamp,
113116
"protocol": protocol,
114117
}
115-
crc32 = binascii.crc32(msg[0: len(msg) - 4])
118+
# crc32 = binascii.crc32(msg[0: len(msg) - 4])
116119
[version, _seq, _random, timestamp, protocol, payload_len] = struct.unpack(
117120
"!3sIIIHH", msg[0:19]
118121
)
119122
[payload, expected_crc32] = struct.unpack_from(f"!{payload_len}sI", msg, 19)
120-
if crc32 != expected_crc32:
121-
raise RoborockException(f"Wrong CRC32 {crc32}, expected {expected_crc32}")
123+
# if crc32 != expected_crc32:
124+
# raise RoborockException(f"Wrong CRC32 {crc32}, expected {expected_crc32}")
122125

123126
aes_key = md5bin(encode_timestamp(timestamp) + local_key + self._salt)
124127
decipher = AES.new(aes_key, AES.MODE_ECB)
@@ -130,7 +133,7 @@ def _decode_msg(self, msg: bytes, local_key: str) -> dict[str, Any]:
130133
"payload": decrypted_payload,
131134
}
132135

133-
def _get_msg_raw(self, device_id, protocol, timestamp, payload, prefix='') -> bytes:
136+
def _encode_msg(self, device_id, protocol, timestamp, payload, prefix='') -> bytes:
134137
local_key = self.device_localkey[device_id]
135138
aes_key = md5bin(encode_timestamp(timestamp) + local_key + self._salt)
136139
cipher = AES.new(aes_key, AES.MODE_ECB)
@@ -155,6 +158,81 @@ def _get_msg_raw(self, device_id, protocol, timestamp, payload, prefix='') -> by
155158
msg += struct.pack("!I", crc32)
156159
return msg
157160

161+
async def on_message(self, device_id, msg) -> None:
162+
try:
163+
data = self._decode_msg(msg, self.device_localkey[device_id])
164+
protocol = data.get("protocol")
165+
if protocol == 102 or protocol == 4:
166+
payload = json.loads(data.get("payload").decode())
167+
for data_point_number, data_point in payload.get("dps").items():
168+
if data_point_number == "102":
169+
data_point_response = json.loads(data_point)
170+
request_id = data_point_response.get("id")
171+
queue = self._waiting_queue.get(request_id)
172+
if queue:
173+
if queue.protocol == protocol:
174+
error = data_point_response.get("error")
175+
if error:
176+
await queue.async_put(
177+
(
178+
None,
179+
VacuumError(
180+
error.get("code"), error.get("message")
181+
),
182+
),
183+
timeout=QUEUE_TIMEOUT,
184+
)
185+
else:
186+
result = data_point_response.get("result")
187+
if isinstance(result, list) and len(result) > 0:
188+
result = result[0]
189+
await queue.async_put(
190+
(result, None), timeout=QUEUE_TIMEOUT
191+
)
192+
elif request_id < self._id_counter:
193+
_LOGGER.debug(
194+
f"id={request_id} Ignoring response: {data_point_response}"
195+
)
196+
elif data_point_number == "121":
197+
status = STATE_CODE_TO_STATUS.get(data_point)
198+
_LOGGER.debug(f"Status updated to {status}")
199+
for listener in self._status_listeners:
200+
listener(device_id, status)
201+
else:
202+
_LOGGER.debug(
203+
f"Unknown data point number received {data_point_number} with {data_point}"
204+
)
205+
elif protocol == 301:
206+
payload = data.get("payload")[0:24]
207+
[endpoint, _, request_id, _] = struct.unpack("<15sBH6s", payload)
208+
if endpoint.decode().startswith(self._endpoint):
209+
iv = bytes(AES.block_size)
210+
decipher = AES.new(self._nonce, AES.MODE_CBC, iv)
211+
decrypted = unpad(
212+
decipher.decrypt(data.get("payload")[24:]), AES.block_size
213+
)
214+
decrypted = gzip.decompress(decrypted)
215+
queue = self._waiting_queue.get(request_id)
216+
if queue:
217+
if isinstance(decrypted, list):
218+
decrypted = decrypted[0]
219+
await queue.async_put((decrypted, None), timeout=QUEUE_TIMEOUT)
220+
except Exception as ex:
221+
_LOGGER.exception(ex)
222+
223+
async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[Any, VacuumError | None]:
224+
try:
225+
queue = RoborockQueue(protocol_id)
226+
self._waiting_queue[request_id] = queue
227+
(response, err) = await queue.async_get(QUEUE_TIMEOUT)
228+
return response, err
229+
except (asyncio.TimeoutError, asyncio.CancelledError):
230+
raise RoborockTimeout(
231+
f"Timeout after {QUEUE_TIMEOUT} seconds waiting for response"
232+
) from None
233+
finally:
234+
del self._waiting_queue[request_id]
235+
158236
def _get_payload(
159237
self, method: RoborockCommand, params: list = None
160238
):

roborock/cloud_api.py

Lines changed: 5 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,21 @@
11
from __future__ import annotations
22

33
import base64
4-
import gzip
5-
import json
64
import logging
75
import secrets
8-
import struct
96
import threading
107
from asyncio import Lock
11-
from asyncio.exceptions import TimeoutError, CancelledError
128
from typing import Any, Callable
139
from urllib.parse import urlparse
1410

1511
import paho.mqtt.client as mqtt
16-
from Crypto.Cipher import AES
17-
from Crypto.Util.Padding import unpad
1812

1913
from roborock.api import md5hex, md5bin, RoborockClient
2014
from roborock.exceptions import (
2115
RoborockException,
2216
CommandVacuumError,
2317
VacuumError,
24-
RoborockTimeout,
2518
)
26-
from .code_mappings import STATE_CODE_TO_STATUS
2719
from .containers import (
2820
UserData,
2921
)
@@ -69,7 +61,6 @@ def __init__(self, user_data: UserData, device_localkey: dict[str, str]) -> None
6961
self._mutex = Lock()
7062
self._last_device_msg_in = mqtt.time_func()
7163
self._last_disconnection = mqtt.time_func()
72-
self._status_listeners: list[Callable[[str, str], None]] = []
7364

7465
def __del__(self) -> None:
7566
self.sync_disconnect()
@@ -102,69 +93,10 @@ async def on_connect(self, _client, _, __, rc, ___=None) -> None:
10293

10394
@run_in_executor()
10495
async def on_message(self, _client, _, msg, __=None) -> None:
105-
try:
106-
async with self._mutex:
107-
self._last_device_msg_in = mqtt.time_func()
108-
device_id = msg.topic.split("/").pop()
109-
data = self._decode_msg(msg.payload, self.device_localkey[device_id])
110-
protocol = data.get("protocol")
111-
if protocol == 102:
112-
payload = json.loads(data.get("payload").decode())
113-
for data_point_number, data_point in payload.get("dps").items():
114-
if data_point_number == "102":
115-
data_point_response = json.loads(data_point)
116-
request_id = data_point_response.get("id")
117-
queue = self._waiting_queue.get(request_id)
118-
if queue:
119-
if queue.protocol == protocol:
120-
error = data_point_response.get("error")
121-
if error:
122-
await queue.async_put(
123-
(
124-
None,
125-
VacuumError(
126-
error.get("code"), error.get("message")
127-
),
128-
),
129-
timeout=QUEUE_TIMEOUT,
130-
)
131-
else:
132-
result = data_point_response.get("result")
133-
if isinstance(result, list) and len(result) > 0:
134-
result = result[0]
135-
await queue.async_put(
136-
(result, None), timeout=QUEUE_TIMEOUT
137-
)
138-
elif request_id < self._id_counter:
139-
_LOGGER.debug(
140-
f"id={request_id} Ignoring response: {data_point_response}"
141-
)
142-
elif data_point_number == "121":
143-
status = STATE_CODE_TO_STATUS.get(data_point)
144-
_LOGGER.debug(f"Status updated to {status}")
145-
for listener in self._status_listeners:
146-
listener(device_id, status)
147-
else:
148-
_LOGGER.debug(
149-
f"Unknown data point number received {data_point_number} with {data_point}"
150-
)
151-
elif protocol == 301:
152-
payload = data.get("payload")[0:24]
153-
[endpoint, _, request_id, _] = struct.unpack("<15sBH6s", payload)
154-
if endpoint.decode().startswith(self._endpoint):
155-
iv = bytes(AES.block_size)
156-
decipher = AES.new(self._nonce, AES.MODE_CBC, iv)
157-
decrypted = unpad(
158-
decipher.decrypt(data.get("payload")[24:]), AES.block_size
159-
)
160-
decrypted = gzip.decompress(decrypted)
161-
queue = self._waiting_queue.get(request_id)
162-
if queue:
163-
if isinstance(decrypted, list):
164-
decrypted = decrypted[0]
165-
await queue.async_put((decrypted, None), timeout=QUEUE_TIMEOUT)
166-
except Exception as ex:
167-
_LOGGER.exception(ex)
96+
async with self._mutex:
97+
self._last_device_msg_in = mqtt.time_func()
98+
device_id = msg.topic.split("/").pop()
99+
super().on_message(device_id, msg)
168100

169101
@run_in_executor()
170102
async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None) -> None:
@@ -228,19 +160,6 @@ def sync_connect(self) -> bool:
228160
raise RoborockException(f"Failed to connect (rc:{rc})")
229161
return rc == mqtt.MQTT_ERR_SUCCESS
230162

231-
async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[Any, VacuumError | None]:
232-
try:
233-
queue = RoborockQueue(protocol_id)
234-
self._waiting_queue[request_id] = queue
235-
(response, err) = await queue.async_get(QUEUE_TIMEOUT)
236-
return response, err
237-
except (TimeoutError, CancelledError):
238-
raise RoborockTimeout(
239-
f"Timeout after {QUEUE_TIMEOUT} seconds waiting for response"
240-
) from None
241-
finally:
242-
del self._waiting_queue[request_id]
243-
244163
async def async_disconnect(self) -> Any:
245164
async with self._mutex:
246165
disconnecting = self.sync_disconnect()
@@ -277,7 +196,7 @@ async def send_command(
277196
_LOGGER.debug(f"id={request_id} Requesting method {method} with {params}")
278197
request_protocol = 101
279198
response_protocol = 301 if method in COMMANDS_WITH_BINARY_RESPONSE else 102
280-
msg = super()._get_msg_raw(device_id, request_protocol, timestamp, payload)
199+
msg = super()._encode_msg(device_id, request_protocol, timestamp, payload)
281200
self._send_msg_raw(device_id, msg)
282201
(response, err) = await self._async_response(request_id, response_protocol)
283202
if err:

0 commit comments

Comments
 (0)