Skip to content

Commit 5a17e6f

Browse files
committed
fix: make threadsafe waiting queue
1 parent 777b736 commit 5a17e6f

File tree

9 files changed

+217
-63
lines changed

9 files changed

+217
-63
lines changed

roborock/api.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
RoborockTimeout,
1818
UnknownMethodError,
1919
)
20-
from .roborock_future import RoborockFuture
20+
from .roborock_future import RequestKey, RoborockFuture, WaitingQueue
2121
from .roborock_message import (
2222
RoborockMessage,
2323
)
24-
from .util import get_next_int, get_running_loop_or_create_one
24+
from .util import get_running_loop_or_create_one
2525

2626
_LOGGER = logging.getLogger(__name__)
2727
KEEPALIVE = 60
@@ -37,7 +37,7 @@ def __init__(self, device_info: DeviceData, queue_timeout: int = 4) -> None:
3737
self.event_loop = get_running_loop_or_create_one()
3838
self.device_info = device_info
3939
self._nonce = secrets.token_bytes(16)
40-
self._waiting_queue: dict[int, RoborockFuture] = {}
40+
self._waiting_queue = WaitingQueue()
4141
self._last_device_msg_in = time.monotonic()
4242
self._last_disconnection = time.monotonic()
4343
self.keep_alive = KEEPALIVE
@@ -94,31 +94,21 @@ async def validate_connection(self) -> None:
9494
await self.async_disconnect()
9595
await self.async_connect()
9696

97-
async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any:
97+
async def _wait_response(self, request_key: RequestKey, future: RoborockFuture) -> Any:
9898
try:
99-
response = await queue.async_get(self.queue_timeout)
99+
response = await future.async_get(self.queue_timeout)
100100
if response == "unknown_method":
101101
raise UnknownMethodError("Unknown method")
102102
return response
103103
except (asyncio.TimeoutError, asyncio.CancelledError):
104-
raise RoborockTimeout(f"id={request_id} Timeout after {self.queue_timeout} seconds") from None
104+
raise RoborockTimeout(f"id={request_key} Timeout after {self.queue_timeout} seconds") from None
105105
finally:
106-
self._waiting_queue.pop(request_id, None)
107-
108-
def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
109-
queue = RoborockFuture(protocol_id)
110-
if request_id in self._waiting_queue:
111-
new_id = get_next_int(10000, 32767)
112-
self._logger.warning(
113-
"Attempting to create a future with an existing id %s (%s)... New id is %s. "
114-
"Code may not function properly.",
115-
request_id,
116-
protocol_id,
117-
new_id,
118-
)
119-
request_id = new_id
120-
self._waiting_queue[request_id] = queue
121-
return asyncio.ensure_future(self._wait_response(request_id, queue))
106+
self._waiting_queue.safe_pop(request_key)
107+
108+
def _async_response(self, request_key: RequestKey) -> Any:
109+
future = RoborockFuture()
110+
self._waiting_queue.put(request_key, future)
111+
return asyncio.ensure_future(self._wait_response(request_key, future))
122112

123113
@abstractmethod
124114
async def send_message(self, roborock_message: RoborockMessage):

roborock/cloud_api.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .containers import DeviceData, UserData
1414
from .exceptions import RoborockException, VacuumError
1515
from .protocol import MessageParser, md5hex
16-
from .roborock_future import RoborockFuture
16+
from .roborock_future import RequestKey
1717

1818
_LOGGER = logging.getLogger(__name__)
1919
CONNECT_REQUEST_ID = 0
@@ -71,12 +71,11 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
7171
self._mqtt_password = rriot.s
7272
self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:]
7373
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
74-
self._waiting_queue: dict[int, RoborockFuture] = {}
7574
self._mutex = Lock()
7675

7776
def _mqtt_on_connect(self, *args, **kwargs):
7877
_, __, ___, rc, ____ = args
79-
connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID)
78+
connection_queue = self._waiting_queue.safe_pop(RequestKey(CONNECT_REQUEST_ID))
8079
if rc != mqtt.MQTT_ERR_SUCCESS:
8180
message = f"Failed to connect ({mqtt.error_string(rc)})"
8281
self._logger.error(message)
@@ -95,6 +94,8 @@ def _mqtt_on_connect(self, *args, **kwargs):
9594
self._logger.info(f"Subscribed to topic {topic}")
9695
if connection_queue:
9796
connection_queue.set_result(True)
97+
else:
98+
self._logger.debug("Connected but no connect future")
9899

99100
def _mqtt_on_message(self, *args, **kwargs):
100101
client, __, msg = args
@@ -109,9 +110,11 @@ def _mqtt_on_disconnect(self, *args, **kwargs):
109110
try:
110111
exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None
111112
super().on_connection_lost(exc)
112-
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
113+
connection_queue = self._waiting_queue.safe_pop(RequestKey(DISCONNECT_REQUEST_ID))
113114
if connection_queue:
114115
connection_queue.set_result(True)
116+
else:
117+
self._logger.debug("Disconnected but no disconnect future")
115118
except Exception as ex:
116119
self._logger.exception(ex)
117120

@@ -121,10 +124,11 @@ def is_connected(self) -> bool:
121124

122125
def sync_disconnect(self) -> Any:
123126
if not self.is_connected():
127+
self._logger.debug("Already disconnected from mqtt")
124128
return None
125129

126130
self._logger.info("Disconnecting from mqtt")
127-
disconnected_future = self._async_response(DISCONNECT_REQUEST_ID)
131+
disconnected_future = self._async_response(RequestKey(DISCONNECT_REQUEST_ID))
128132
rc = self._mqtt_client.disconnect()
129133

130134
if rc == mqtt.MQTT_ERR_NO_CONN:
@@ -146,7 +150,7 @@ def sync_connect(self) -> Any:
146150
raise RoborockException("Mqtt information was not entered. Cannot connect.")
147151

148152
self._logger.debug("Connecting to mqtt")
149-
connected_future = self._async_response(CONNECT_REQUEST_ID)
153+
connected_future = self._async_response(RequestKey(CONNECT_REQUEST_ID))
150154
self._mqtt_client.connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
151155
self._mqtt_client.maybe_restart_loop()
152156
return connected_future

roborock/roborock_future.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,64 @@
11
from __future__ import annotations
22

3+
import logging
34
from asyncio import Future
5+
from dataclasses import dataclass
6+
from threading import Lock
47
from typing import Any
58

69
import async_timeout
710

8-
from .exceptions import VacuumError
11+
from .exceptions import UnknownMethodError, VacuumError
12+
from .roborock_message import RoborockMessageProtocol
13+
14+
_LOGGER = logging.getLogger(__name__)
15+
_TRIES = 3
16+
17+
18+
@dataclass(frozen=True)
19+
class RequestKey:
20+
"""A key for a Roborock message request."""
21+
22+
request_id: int
23+
protocol: RoborockMessageProtocol | int = 0
24+
25+
def __str__(self) -> str:
26+
"""Get the key for the request."""
27+
return f"{self.request_id}-{self.protocol}"
28+
29+
30+
class WaitingQueue:
31+
"""A threadsafe waiting queue for Roborock messages."""
32+
33+
def __init__(self) -> None:
34+
"""Initialize the waiting queue."""
35+
self._lock = Lock()
36+
self._queue: dict[RequestKey, RoborockFuture] = {}
37+
38+
def put(self, request_key: RequestKey, future: RoborockFuture) -> None:
39+
"""Create a future for the given protocol."""
40+
_LOGGER.debug("Putting request key %s in the queue", request_key)
41+
with self._lock:
42+
if request_key in self._queue:
43+
raise ValueError(f"Request key {request_key} already exists in the queue")
44+
self._queue[request_key] = future
45+
46+
def safe_pop(self, request_key: RequestKey) -> RoborockFuture | None:
47+
"""Get the future from the queue if it has not yet been popped, otherwise ignore."""
48+
_LOGGER.debug("Popping request key %s from the queue", request_key)
49+
with self._lock:
50+
return self._queue.pop(request_key, None)
951

1052

1153
class RoborockFuture:
12-
def __init__(self, protocol: int):
13-
self.protocol = protocol
54+
"""A threadsafe asyncio Future for Roborock messages.
55+
56+
The results may be set from a background thread. The future
57+
must be awaited in an asyncio event loop.
58+
"""
59+
60+
def __init__(self):
61+
"""Initialize the Roborock future."""
1462
self.fut: Future = Future()
1563
self.loop = self.fut.get_loop()
1664

@@ -28,9 +76,15 @@ def _set_exception(self, exc: VacuumError) -> None:
2876
def set_exception(self, exc: VacuumError) -> None:
2977
self.loop.call_soon_threadsafe(self._set_exception, exc)
3078

31-
async def async_get(self, timeout: float | int) -> tuple[Any, VacuumError | None]:
79+
async def async_get(self, timeout: float | int) -> Any:
80+
"""Get the result from the future or raises an error."""
3281
try:
3382
async with async_timeout.timeout(timeout):
34-
return await self.fut
83+
response = await self.fut
84+
# This should be moved to the specific client that handles this
85+
# and set an exception directly rather than patching an exception here
86+
if response == "unknown_method":
87+
raise UnknownMethodError("Unknown method")
88+
return response
3589
finally:
3690
self.fut.cancel()

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
WashTowelMode,
4848
)
4949
from roborock.protocol import Utils
50+
from roborock.roborock_future import RequestKey
5051
from roborock.roborock_message import (
5152
ROBOROCK_DATA_CONSUMABLE_PROTOCOL,
5253
ROBOROCK_DATA_STATUS_PROTOCOL,
@@ -390,8 +391,9 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
390391
if data_point_number == "102":
391392
data_point_response = json.loads(data_point)
392393
request_id = data_point_response.get("id")
393-
queue = self._waiting_queue.get(request_id)
394-
if queue and queue.protocol == protocol:
394+
request_key = RequestKey(request_id, protocol)
395+
queue = self._waiting_queue.safe_pop(request_key)
396+
if queue:
395397
error = data_point_response.get("error")
396398
if error:
397399
queue.set_exception(
@@ -406,7 +408,7 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
406408
result = result[0]
407409
queue.set_result(result)
408410
else:
409-
self._logger.debug("Received response for unknown request id %s", request_id)
411+
self._logger.debug("Received response for unknown request id %s", request_key)
410412
else:
411413
try:
412414
data_protocol = RoborockDataProtocol(int(data_point_number))
@@ -460,19 +462,21 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
460462
except ValueError as err:
461463
raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err
462464
decompressed = Utils.decompress(decrypted)
463-
queue = self._waiting_queue.get(request_id)
465+
request_key = RequestKey(request_id, protocol)
466+
queue = self._waiting_queue.safe_pop(request_key)
464467
if queue:
465468
if isinstance(decompressed, list):
466469
decompressed = decompressed[0]
467470
queue.set_result(decompressed)
468471
else:
469-
self._logger.debug("Received response for unknown request id %s", request_id)
472+
self._logger.debug("Received response for unknown request id %s", request_key)
470473
else:
471-
queue = self._waiting_queue.get(data.seq)
474+
request_key = RequestKey(data.seq, protocol)
475+
queue = self._waiting_queue.safe_pop(request_key)
472476
if queue:
473477
queue.set_result(data.payload)
474478
else:
475-
self._logger.debug("Received response for unknown request id %s", data.seq)
479+
self._logger.debug("Received response for unknown request id %s", request_key)
476480
except Exception as ex:
477481
self._logger.exception(ex)
478482

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
66
from ..exceptions import VacuumError
77
from ..protocol import MessageParser
8+
from ..roborock_future import RequestKey
89
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
910
from ..util import RoborockLoggerAdapter
1011
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
@@ -53,15 +54,19 @@ async def send_message(self, roborock_message: RoborockMessage):
5354
response_protocol = request_id + 1
5455
else:
5556
request_id = roborock_message.get_request_id()
57+
_LOGGER.debug("Getting next request id: %s", request_id)
5658
response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
5759
if request_id is None:
5860
raise RoborockException(f"Failed build message {roborock_message}")
5961
local_key = self.device_info.device.local_key
6062
msg = MessageParser.build(roborock_message, local_key=local_key)
63+
request_key = RequestKey(request_id, response_protocol)
6164
if method:
62-
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
65+
self._logger.debug(f"id={request_key} Requesting method {method} with {params}")
66+
else:
67+
self._logger.debug(f"id={request_key} Requesting with {params}")
6368
# Send the command to the Roborock device
64-
async_response = self._async_response(request_id, response_protocol)
69+
async_response = self._async_response(request_key)
6570
self._send_msg_raw(msg)
6671
diagnostic_key = method if method is not None else "unknown"
6772
try:

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..containers import DeviceData, UserData
1212
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
1313
from ..protocol import MessageParser, Utils
14+
from ..roborock_future import RequestKey
1415
from ..roborock_message import (
1516
RoborockMessage,
1617
RoborockMessageProtocol,
@@ -46,11 +47,11 @@ async def send_message(self, roborock_message: RoborockMessage):
4647
response_protocol = (
4748
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
4849
)
49-
50+
request_key = RequestKey(request_id, response_protocol)
5051
local_key = self.device_info.device.local_key
5152
msg = MessageParser.build(roborock_message, local_key, False)
52-
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
53-
async_response = self._async_response(request_id, response_protocol)
53+
self._logger.debug(f"id={request_key} Requesting method {method} with {params}")
54+
async_response = self._async_response(request_key)
5455
self._send_msg_raw(msg)
5556
diagnostic_key = method if method is not None else "unknown"
5657
try:
@@ -66,9 +67,9 @@ async def send_message(self, roborock_message: RoborockMessage):
6667
"response": response,
6768
}
6869
if response_protocol == RoborockMessageProtocol.MAP_RESPONSE:
69-
self._logger.debug(f"id={request_id} Response from {method}: {len(response)} bytes")
70+
self._logger.debug(f"id={request_key} Response from {method}: {len(response)} bytes")
7071
else:
71-
self._logger.debug(f"id={request_id} Response from {method}: {response}")
72+
self._logger.debug(f"id={request_key} Response from {method}: {response}")
7273
return response
7374

7475
async def _send_command(

roborock/version_a01_apis/roborock_client_a01.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ZeoTemperature,
3434
)
3535
from roborock.containers import DyadProductInfo, DyadSndState, RoborockCategory
36+
from roborock.roborock_future import RequestKey
3637
from roborock.roborock_message import (
3738
RoborockDyadDataProtocol,
3839
RoborockMessage,
@@ -142,9 +143,12 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
142143
if data_point_protocol in entries:
143144
# Auto convert into data struct we want.
144145
converted_response = entries[data_point_protocol].post_process_fn(data_point)
145-
queue = self._waiting_queue.get(int(data_point_number))
146-
if queue and queue.protocol == protocol:
147-
queue.set_result(converted_response)
146+
request_key = RequestKey(int(data_point_number), protocol)
147+
future = self._waiting_queue.safe_pop(request_key)
148+
if future is not None:
149+
future.set_result(converted_response)
150+
else:
151+
self._logger.debug(f"Got response for {request_key} but no future found")
148152

149153
@abstractmethod
150154
async def update_values(

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from roborock.containers import DeviceData, RoborockCategory, UserData
1111
from roborock.exceptions import RoborockException
1212
from roborock.protocol import MessageParser
13+
from roborock.roborock_future import RequestKey
1314
from roborock.roborock_message import (
1415
RoborockDyadDataProtocol,
1516
RoborockMessage,
@@ -49,7 +50,7 @@ async def send_message(self, roborock_message: RoborockMessage):
4950
futures = []
5051
if "10000" in payload["dps"]:
5152
for dps in json.loads(payload["dps"]["10000"]):
52-
futures.append(self._async_response(dps, response_protocol))
53+
futures.append(self._async_response(RequestKey(dps, response_protocol)))
5354
self._send_msg_raw(m)
5455
responses = await asyncio.gather(*futures, return_exceptions=True)
5556
dps_responses: dict[int, typing.Any] = {}

0 commit comments

Comments
 (0)