Skip to content

Commit cbd6df2

Browse files
authored
chore: Refactor some of the internal channel details used by the device. (#424)
* chore: Separate V1 API connection logic from encoding logic * chore: Remove unnecessary command * chore: rename rpc channels to have v1 in the name
1 parent 717654a commit cbd6df2

File tree

9 files changed

+252
-126
lines changed

9 files changed

+252
-126
lines changed

roborock/devices/device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,4 @@ async def get_status(self) -> Status:
117117
This is a placeholder command and will likely be changed/moved in the future.
118118
"""
119119
status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus)
120-
return await self._v1_channel.send_decoded_command(RoborockCommand.GET_STATUS, response_type=status_type)
120+
return await self._v1_channel.rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type)

roborock/devices/local_channel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def __init__(self, host: str, local_key: str):
5050
self._encoder: Encoder = create_local_encoder(local_key)
5151
self._queue_lock = asyncio.Lock()
5252

53+
@property
54+
def is_connected(self) -> bool:
55+
"""Check if the channel is currently connected."""
56+
return self._is_connected
57+
5358
async def connect(self) -> None:
5459
"""Connect to the device."""
5560
if self._is_connected:
@@ -113,7 +118,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
113118
else:
114119
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
115120

116-
async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
121+
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
117122
"""Send a command message and wait for the response message."""
118123
if not self._transport or not self._is_connected:
119124
raise RoborockConnectionException("Not connected to device")

roborock/devices/mqtt_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
8080
else:
8181
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
8282

83-
async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
83+
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
8484
"""Send a command message and wait for the response message.
8585
8686
Returns the raw response message - caller is responsible for parsing.

roborock/devices/v1_channel.py

Lines changed: 18 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,21 @@
66

77
import logging
88
from collections.abc import Callable
9-
from typing import Any, TypeVar
9+
from typing import TypeVar
1010

1111
from roborock.containers import HomeDataDevice, NetworkInfo, RoborockBase, UserData
1212
from roborock.exceptions import RoborockException
1313
from roborock.mqtt.session import MqttParams, MqttSession
1414
from roborock.protocols.v1_protocol import (
15-
CommandType,
16-
ParamsType,
1715
SecurityData,
18-
create_mqtt_payload_encoder,
1916
create_security_data,
20-
decode_rpc_response,
21-
encode_local_payload,
2217
)
2318
from roborock.roborock_message import RoborockMessage
2419
from roborock.roborock_typing import RoborockCommand
2520

2621
from .local_channel import LocalChannel, LocalSession, create_local_session
2722
from .mqtt_channel import MqttChannel
23+
from .v1_rpc_channel import V1RpcChannel, create_combined_rpc_channel, create_mqtt_rpc_channel
2824

2925
_LOGGER = logging.getLogger(__name__)
3026

@@ -58,9 +54,10 @@ def __init__(
5854
"""
5955
self._device_uid = device_uid
6056
self._mqtt_channel = mqtt_channel
61-
self._mqtt_payload_encoder = create_mqtt_payload_encoder(security_data)
57+
self._mqtt_rpc_channel = create_mqtt_rpc_channel(mqtt_channel, security_data)
6258
self._local_session = local_session
6359
self._local_channel: LocalChannel | None = None
60+
self._combined_rpc_channel: V1RpcChannel | None = None
6461
self._mqtt_unsub: Callable[[], None] | None = None
6562
self._local_unsub: Callable[[], None] | None = None
6663
self._callback: Callable[[RoborockMessage], None] | None = None
@@ -76,6 +73,16 @@ def is_mqtt_connected(self) -> bool:
7673
"""Return whether MQTT connection is available."""
7774
return self._mqtt_unsub is not None
7875

76+
@property
77+
def rpc_channel(self) -> V1RpcChannel:
78+
"""Return the combined RPC channel prefers local with a fallback to MQTT."""
79+
return self._combined_rpc_channel or self._mqtt_rpc_channel
80+
81+
@property
82+
def mqtt_rpc_channel(self) -> V1RpcChannel:
83+
"""Return the MQTT RPC channel."""
84+
return self._mqtt_rpc_channel
85+
7986
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
8087
"""Subscribe to all messages from the device.
8188
@@ -119,7 +126,9 @@ async def _get_networking_info(self) -> NetworkInfo:
119126
This is a cloud only command used to get the local device's IP address.
120127
"""
121128
try:
122-
return await self._send_mqtt_decoded_command(RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo)
129+
return await self._mqtt_rpc_channel.send_command(
130+
RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo
131+
)
123132
except RoborockException as e:
124133
raise RoborockException(f"Network info failed for device {self._device_uid}") from e
125134

@@ -136,59 +145,9 @@ async def _local_connect(self) -> Callable[[], None]:
136145
except RoborockException as e:
137146
self._local_channel = None
138147
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
139-
148+
self._combined_rpc_channel = create_combined_rpc_channel(self._local_channel, self._mqtt_rpc_channel)
140149
return await self._local_channel.subscribe(self._on_local_message)
141150

142-
async def send_decoded_command(
143-
self,
144-
method: CommandType,
145-
*,
146-
response_type: type[_T],
147-
params: ParamsType = None,
148-
) -> _T:
149-
"""Send a command using the best available transport.
150-
151-
Will prefer local connection if available, falling back to MQTT.
152-
"""
153-
connection = "local" if self.is_local_connected else "mqtt"
154-
_LOGGER.debug("Sending command (%s): %s, params=%s", connection, method, params)
155-
if self._local_channel:
156-
return await self._send_local_decoded_command(method, response_type=response_type, params=params)
157-
return await self._send_mqtt_decoded_command(method, response_type=response_type, params=params)
158-
159-
async def _send_mqtt_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]:
160-
"""Send a raw command and return a raw unparsed response."""
161-
message = self._mqtt_payload_encoder(method, params)
162-
_LOGGER.debug("Sending MQTT message for device %s: %s", self._device_uid, message)
163-
response = await self._mqtt_channel.send_command(message)
164-
return decode_rpc_response(response)
165-
166-
async def _send_mqtt_decoded_command(
167-
self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None
168-
) -> _T:
169-
"""Send a command over MQTT and decode the response."""
170-
decoded_response = await self._send_mqtt_raw_command(method, params)
171-
return response_type.from_dict(decoded_response)
172-
173-
async def _send_local_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]:
174-
"""Send a raw command over local connection."""
175-
if not self._local_channel:
176-
raise RoborockException("Local channel is not connected")
177-
178-
message = encode_local_payload(method, params)
179-
_LOGGER.debug("Sending local message for device %s: %s", self._device_uid, message)
180-
response = await self._local_channel.send_command(message)
181-
return decode_rpc_response(response)
182-
183-
async def _send_local_decoded_command(
184-
self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None
185-
) -> _T:
186-
"""Send a command over local connection and decode the response."""
187-
if not self._local_channel:
188-
raise RoborockException("Local channel is not connected")
189-
decoded_response = await self._send_local_raw_command(method, params)
190-
return response_type.from_dict(decoded_response)
191-
192151
def _on_mqtt_message(self, message: RoborockMessage) -> None:
193152
"""Handle incoming MQTT messages."""
194153
_LOGGER.debug("V1Channel received MQTT message from device %s: %s", self._device_uid, message)

roborock/devices/v1_rpc_channel.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""V1 Rpc Channel for Roborock devices.
2+
3+
This is a wrapper around the V1 channel that provides a higher level interface
4+
for sending typed commands and receiving typed responses. This also provides
5+
a simple interface for sending commands and receiving responses over both MQTT
6+
and local connections, preferring local when available.
7+
"""
8+
9+
import logging
10+
from collections.abc import Callable
11+
from typing import Any, Protocol, TypeVar, overload
12+
13+
from roborock.containers import RoborockBase
14+
from roborock.protocols.v1_protocol import (
15+
CommandType,
16+
ParamsType,
17+
SecurityData,
18+
create_mqtt_payload_encoder,
19+
decode_rpc_response,
20+
encode_local_payload,
21+
)
22+
from roborock.roborock_message import RoborockMessage
23+
24+
from .local_channel import LocalChannel
25+
from .mqtt_channel import MqttChannel
26+
27+
_LOGGER = logging.getLogger(__name__)
28+
29+
30+
_T = TypeVar("_T", bound=RoborockBase)
31+
32+
33+
class V1RpcChannel(Protocol):
34+
"""Protocol for V1 RPC channels.
35+
36+
This is a wrapper around a raw channel that provides a high-level interface
37+
for sending commands and receiving responses.
38+
"""
39+
40+
@overload
41+
async def send_command(
42+
self,
43+
method: CommandType,
44+
*,
45+
params: ParamsType = None,
46+
) -> Any:
47+
"""Send a command and return a decoded response."""
48+
...
49+
50+
@overload
51+
async def send_command(
52+
self,
53+
method: CommandType,
54+
*,
55+
response_type: type[_T],
56+
params: ParamsType = None,
57+
) -> _T:
58+
"""Send a command and return a parsed response RoborockBase type."""
59+
...
60+
61+
62+
class BaseV1RpcChannel(V1RpcChannel):
63+
"""Base implementation that provides the typed response logic."""
64+
65+
async def send_command(
66+
self,
67+
method: CommandType,
68+
*,
69+
response_type: type[_T] | None = None,
70+
params: ParamsType = None,
71+
) -> _T | Any:
72+
"""Send a command and return either a decoded or parsed response."""
73+
decoded_response = await self._send_raw_command(method, params=params)
74+
75+
if response_type is not None:
76+
return response_type.from_dict(decoded_response)
77+
return decoded_response
78+
79+
async def _send_raw_command(
80+
self,
81+
method: CommandType,
82+
*,
83+
params: ParamsType = None,
84+
) -> Any:
85+
"""Send a raw command and return the decoded response. Must be implemented by subclasses."""
86+
raise NotImplementedError
87+
88+
89+
class CombinedV1RpcChannel(BaseV1RpcChannel):
90+
"""A V1 RPC channel that can use both local and MQTT channels, preferring local when available."""
91+
92+
def __init__(
93+
self, local_channel: LocalChannel, local_rpc_channel: V1RpcChannel, mqtt_channel: V1RpcChannel
94+
) -> None:
95+
"""Initialize the combined channel with local and MQTT channels."""
96+
self._local_channel = local_channel
97+
self._local_rpc_channel = local_rpc_channel
98+
self._mqtt_rpc_channel = mqtt_channel
99+
100+
async def _send_raw_command(
101+
self,
102+
method: CommandType,
103+
*,
104+
params: ParamsType = None,
105+
) -> Any:
106+
"""Send a command and return a parsed response RoborockBase type."""
107+
if self._local_channel.is_connected:
108+
return await self._local_rpc_channel.send_command(method, params=params)
109+
return await self._mqtt_rpc_channel.send_command(method, params=params)
110+
111+
112+
class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
113+
"""Protocol for V1 channels that send encoded commands."""
114+
115+
def __init__(
116+
self,
117+
name: str,
118+
channel: MqttChannel | LocalChannel,
119+
payload_encoder: Callable[[CommandType, ParamsType], RoborockMessage],
120+
) -> None:
121+
"""Initialize the channel with a raw channel and an encoder function."""
122+
self._name = name
123+
self._channel = channel
124+
self._payload_encoder = payload_encoder
125+
126+
async def _send_raw_command(
127+
self,
128+
method: CommandType,
129+
*,
130+
params: ParamsType = None,
131+
) -> Any:
132+
"""Send a command and return a parsed response RoborockBase type."""
133+
_LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
134+
message = self._payload_encoder(method, params)
135+
response = await self._channel.send_message(message)
136+
return decode_rpc_response(response)
137+
138+
139+
def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
140+
"""Create a V1 RPC channel using an MQTT channel."""
141+
payload_encoder = create_mqtt_payload_encoder(security_data)
142+
return PayloadEncodedV1RpcChannel("mqtt", mqtt_channel, payload_encoder)
143+
144+
145+
def create_combined_rpc_channel(local_channel: LocalChannel, mqtt_rpc_channel: V1RpcChannel) -> V1RpcChannel:
146+
"""Create a V1 RPC channel that combines local and MQTT channels."""
147+
local_rpc_channel = PayloadEncodedV1RpcChannel("local", local_channel, encode_local_payload)
148+
return CombinedV1RpcChannel(local_channel, local_rpc_channel, mqtt_rpc_channel)

tests/devices/test_device.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ async def test_device_connection(device: RoborockDevice, channel: AsyncMock) ->
5656
async def test_device_get_status_command(device: RoborockDevice, channel: AsyncMock) -> None:
5757
"""Test the device get_status command."""
5858
# Mock response for get_status command
59-
channel.send_decoded_command.return_value = STATUS
59+
channel.rpc_channel.send_command.return_value = STATUS
6060

6161
# Test get_status and verify the command was sent
6262
status = await device.get_status()
63-
assert channel.send_decoded_command.called
63+
assert channel.rpc_channel.send_command.called
6464

6565
# Verify the result
6666
assert status is not None

0 commit comments

Comments
 (0)