Skip to content

Commit 34b1422

Browse files
committed
chore: extract map parser into a separate function to share with new api
Simplify how security data works to avoid passing all the way up to the parent class.
1 parent 251b3f9 commit 34b1422

File tree

9 files changed

+159
-37
lines changed

9 files changed

+159
-37
lines changed

roborock/api.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
from __future__ import annotations
44

55
import asyncio
6-
import base64
76
import logging
8-
import secrets
97
import time
108
from abc import ABC, abstractmethod
119
from typing import Any
@@ -37,14 +35,11 @@ class RoborockClient(ABC):
3735
def __init__(self, device_info: DeviceData) -> None:
3836
"""Initialize RoborockClient."""
3937
self.device_info = device_info
40-
self._nonce = secrets.token_bytes(16)
4138
self._waiting_queue: dict[int, RoborockFuture] = {}
4239
self._last_device_msg_in = time.monotonic()
4340
self._last_disconnection = time.monotonic()
4441
self.keep_alive = KEEPALIVE
45-
self._diagnostic_data: dict[str, dict[str, Any]] = {
46-
"misc_info": {"Nonce": base64.b64encode(self._nonce).decode("utf-8")}
47-
}
42+
self._diagnostic_data: dict[str, dict[str, Any]] = {}
4843
self.is_available: bool = True
4944

5045
async def async_release(self) -> None:

roborock/protocol.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,26 @@ def decrypt_ecb(ciphertext: bytes, token: bytes) -> bytes:
147147
return unpad(decipher.decrypt(ciphertext), AES.block_size)
148148
return ciphertext
149149

150+
@staticmethod
151+
def encrypt_cbc(plaintext: bytes, token: bytes) -> bytes:
152+
"""Encrypt plaintext with a given token using cbc mode.
153+
154+
This is currently used for testing purposes only.
155+
156+
:param bytes plaintext: Plaintext (json) to encrypt
157+
:param bytes token: Token to use
158+
:return: Encrypted bytes
159+
"""
160+
if not isinstance(plaintext, bytes):
161+
raise TypeError("plaintext requires bytes")
162+
Utils.verify_token(token)
163+
iv = bytes(AES.block_size)
164+
cipher = AES.new(token, AES.MODE_CBC, iv)
165+
if plaintext:
166+
plaintext = pad(plaintext, AES.block_size)
167+
return cipher.encrypt(plaintext)
168+
return plaintext
169+
150170
@staticmethod
151171
def decrypt_cbc(ciphertext: bytes, token: bytes) -> bytes:
152172
"""Decrypt ciphertext with a given token using cbc mode.

roborock/protocols/v1_protocol.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import math
99
import secrets
10+
import struct
1011
import time
1112
from collections.abc import Callable
1213
from dataclasses import dataclass, field
@@ -44,6 +45,10 @@ def to_dict(self) -> dict[str, Any]:
4445
"""Convert security data to a dictionary for sending in the payload."""
4546
return {"security": {"endpoint": self.endpoint, "nonce": self.nonce.hex().lower()}}
4647

48+
def to_diagnostic_data(self) -> dict[str, Any]:
49+
"""Convert security data to a dictionary for debugging purposes."""
50+
return {"nonce": self.nonce.hex().lower()}
51+
4752

4853
def create_security_data(rriot: RRiot) -> SecurityData:
4954
"""Create a SecurityData instance for the given endpoint and nonce."""
@@ -142,3 +147,37 @@ def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
142147
if not isinstance(result, dict):
143148
raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}")
144149
return result
150+
151+
152+
@dataclass
153+
class MapResponse:
154+
"""Data structure for the V1 Map response."""
155+
156+
request_id: int
157+
"""The request ID of the map response."""
158+
159+
data: bytes
160+
"""The map data, decrypted and decompressed."""
161+
162+
163+
def create_map_response_decoder(security_data: SecurityData) -> Callable[[RoborockMessage], MapResponse]:
164+
"""Create a decoder for V1 map response messages."""
165+
166+
def _decode_map_response(message: RoborockMessage) -> MapResponse:
167+
"""Decode a V1 map response message."""
168+
if not message.payload or len(message.payload) < 24:
169+
raise RoborockException("Invalid V1 map response format: missing payload")
170+
header, body = message.payload[:24], message.payload[24:]
171+
[endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", header)
172+
if not endpoint.decode().startswith(security_data.endpoint):
173+
raise RoborockException(
174+
f"Invalid V1 map response endpoint: {endpoint!r}, expected {security_data.endpoint!r}"
175+
)
176+
try:
177+
decrypted = Utils.decrypt_cbc(body, security_data.nonce)
178+
except ValueError as err:
179+
raise RoborockException("Failed to decode map message payload") from err
180+
decompressed = Utils.decompress(decrypted)
181+
return MapResponse(request_id=request_id, data=decompressed)
182+
183+
return _decode_map_response

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import dataclasses
33
import json
4-
import struct
54
import time
65
from abc import ABC, abstractmethod
76
from collections.abc import Callable, Coroutine
@@ -45,7 +44,7 @@
4544
ValleyElectricityTimer,
4645
WashTowelMode,
4746
)
48-
from roborock.protocol import Utils
47+
from roborock.protocols.v1_protocol import MapResponse, SecurityData, create_map_response_decoder
4948
from roborock.roborock_message import (
5049
ROBOROCK_DATA_CONSUMABLE_PROTOCOL,
5150
ROBOROCK_DATA_STATUS_PROTOCOL,
@@ -150,10 +149,15 @@ class RoborockClientV1(RoborockClient, ABC):
150149
"""Roborock client base class for version 1 devices."""
151150

152151
_listeners: dict[str, ListenerModel] = {}
152+
_map_response_decoder: Callable[[RoborockMessage], MapResponse] | None = None
153153

154-
def __init__(self, device_info: DeviceData, endpoint: str):
154+
def __init__(self, device_info: DeviceData, security_data: SecurityData | None) -> None:
155155
"""Initializes the Roborock client."""
156156
super().__init__(device_info)
157+
if security_data is not None:
158+
self._diagnostic_data.update({"misc_info": security_data.to_diagnostic_data()})
159+
self._map_response_decoder = create_map_response_decoder(security_data)
160+
157161
self._status_type: type[Status] = ModelStatus.get(device_info.model, S7MaxVStatus)
158162
self.cache: dict[CacheableAttribute, AttributeCache] = {
159163
cacheable_attribute: AttributeCache(attr, self._send_command)
@@ -162,7 +166,6 @@ def __init__(self, device_info: DeviceData, endpoint: str):
162166
if device_info.device.duid not in self._listeners:
163167
self._listeners[device_info.device.duid] = ListenerModel({}, self.cache)
164168
self.listener_model = self._listeners[device_info.device.duid]
165-
self._endpoint = endpoint
166169

167170
async def async_release(self) -> None:
168171
await super().async_release()
@@ -429,21 +432,15 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
429432
dps = {data_point_number: data_point}
430433
self._logger.debug(f"Got unknown data point {dps}")
431434
elif data.payload and protocol == RoborockMessageProtocol.MAP_RESPONSE:
432-
payload = data.payload[0:24]
433-
[endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", payload)
434-
if endpoint.decode().startswith(self._endpoint):
435-
try:
436-
decrypted = Utils.decrypt_cbc(data.payload[24:], self._nonce)
437-
except ValueError as err:
438-
raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err
439-
decompressed = Utils.decompress(decrypted)
440-
queue = self._waiting_queue.get(request_id)
435+
if self._map_response_decoder is not None:
436+
map_response = self._map_response_decoder(data)
437+
queue = self._waiting_queue.get(map_response.request_id)
441438
if queue:
442-
if isinstance(decompressed, list):
443-
decompressed = decompressed[0]
444-
queue.set_result(decompressed)
439+
queue.set_result(map_response.data)
445440
else:
446-
self._logger.debug("Received response for unknown request id %s", request_id)
441+
self._logger.debug(
442+
"Received unsolicited map response for request_id %s", map_response.request_id
443+
)
447444
else:
448445
queue = self._waiting_queue.get(data.seq)
449446
if queue:

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
6060
self.transport: Transport | None = None
6161
self._mutex = Lock()
6262
self.keep_alive_task: TimerHandle | None = None
63-
RoborockClientV1.__init__(self, device_data, "abc")
63+
RoborockClientV1.__init__(self, device_data, security_data=None)
6464
RoborockClient.__init__(self, device_data)
6565
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
6666
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import base64
21
import logging
32

43
from vacuum_map_parser_base.config.color import ColorsPalette
@@ -10,8 +9,7 @@
109

1110
from ..containers import DeviceData, UserData
1211
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
13-
from ..protocol import Utils
14-
from ..protocols.v1_protocol import SecurityData, create_mqtt_payload_encoder
12+
from ..protocols.v1_protocol import create_mqtt_payload_encoder, create_security_data
1513
from ..roborock_message import (
1614
RoborockMessageProtocol,
1715
)
@@ -30,15 +28,12 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
3028
rriot = user_data.rriot
3129
if rriot is None:
3230
raise RoborockException("Got no rriot data from user_data")
33-
endpoint = base64.b64encode(Utils.md5(rriot.k.encode())[8:14]).decode()
34-
31+
security_data = create_security_data(rriot)
3532
RoborockMqttClient.__init__(self, user_data, device_info)
36-
RoborockClientV1.__init__(self, device_info, endpoint)
33+
RoborockClientV1.__init__(self, device_info, security_data=security_data)
3734
self.queue_timeout = queue_timeout
3835
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
39-
self._payload_encoder = create_mqtt_payload_encoder(
40-
SecurityData(endpoint=self._endpoint, nonce=self._nonce),
41-
)
36+
self._payload_encoder = create_mqtt_payload_encoder(security_data)
4237

4338
async def _send_command(
4439
self,

tests/mock_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
BASE_URL = "https://usiot.roborock.com"
1010

1111
USER_ID = "user123"
12-
K_VALUE = "domain123"
12+
K_VALUE = "qiCNieZa"
1313
USER_DATA = {
1414
"uid": 123456,
1515
"tokentype": "token_type",

tests/protocols/test_v1_protocol.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
from freezegun import freeze_time
88

99
from roborock.containers import RoborockBase, UserData
10+
from roborock.exceptions import RoborockException
11+
from roborock.protocol import Utils
1012
from roborock.protocols.v1_protocol import (
1113
SecurityData,
14+
create_map_response_decoder,
1215
create_mqtt_payload_encoder,
1316
decode_rpc_response,
1417
encode_local_payload,
@@ -20,7 +23,12 @@
2023

2124
USER_DATA = UserData.from_dict(mock_data.USER_DATA)
2225
TEST_REQUEST_ID = 44444
23-
SECURITY_DATA = SecurityData(endpoint="3PBTIjvc", nonce=b"fake-nonce")
26+
TEST_ENDPOINT = "87ItGWdb"
27+
TEST_ENDPOINT_BYTES = TEST_ENDPOINT.encode()
28+
SECURITY_DATA = SecurityData(
29+
endpoint=TEST_ENDPOINT,
30+
nonce=b"\x91\xbe\x10\xc9b+\x9d\x8a\xcdH*\x19\xf6\xfe\x81h",
31+
)
2432

2533

2634
@pytest.fixture(autouse=True)
@@ -62,7 +70,7 @@ def test_encode_local_payload(command, params, expected):
6270
(
6371
RoborockCommand.GET_STATUS,
6472
None,
65-
b'{"dps":{"101":"{\\"id\\":44444,\\"method\\":\\"get_status\\",\\"params\\":[],\\"security\\":{\\"endpoint\\":\\"3PBTIjvc\\",\\"nonce\\":\\"66616b652d6e6f6e6365\\"}}"},"t":1737374400}',
73+
b'{"dps":{"101":"{\\"id\\":44444,\\"method\\":\\"get_status\\",\\"params\\":[],\\"security\\":{\\"endpoint\\":\\"87ItGWdb\\",\\"nonce\\":\\"91be10c9622b9d8acd482a19f6fe8168\\"}}"},"t":1737374400}',
6674
)
6775
],
6876
)
@@ -122,3 +130,70 @@ def test_decode_rpc_response(payload: bytes, expected: RoborockBase) -> None:
122130
)
123131
decoded_message = decode_rpc_response(message)
124132
assert decoded_message == expected
133+
134+
135+
def test_create_map_response_decoder():
136+
"""Test creating and using a map response decoder."""
137+
test_data = b"some map\n"
138+
compressed_data = (
139+
b"\x1f\x8b\x08\x08\xf9\x13\x99h\x00\x03foo\x00+\xce\xcfMU\xc8M,\xe0\x02\x00@\xdb\xc6\x1a\t\x00\x00\x00"
140+
)
141+
142+
# Create header: endpoint(8) + padding(8) + request_id(2) + padding(6)
143+
# request_id = 44508 (0xaddc in little endian)
144+
header = TEST_ENDPOINT_BYTES + b"\x00" * 8 + b"\xdc\xad" + b"\x00" * 6
145+
encrypted_data = Utils.encrypt_cbc(compressed_data, SECURITY_DATA.nonce)
146+
payload = header + encrypted_data
147+
148+
message = RoborockMessage(
149+
protocol=RoborockMessageProtocol.MAP_RESPONSE,
150+
payload=payload,
151+
seq=12750,
152+
version=b"1.0",
153+
random=97431,
154+
timestamp=1652547161,
155+
)
156+
157+
decoder = create_map_response_decoder(SECURITY_DATA)
158+
result = decoder(message)
159+
160+
assert result.request_id == 44508
161+
assert result.data == test_data
162+
163+
164+
def test_create_map_response_decoder_invalid_endpoint():
165+
"""Test map response decoder with invalid endpoint."""
166+
# Create header with wrong endpoint
167+
header = b"wrongend" + b"\x00" * 8 + b"\xdc\xad" + b"\x00" * 6
168+
payload = header + b"encrypted_data"
169+
170+
message = RoborockMessage(
171+
protocol=RoborockMessageProtocol.MAP_RESPONSE,
172+
payload=payload,
173+
seq=12750,
174+
version=b"1.0",
175+
random=97431,
176+
timestamp=1652547161,
177+
)
178+
179+
decoder = create_map_response_decoder(SECURITY_DATA)
180+
181+
with pytest.raises(RoborockException, match="Invalid V1 map response endpoint"):
182+
decoder(message)
183+
184+
185+
def test_create_map_response_decoder_invalid_payload():
186+
"""Test map response decoder with invalid payload."""
187+
message = RoborockMessage(
188+
protocol=RoborockMessageProtocol.MAP_RESPONSE,
189+
payload=b"short", # Too short payload
190+
seq=12750,
191+
version=b"1.0",
192+
random=97431,
193+
timestamp=1652547161,
194+
)
195+
196+
decoder = create_map_response_decoder(SECURITY_DATA)
197+
198+
with pytest.raises(RoborockException, match="Invalid V1 map response format: missing payload"):
199+
decoder(message)

tests/test_containers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
CONSUMABLE,
2323
DND_TIMER,
2424
HOME_DATA_RAW,
25+
K_VALUE,
2526
LOCAL_KEY,
2627
PRODUCT_ID,
2728
STATUS,
@@ -130,7 +131,7 @@ def test_user_data():
130131
assert ud.rriot.u == "user123"
131132
assert ud.rriot.s == "pass123"
132133
assert ud.rriot.h == "unknown123"
133-
assert ud.rriot.k == "domain123"
134+
assert ud.rriot.k == K_VALUE
134135
assert ud.rriot.r.r == "US"
135136
assert ud.rriot.r.a == "https://api-us.roborock.com"
136137
assert ud.rriot.r.m == "tcp://mqtt-us.roborock.com:8883"

0 commit comments

Comments
 (0)