Skip to content

Commit 82c96e1

Browse files
Merge pull request #18 from humbertogontijo/local_and_cloud_clients
feat: sppliting clients into local and cloud
2 parents 1b4926e + 8019313 commit 82c96e1

File tree

12 files changed

+457
-451
lines changed

12 files changed

+457
-451
lines changed

roborock/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Roborock API."""
22

3-
from roborock.api import RoborockClient, RoborockMqttClient
3+
from roborock.api import RoborockApiClient
4+
from roborock.cloud_api import RoborockMqttClient
45
from roborock.containers import *
56
from roborock.exceptions import *
67
from roborock.typing import *

roborock/api.py

Lines changed: 53 additions & 280 deletions
Large diffs are not rendered by default.

roborock/cli.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import click
77

88
from roborock import RoborockException
9-
from roborock.api import RoborockClient, RoborockMqttClient
9+
from roborock.api import RoborockApiClient
10+
from roborock.cloud_api import RoborockMqttClient
1011
from roborock.containers import LoginData
11-
from roborock.typing import RoborockDeviceInfo
1212
from roborock.util import run_sync
1313

1414
_LOGGER = logging.getLogger(__name__)
@@ -69,24 +69,28 @@ async def login(ctx, email, password):
6969
return
7070
except RoborockException:
7171
pass
72-
client = RoborockClient(email)
72+
client = RoborockApiClient(email)
7373
user_data = await client.pass_login(password)
7474
context.update(LoginData({"user_data": user_data, "email": email}))
7575

76+
7677
async def _discover(ctx):
7778
context: RoborockContext = ctx.obj
7879
login_data = context.login_data()
79-
client = RoborockClient(login_data.email)
80+
client = RoborockApiClient(login_data.email)
8081
home_data = await client.get_home_data(login_data.user_data)
8182
context.update(LoginData({**login_data, "home_data": home_data}))
82-
click.echo(f"Discovered devices {', '.join([device.name for device in home_data.devices + home_data.received_devices])}")
83+
click.echo(
84+
f"Discovered devices {', '.join([device.name for device in home_data.devices + home_data.received_devices])}")
85+
8386

8487
@click.command()
8588
@click.pass_context
8689
@run_sync()
8790
async def discover(ctx):
8891
await _discover(ctx)
8992

93+
9094
@click.command()
9195
@click.pass_context
9296
@run_sync()
@@ -99,6 +103,7 @@ async def list_devices(ctx):
99103
home_data = login_data.home_data
100104
click.echo(f"Known devices {', '.join([device.name for device in home_data.devices + home_data.received_devices])}")
101105

106+
102107
@click.command()
103108
@click.option("--cmd", required=True)
104109
@click.option("--params", required=False)
@@ -111,17 +116,9 @@ async def command(ctx, cmd, params):
111116
await _discover(ctx)
112117
login_data = context.login_data()
113118
home_data = login_data.home_data
114-
device_map: dict[str, RoborockDeviceInfo] = {}
119+
device_map: dict[str, str] = {}
115120
for device in home_data.devices + home_data.received_devices:
116-
product = next(
117-
(
118-
product
119-
for product in home_data.products
120-
if product.id == device.product_id
121-
),
122-
{},
123-
)
124-
device_map[device.duid] = RoborockDeviceInfo(device, product)
121+
device_map[device.duid] = device.local_key
125122
mqtt_client = RoborockMqttClient(login_data.user_data, device_map)
126123
await mqtt_client.send_command(home_data.devices[0].duid, cmd, params)
127124
mqtt_client.__del__()

roborock/cloud_api.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
import gzip
5+
import json
6+
import logging
7+
import secrets
8+
import struct
9+
import threading
10+
from asyncio import Lock
11+
from asyncio.exceptions import TimeoutError, CancelledError
12+
from typing import Any, Callable
13+
from urllib.parse import urlparse
14+
15+
import paho.mqtt.client as mqtt
16+
from Crypto.Cipher import AES
17+
from Crypto.Util.Padding import unpad
18+
19+
from roborock.api import md5hex, md5bin, RoborockClient
20+
from roborock.exceptions import (
21+
RoborockException,
22+
CommandVacuumError,
23+
VacuumError,
24+
RoborockTimeout,
25+
)
26+
from .code_mappings import STATE_CODE_TO_STATUS
27+
from .containers import (
28+
UserData,
29+
)
30+
from .roborock_queue import RoborockQueue
31+
from .typing import (
32+
RoborockCommand,
33+
)
34+
from .util import run_in_executor
35+
36+
_LOGGER = logging.getLogger(__name__)
37+
QUEUE_TIMEOUT = 4
38+
MQTT_KEEPALIVE = 60
39+
COMMANDS_WITH_BINARY_RESPONSE = [
40+
RoborockCommand.GET_MAP_V1,
41+
]
42+
43+
44+
class RoborockMqttClient(RoborockClient, mqtt.Client):
45+
_thread: threading.Thread
46+
47+
def __init__(self, user_data: UserData, device_localkey: dict[str, str]) -> None:
48+
rriot = user_data.rriot
49+
self._mqtt_user = rriot.user
50+
RoborockClient.__init__(self, rriot.endpoint, device_localkey)
51+
mqtt.Client.__init__(self, protocol=mqtt.MQTTv5)
52+
self._hashed_user = md5hex(self._mqtt_user + ":" + rriot.endpoint)[2:10]
53+
url = urlparse(rriot.reference.mqtt)
54+
self._mqtt_host = url.hostname
55+
self._mqtt_port = url.port
56+
self._mqtt_ssl = url.scheme == "ssl"
57+
if self._mqtt_ssl:
58+
super().tls_set()
59+
self._mqtt_password = rriot.password
60+
self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.endpoint)[16:]
61+
super().username_pw_set(self._hashed_user, self._hashed_password)
62+
self._seq = 1
63+
self._random = 4711
64+
self._id_counter = 2
65+
self._salt = "TXdfu$jyZ#TZHsg4"
66+
self._endpoint = base64.b64encode(md5bin(rriot.endpoint)[8:14]).decode()
67+
self._nonce = secrets.token_bytes(16)
68+
self._waiting_queue: dict[int, RoborockQueue] = {}
69+
self._mutex = Lock()
70+
self._last_device_msg_in = mqtt.time_func()
71+
self._last_disconnection = mqtt.time_func()
72+
self._status_listeners: list[Callable[[str, str], None]] = []
73+
74+
def __del__(self) -> None:
75+
self.sync_disconnect()
76+
77+
@run_in_executor()
78+
async def on_connect(self, _client, _, __, rc, ___=None) -> None:
79+
connection_queue = self._waiting_queue.get(0)
80+
if rc != mqtt.MQTT_ERR_SUCCESS:
81+
message = f"Failed to connect (rc: {rc})"
82+
_LOGGER.error(message)
83+
if connection_queue:
84+
await connection_queue.async_put(
85+
(None, VacuumError(rc, message)), timeout=QUEUE_TIMEOUT
86+
)
87+
return
88+
_LOGGER.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
89+
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/#"
90+
(result, mid) = self.subscribe(topic)
91+
if result != 0:
92+
message = f"Failed to subscribe (rc: {result})"
93+
_LOGGER.error(message)
94+
if connection_queue:
95+
await connection_queue.async_put(
96+
(None, VacuumError(rc, message)), timeout=QUEUE_TIMEOUT
97+
)
98+
return
99+
_LOGGER.info(f"Subscribed to topic {topic}")
100+
if connection_queue:
101+
await connection_queue.async_put((True, None), timeout=QUEUE_TIMEOUT)
102+
103+
@run_in_executor()
104+
async def on_message(self, _client, _, msg, __=None) -> None:
105+
try:
106+
async with self._mutex:
107+
self._last_device_msg_in = mqtt.time_func()
108+
device_id = msg.topic.split("/").pop()
109+
data = self._decode_msg(msg.payload, self.device_localkey[device_id])
110+
protocol = data.get("protocol")
111+
if protocol == 102:
112+
payload = json.loads(data.get("payload").decode())
113+
for data_point_number, data_point in payload.get("dps").items():
114+
if data_point_number == "102":
115+
data_point_response = json.loads(data_point)
116+
request_id = data_point_response.get("id")
117+
queue = self._waiting_queue.get(request_id)
118+
if queue:
119+
if queue.protocol == protocol:
120+
error = data_point_response.get("error")
121+
if error:
122+
await queue.async_put(
123+
(
124+
None,
125+
VacuumError(
126+
error.get("code"), error.get("message")
127+
),
128+
),
129+
timeout=QUEUE_TIMEOUT,
130+
)
131+
else:
132+
result = data_point_response.get("result")
133+
if isinstance(result, list) and len(result) > 0:
134+
result = result[0]
135+
await queue.async_put(
136+
(result, None), timeout=QUEUE_TIMEOUT
137+
)
138+
elif request_id < self._id_counter:
139+
_LOGGER.debug(
140+
f"id={request_id} Ignoring response: {data_point_response}"
141+
)
142+
elif data_point_number == "121":
143+
status = STATE_CODE_TO_STATUS.get(data_point)
144+
_LOGGER.debug(f"Status updated to {status}")
145+
for listener in self._status_listeners:
146+
listener(device_id, status)
147+
else:
148+
_LOGGER.debug(
149+
f"Unknown data point number received {data_point_number} with {data_point}"
150+
)
151+
elif protocol == 301:
152+
payload = data.get("payload")[0:24]
153+
[endpoint, _, request_id, _] = struct.unpack("<15sBH6s", payload)
154+
if endpoint.decode().startswith(self._endpoint):
155+
iv = bytes(AES.block_size)
156+
decipher = AES.new(self._nonce, AES.MODE_CBC, iv)
157+
decrypted = unpad(
158+
decipher.decrypt(data.get("payload")[24:]), AES.block_size
159+
)
160+
decrypted = gzip.decompress(decrypted)
161+
queue = self._waiting_queue.get(request_id)
162+
if queue:
163+
if isinstance(decrypted, list):
164+
decrypted = decrypted[0]
165+
await queue.async_put((decrypted, None), timeout=QUEUE_TIMEOUT)
166+
except Exception as ex:
167+
_LOGGER.exception(ex)
168+
169+
@run_in_executor()
170+
async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None) -> None:
171+
try:
172+
self._last_disconnection = mqtt.time_func()
173+
message = f"Roborock mqtt client disconnected (rc: {rc})"
174+
_LOGGER.warning(message)
175+
connection_queue = self._waiting_queue.get(1)
176+
if connection_queue:
177+
await connection_queue.async_put(
178+
(True, None), timeout=QUEUE_TIMEOUT
179+
)
180+
except Exception as ex:
181+
_LOGGER.exception(ex)
182+
183+
@run_in_executor()
184+
async def _async_check_keepalive(self) -> None:
185+
async with self._mutex:
186+
now = mqtt.time_func()
187+
if now - self._last_disconnection > self._keepalive ** 2 and now - self._last_device_msg_in > self._keepalive:
188+
self._ping_t = self._last_device_msg_in
189+
190+
def add_status_listener(self, callback: Callable[[str, str], None]):
191+
self._status_listeners.append(callback)
192+
193+
def _check_keepalive(self) -> None:
194+
self._async_check_keepalive()
195+
super()._check_keepalive()
196+
197+
def sync_stop_loop(self) -> None:
198+
if self._thread:
199+
_LOGGER.info("Stopping mqtt loop")
200+
super().loop_stop()
201+
202+
def sync_start_loop(self) -> None:
203+
if not self._thread or not self._thread.is_alive():
204+
self.sync_stop_loop()
205+
_LOGGER.info("Starting mqtt loop")
206+
super().loop_start()
207+
208+
def sync_disconnect(self) -> bool:
209+
rc = mqtt.MQTT_ERR_AGAIN
210+
if self.is_connected():
211+
_LOGGER.info("Disconnecting from mqtt")
212+
rc = super().disconnect()
213+
if not rc in [mqtt.MQTT_ERR_SUCCESS, mqtt.MQTT_ERR_NO_CONN]:
214+
raise RoborockException(f"Failed to disconnect (rc:{rc})")
215+
return rc == mqtt.MQTT_ERR_SUCCESS
216+
217+
def sync_connect(self) -> bool:
218+
rc = mqtt.MQTT_ERR_AGAIN
219+
self.sync_start_loop()
220+
if not self.is_connected():
221+
_LOGGER.info("Connecting to mqtt")
222+
rc = super().connect(
223+
host=self._mqtt_host,
224+
port=self._mqtt_port,
225+
keepalive=MQTT_KEEPALIVE
226+
)
227+
if rc != mqtt.MQTT_ERR_SUCCESS:
228+
raise RoborockException(f"Failed to connect (rc:{rc})")
229+
return rc == mqtt.MQTT_ERR_SUCCESS
230+
231+
async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[Any, VacuumError | None]:
232+
try:
233+
queue = RoborockQueue(protocol_id)
234+
self._waiting_queue[request_id] = queue
235+
(response, err) = await queue.async_get(QUEUE_TIMEOUT)
236+
return response, err
237+
except (TimeoutError, CancelledError):
238+
raise RoborockTimeout(
239+
f"Timeout after {QUEUE_TIMEOUT} seconds waiting for response"
240+
) from None
241+
finally:
242+
del self._waiting_queue[request_id]
243+
244+
async def async_disconnect(self) -> Any:
245+
async with self._mutex:
246+
disconnecting = self.sync_disconnect()
247+
if disconnecting:
248+
(response, err) = await self._async_response(1)
249+
if err:
250+
raise RoborockException(err) from err
251+
return response
252+
253+
async def async_connect(self) -> Any:
254+
async with self._mutex:
255+
connecting = self.sync_connect()
256+
if connecting:
257+
(response, err) = await self._async_response(0)
258+
if err:
259+
raise RoborockException(err) from err
260+
return response
261+
262+
async def validate_connection(self) -> None:
263+
await self.async_connect()
264+
265+
def _send_msg_raw(self, device_id, msg) -> None:
266+
info = self.publish(
267+
f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{device_id}", msg
268+
)
269+
if info.rc != mqtt.MQTT_ERR_SUCCESS:
270+
raise RoborockException(f"Failed to publish (rc: {info.rc})")
271+
272+
async def send_command(
273+
self, device_id: str, method: RoborockCommand, params: list = None
274+
):
275+
await self.validate_connection()
276+
request_id, timestamp, payload = super()._get_payload(method, params)
277+
_LOGGER.debug(f"id={request_id} Requesting method {method} with {params}")
278+
request_protocol = 101
279+
response_protocol = 301 if method in COMMANDS_WITH_BINARY_RESPONSE else 102
280+
msg = super()._get_msg_raw(device_id, request_protocol, timestamp, payload)
281+
self._send_msg_raw(device_id, msg)
282+
(response, err) = await self._async_response(request_id, response_protocol)
283+
if err:
284+
raise CommandVacuumError(method, err) from err
285+
if response_protocol == 301:
286+
_LOGGER.debug(
287+
f"id={request_id} Response from {method}: {len(response)} bytes"
288+
)
289+
else:
290+
_LOGGER.debug(f"id={request_id} Response from {method}: {response}")
291+
return response

roborock/containers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class UserDataRRiotField(str, Enum):
1515
USER = "u"
1616
PASSWORD = "s"
1717
H_UNKNOWN = "h"
18-
DOMAIN = "k"
18+
ENDPOINT = "k"
1919
REFERENCE = "r"
2020

2121

@@ -277,8 +277,8 @@ def h_unknown(self) -> str:
277277
return self.get(UserDataRRiotField.H_UNKNOWN)
278278

279279
@property
280-
def domain(self) -> str:
281-
return self.get(UserDataRRiotField.DOMAIN)
280+
def endpoint(self) -> str:
281+
return self.get(UserDataRRiotField.ENDPOINT)
282282

283283
@property
284284
def reference(self) -> Reference:

0 commit comments

Comments
 (0)