Skip to content

Commit d8ce60f

Browse files
authored
chore: extract common module for manaing pending RPCs (#451)
1 parent 1addf95 commit d8ce60f

File tree

4 files changed

+132
-33
lines changed

4 files changed

+132
-33
lines changed

roborock/devices/local_channel.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from roborock.roborock_message import RoborockMessage
1212

1313
from .channel import Channel
14+
from .pending import PendingRpcs
1415

1516
_LOGGER = logging.getLogger(__name__)
1617
_PORT = 58867
@@ -47,10 +48,9 @@ def __init__(self, host: str, local_key: str):
4748
self._is_connected = False
4849

4950
# RPC support
50-
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
51+
self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs()
5152
self._decoder: Decoder = create_local_decoder(local_key)
5253
self._encoder: Encoder = create_local_encoder(local_key)
53-
self._queue_lock = asyncio.Lock()
5454

5555
@property
5656
def is_connected(self) -> bool:
@@ -114,11 +114,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
114114
if (request_id := message.get_request_id()) is None:
115115
_LOGGER.debug("Received message with no request_id")
116116
return
117-
async with self._queue_lock:
118-
if (future := self._waiting_queue.pop(request_id, None)) is not None:
119-
future.set_result(message)
120-
else:
121-
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
117+
await self._pending_rpcs.resolve(request_id, message)
122118

123119
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
124120
"""Send a command message and wait for the response message."""
@@ -132,24 +128,17 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) ->
132128
_LOGGER.exception("Error getting request_id from message: %s", err)
133129
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
134130

135-
future: asyncio.Future[RoborockMessage] = asyncio.Future()
136-
async with self._queue_lock:
137-
if request_id in self._waiting_queue:
138-
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
139-
self._waiting_queue[request_id] = future
140-
131+
future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)
141132
try:
142133
encoded_msg = self._encoder(message)
143134
self._transport.write(encoded_msg)
144135
return await asyncio.wait_for(future, timeout=timeout)
145136
except asyncio.TimeoutError as ex:
146-
async with self._queue_lock:
147-
self._waiting_queue.pop(request_id, None)
137+
await self._pending_rpcs.pop(request_id)
148138
raise RoborockException(f"Command timed out after {timeout}s") from ex
149139
except Exception:
150140
logging.exception("Uncaught error sending command")
151-
async with self._queue_lock:
152-
self._waiting_queue.pop(request_id, None)
141+
await self._pending_rpcs.pop(request_id)
153142
raise
154143

155144

roborock/devices/mqtt_channel.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from roborock.roborock_message import RoborockMessage
1313

1414
from .channel import Channel
15+
from .pending import PendingRpcs
1516

1617
_LOGGER = logging.getLogger(__name__)
1718

@@ -31,10 +32,9 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot:
3132
self._mqtt_params = mqtt_params
3233

3334
# RPC support
34-
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
35+
self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs()
3536
self._decoder = create_mqtt_decoder(local_key)
3637
self._encoder = create_mqtt_encoder(local_key)
37-
self._queue_lock = asyncio.Lock()
3838
self._mqtt_unsub: Callable[[], None] | None = None
3939

4040
@property
@@ -89,11 +89,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
8989
if (request_id := message.get_request_id()) is None:
9090
_LOGGER.debug("Received message with no request_id")
9191
return
92-
async with self._queue_lock:
93-
if (future := self._waiting_queue.pop(request_id, None)) is not None:
94-
future.set_result(message)
95-
else:
96-
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
92+
await self._pending_rpcs.resolve(request_id, message)
9793

9894
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
9995
"""Send a command message and wait for the response message.
@@ -107,11 +103,7 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) ->
107103
_LOGGER.exception("Error getting request_id from message: %s", err)
108104
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
109105

110-
future: asyncio.Future[RoborockMessage] = asyncio.Future()
111-
async with self._queue_lock:
112-
if request_id in self._waiting_queue:
113-
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
114-
self._waiting_queue[request_id] = future
106+
future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)
115107

116108
try:
117109
encoded_msg = self._encoder(message)
@@ -120,13 +112,11 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) ->
120112
return await asyncio.wait_for(future, timeout=timeout)
121113

122114
except asyncio.TimeoutError as ex:
123-
async with self._queue_lock:
124-
self._waiting_queue.pop(request_id, None)
115+
await self._pending_rpcs.pop(request_id)
125116
raise RoborockException(f"Command timed out after {timeout}s") from ex
126117
except Exception:
127118
logging.exception("Uncaught error sending command")
128-
async with self._queue_lock:
129-
self._waiting_queue.pop(request_id, None)
119+
await self._pending_rpcs.pop(request_id)
130120
raise
131121

132122

roborock/devices/pending.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Module for managing pending RPCs."""
2+
3+
import asyncio
4+
import logging
5+
from typing import Generic, TypeVar
6+
7+
from roborock.exceptions import RoborockException
8+
9+
_LOGGER = logging.getLogger(__name__)
10+
11+
12+
K = TypeVar("K")
13+
V = TypeVar("V")
14+
15+
16+
class PendingRpcs(Generic[K, V]):
17+
"""Manage pending RPCs."""
18+
19+
def __init__(self) -> None:
20+
"""Initialize the pending RPCs."""
21+
self._queue_lock = asyncio.Lock()
22+
self._waiting_queue: dict[K, asyncio.Future[V]] = {}
23+
24+
async def start(self, key: K) -> asyncio.Future[V]:
25+
"""Start the pending RPCs."""
26+
future: asyncio.Future[V] = asyncio.Future()
27+
async with self._queue_lock:
28+
if key in self._waiting_queue:
29+
raise RoborockException(f"Request ID {key} already pending, cannot send command")
30+
self._waiting_queue[key] = future
31+
return future
32+
33+
async def pop(self, key: K) -> None:
34+
"""Pop a pending RPC."""
35+
async with self._queue_lock:
36+
if (future := self._waiting_queue.pop(key, None)) is not None:
37+
future.cancel()
38+
39+
async def resolve(self, key: K, value: V) -> None:
40+
"""Resolve waiting future with proper locking."""
41+
async with self._queue_lock:
42+
if (future := self._waiting_queue.pop(key, None)) is not None:
43+
future.set_result(value)
44+
else:
45+
_LOGGER.debug("Received unsolicited message: %s", key)

tests/devices/test_pending.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Tests for the PendingRpcs class."""
2+
3+
import asyncio
4+
5+
import pytest
6+
7+
from roborock.devices.pending import PendingRpcs
8+
from roborock.exceptions import RoborockException
9+
10+
11+
@pytest.fixture(name="pending_rpcs")
12+
def setup_pending_rpcs() -> PendingRpcs[int, str]:
13+
"""Fixture to set up the PendingRpcs for tests."""
14+
return PendingRpcs[int, str]()
15+
16+
17+
async def test_start_duplicate_rpc_raises_exception(pending_rpcs: PendingRpcs[int, str]) -> None:
18+
"""Test that starting a duplicate RPC raises an exception."""
19+
key = 1
20+
await pending_rpcs.start(key)
21+
with pytest.raises(RoborockException, match=f"Request ID {key} already pending, cannot send command"):
22+
await pending_rpcs.start(key)
23+
24+
25+
async def test_resolve_pending_rpc(pending_rpcs: PendingRpcs[int, str]) -> None:
26+
"""Test resolving a pending RPC."""
27+
key = 1
28+
value = "test_result"
29+
future = await pending_rpcs.start(key)
30+
await pending_rpcs.resolve(key, value)
31+
result = await future
32+
assert result == value
33+
34+
35+
async def test_resolve_unsolicited_message(
36+
pending_rpcs: PendingRpcs[int, str], caplog: pytest.LogCaptureFixture
37+
) -> None:
38+
"""Test resolving an unsolicited message does not raise."""
39+
key = 1
40+
value = "test_result"
41+
await pending_rpcs.resolve(key, value)
42+
43+
44+
async def test_pop_pending_rpc(pending_rpcs: PendingRpcs[int, str]) -> None:
45+
"""Test popping a pending RPC, which should cancel the future."""
46+
key = 1
47+
future = await pending_rpcs.start(key)
48+
await pending_rpcs.pop(key)
49+
with pytest.raises(asyncio.CancelledError):
50+
await future
51+
52+
53+
async def test_pop_non_existent_rpc(pending_rpcs: PendingRpcs[int, str]) -> None:
54+
"""Test that popping a non-existent RPC does not raise an exception."""
55+
key = 1
56+
await pending_rpcs.pop(key)
57+
58+
59+
async def test_concurrent_rpcs(pending_rpcs: PendingRpcs[int, str]) -> None:
60+
"""Test handling multiple concurrent RPCs."""
61+
62+
async def start_and_resolve(key: int, value: str) -> str:
63+
future = await pending_rpcs.start(key)
64+
await asyncio.sleep(0.01) # yield
65+
await pending_rpcs.resolve(key, value)
66+
return await future
67+
68+
tasks = [
69+
asyncio.create_task(start_and_resolve(1, "result1")),
70+
asyncio.create_task(start_and_resolve(2, "result2")),
71+
asyncio.create_task(start_and_resolve(3, "result3")),
72+
]
73+
74+
results = await asyncio.gather(*tasks)
75+
assert results == ["result1", "result2", "result3"]

0 commit comments

Comments
 (0)