Skip to content

Commit f4dcea5

Browse files
authored
chore: remove dependencies on get_request_id in RequestMessage (#452)
1 parent d5d79b6 commit f4dcea5

File tree

5 files changed

+64
-68
lines changed

5 files changed

+64
-68
lines changed

roborock/devices/v1_rpc_channel.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414
from roborock.protocols.v1_protocol import (
1515
CommandType,
1616
ParamsType,
17+
RequestMessage,
1718
SecurityData,
18-
create_mqtt_payload_encoder,
1919
decode_rpc_response,
20-
encode_local_payload,
2120
)
22-
from roborock.roborock_message import RoborockMessage
21+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
2322

2423
from .local_channel import LocalChannel
2524
from .mqtt_channel import MqttChannel
@@ -116,7 +115,7 @@ def __init__(
116115
self,
117116
name: str,
118117
channel: MqttChannel | LocalChannel,
119-
payload_encoder: Callable[[CommandType, ParamsType], RoborockMessage],
118+
payload_encoder: Callable[[RequestMessage], RoborockMessage],
120119
) -> None:
121120
"""Initialize the channel with a raw channel and an encoder function."""
122121
self._name = name
@@ -131,18 +130,26 @@ async def _send_raw_command(
131130
) -> Any:
132131
"""Send a command and return a parsed response RoborockBase type."""
133132
_LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
134-
message = self._payload_encoder(method, params)
133+
request_message = RequestMessage(method, params=params)
134+
message = self._payload_encoder(request_message)
135135
response = await self._channel.send_message(message)
136136
return decode_rpc_response(response)
137137

138138

139139
def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
140140
"""Create a V1 RPC channel using an MQTT channel."""
141-
payload_encoder = create_mqtt_payload_encoder(security_data)
142-
return PayloadEncodedV1RpcChannel("mqtt", mqtt_channel, payload_encoder)
141+
return PayloadEncodedV1RpcChannel(
142+
"mqtt",
143+
mqtt_channel,
144+
lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data),
145+
)
143146

144147

145148
def create_combined_rpc_channel(local_channel: LocalChannel, mqtt_rpc_channel: V1RpcChannel) -> V1RpcChannel:
146149
"""Create a V1 RPC channel that combines local and MQTT channels."""
147-
local_rpc_channel = PayloadEncodedV1RpcChannel("local", local_channel, encode_local_payload)
150+
local_rpc_channel = PayloadEncodedV1RpcChannel(
151+
"local",
152+
local_channel,
153+
lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST),
154+
)
148155
return CombinedV1RpcChannel(local_channel, local_rpc_channel, mqtt_rpc_channel)

roborock/protocols/v1_protocol.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
__all__ = [
2626
"SecurityData",
2727
"create_security_data",
28-
"create_mqtt_payload_encoder",
29-
"encode_local_payload",
3028
"decode_rpc_response",
3129
]
3230

@@ -66,7 +64,19 @@ class RequestMessage:
6664
timestamp: int = field(default_factory=lambda: math.floor(time.time()))
6765
request_id: int = field(default_factory=lambda: get_next_int(10000, 32767))
6866

69-
def as_payload(self, security_data: SecurityData | None) -> bytes:
67+
def encode_message(
68+
self,
69+
protocol: RoborockMessageProtocol,
70+
security_data: SecurityData | None = None,
71+
) -> RoborockMessage:
72+
"""Convert the request message to a RoborockMessage."""
73+
return RoborockMessage(
74+
timestamp=self.timestamp,
75+
protocol=protocol,
76+
payload=self._as_payload(security_data=security_data),
77+
)
78+
79+
def _as_payload(self, security_data: SecurityData | None) -> bytes:
7080
"""Convert the request arguments to a dictionary."""
7181
inner = {
7282
"id": self.request_id,
@@ -85,35 +95,6 @@ def as_payload(self, security_data: SecurityData | None) -> bytes:
8595
)
8696

8797

88-
def create_mqtt_payload_encoder(security_data: SecurityData) -> Callable[[CommandType, ParamsType], RoborockMessage]:
89-
"""Create a payload encoder for V1 commands over MQTT."""
90-
91-
def _get_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
92-
"""Build the payload for a V1 command."""
93-
request = RequestMessage(method=method, params=params)
94-
payload = request.as_payload(security_data) # always secure
95-
return RoborockMessage(
96-
timestamp=request.timestamp,
97-
protocol=RoborockMessageProtocol.RPC_REQUEST,
98-
payload=payload,
99-
)
100-
101-
return _get_payload
102-
103-
104-
def encode_local_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
105-
"""Encode payload for V1 commands over local connection."""
106-
107-
request = RequestMessage(method=method, params=params)
108-
payload = request.as_payload(security_data=None)
109-
110-
return RoborockMessage(
111-
timestamp=request.timestamp,
112-
protocol=RoborockMessageProtocol.GENERAL_REQUEST,
113-
payload=payload,
114-
)
115-
116-
11798
def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
11899
"""Decode a V1 RPC_RESPONSE message."""
119100
if not message.payload:

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..api import RoborockClient
1111
from ..exceptions import RoborockConnectionException, RoborockException, VacuumError
1212
from ..protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
13-
from ..protocols.v1_protocol import encode_local_payload
13+
from ..protocols.v1_protocol import RequestMessage
1414
from ..roborock_message import RoborockMessage, RoborockMessageProtocol
1515
from ..util import RoborockLoggerAdapter
1616
from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1
@@ -123,12 +123,20 @@ async def async_disconnect(self) -> None:
123123

124124
async def hello(self):
125125
try:
126-
return await self._send_message(_HELLO_REQUEST_MESSAGE)
126+
return await self._send_message(
127+
roborock_message=_HELLO_REQUEST_MESSAGE,
128+
request_id=_HELLO_REQUEST_MESSAGE.seq,
129+
response_protocol=RoborockMessageProtocol.HELLO_RESPONSE,
130+
)
127131
except Exception as e:
128132
self._logger.error(e)
129133

130134
async def ping(self) -> None:
131-
await self._send_message(_PING_REQUEST_MESSAGE)
135+
await self._send_message(
136+
roborock_message=_PING_REQUEST_MESSAGE,
137+
request_id=_PING_REQUEST_MESSAGE.seq,
138+
response_protocol=RoborockMessageProtocol.PING_RESPONSE,
139+
)
132140

133141
def _send_msg_raw(self, data: bytes):
134142
try:
@@ -145,27 +153,26 @@ async def _send_command(
145153
):
146154
if method in CLOUD_REQUIRED:
147155
raise RoborockException(f"Method {method} is not supported over local connection")
148-
149-
roborock_message = encode_local_payload(method, params)
150-
self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id(), method)
151-
return await self._send_message(roborock_message, method, params)
156+
request_message = RequestMessage(method=method, params=params)
157+
roborock_message = request_message.encode_message(RoborockMessageProtocol.GENERAL_REQUEST)
158+
self._logger.debug("Building message id %s for method %s", request_message.request_id, method)
159+
return await self._send_message(
160+
roborock_message,
161+
request_id=request_message.request_id,
162+
response_protocol=RoborockMessageProtocol.GENERAL_REQUEST,
163+
method=method,
164+
params=params,
165+
)
152166

153167
async def _send_message(
154168
self,
155169
roborock_message: RoborockMessage,
170+
request_id: int,
171+
response_protocol: int,
156172
method: str | None = None,
157173
params: list | dict | int | None = None,
158174
) -> RoborockMessage:
159175
await self.validate_connection()
160-
request_id: int | None
161-
if not method or not method.startswith("get"):
162-
request_id = roborock_message.seq
163-
response_protocol = request_id + 1
164-
else:
165-
request_id = roborock_message.get_request_id()
166-
response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
167-
if request_id is None:
168-
raise RoborockException(f"Failed build message {roborock_message}")
169176
msg = self._encoder(roborock_message)
170177
if method:
171178
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ..containers import DeviceData, UserData
1111
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
12-
from ..protocols.v1_protocol import create_mqtt_payload_encoder, create_security_data
12+
from ..protocols.v1_protocol import RequestMessage, create_security_data
1313
from ..roborock_message import (
1414
RoborockMessageProtocol,
1515
)
@@ -28,12 +28,12 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
2828
rriot = user_data.rriot
2929
if rriot is None:
3030
raise RoborockException("Got no rriot data from user_data")
31-
security_data = create_security_data(rriot)
3231
RoborockMqttClient.__init__(self, user_data, device_info)
32+
security_data = create_security_data(rriot)
3333
RoborockClientV1.__init__(self, device_info, security_data=security_data)
3434
self.queue_timeout = queue_timeout
3535
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
36-
self._payload_encoder = create_mqtt_payload_encoder(security_data)
36+
self._security_data = security_data
3737

3838
async def _send_command(
3939
self,
@@ -44,13 +44,15 @@ async def _send_command(
4444
# When we have more custom commands do something more complicated here
4545
return await self._get_calibration_points()
4646

47-
roborock_message = self._payload_encoder(method, params)
47+
request_message = RequestMessage(method=method, params=params)
48+
roborock_message = request_message.encode_message(
49+
RoborockMessageProtocol.RPC_REQUEST,
50+
security_data=self._security_data,
51+
)
4852
self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id, method)
4953

5054
await self.validate_connection()
51-
request_id = roborock_message.get_request_id()
52-
if request_id is None:
53-
raise RoborockException(f"Failed build message {roborock_message}")
55+
request_id = request_message.request_id
5456
response_protocol = (
5557
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
5658
)

tests/protocols/test_v1_protocol.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
from roborock.exceptions import RoborockException
1111
from roborock.protocol import Utils
1212
from roborock.protocols.v1_protocol import (
13+
RequestMessage,
1314
SecurityData,
1415
create_map_response_decoder,
15-
create_mqtt_payload_encoder,
1616
decode_rpc_response,
17-
encode_local_payload,
1817
)
1918
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
2019
from roborock.roborock_typing import RoborockCommand
@@ -58,7 +57,7 @@ def request_id_fixture() -> Generator[int, None, None]:
5857
)
5958
def test_encode_local_payload(command, params, expected):
6059
"""Test encoding of local payload for V1 commands."""
61-
message = encode_local_payload(command, params)
60+
message = RequestMessage(command, params).encode_message(RoborockMessageProtocol.GENERAL_REQUEST)
6261
assert isinstance(message, RoborockMessage)
6362
assert message.protocol == RoborockMessageProtocol.GENERAL_REQUEST
6463
assert message.payload == expected
@@ -76,8 +75,8 @@ def test_encode_local_payload(command, params, expected):
7675
)
7776
def test_encode_mqtt_payload(command, params, expected):
7877
"""Test encoding of local payload for V1 commands."""
79-
encoder = create_mqtt_payload_encoder(SECURITY_DATA)
80-
message = encoder(command, params)
78+
request_message = RequestMessage(command, params=params)
79+
message = request_message.encode_message(RoborockMessageProtocol.RPC_REQUEST, SECURITY_DATA)
8180
assert isinstance(message, RoborockMessage)
8281
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
8382
assert message.payload == expected

0 commit comments

Comments
 (0)