3
3
from asyncio import Lock , TimerHandle , Transport , get_running_loop
4
4
from collections .abc import Callable
5
5
from dataclasses import dataclass
6
+ from enum import StrEnum
6
7
7
8
import async_timeout
8
9
18
19
_LOGGER = logging .getLogger (__name__ )
19
20
20
21
22
+ class LocalProtocolVersion (StrEnum ):
23
+ """Supported local protocol versions. Different from vacuum protocol versions."""
24
+
25
+ L01 = "L01"
26
+ V1 = "1.0"
27
+
28
+
21
29
@dataclass
22
30
class _LocalProtocol (asyncio .Protocol ):
23
31
"""Callbacks for the Roborock local client transport."""
@@ -37,7 +45,12 @@ def connection_lost(self, exc: Exception | None) -> None:
37
45
class RoborockLocalClientV1 (RoborockClientV1 , RoborockClient ):
38
46
"""Roborock local client for v1 devices."""
39
47
40
- def __init__ (self , device_data : DeviceData , queue_timeout : int = 4 , version : str | None = None ):
48
+ def __init__ (
49
+ self ,
50
+ device_data : DeviceData ,
51
+ queue_timeout : int = 4 ,
52
+ local_protocol_version : LocalProtocolVersion | None = None ,
53
+ ):
41
54
"""Initialize the Roborock local client."""
42
55
if device_data .host is None :
43
56
raise RoborockException ("Host is required" )
@@ -50,13 +63,17 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4, version: str
50
63
RoborockClientV1 .__init__ (self , device_data , security_data = None )
51
64
RoborockClient .__init__ (self , device_data )
52
65
self ._local_protocol = _LocalProtocol (self ._data_received , self ._connection_lost )
53
- self ._version = version
66
+ self ._local_protocol_version = local_protocol_version
54
67
self ._connect_nonce = get_next_int (10000 , 32767 )
55
68
self ._ack_nonce : int | None = None
56
69
self ._set_encoder_decoder ()
57
70
self .queue_timeout = queue_timeout
58
71
self ._logger = RoborockLoggerAdapter (device_data .device .name , _LOGGER )
59
72
73
+ @property
74
+ def local_protocol_version (self ) -> LocalProtocolVersion :
75
+ return LocalProtocolVersion .V1 if self ._local_protocol_version is None else self ._local_protocol_version
76
+
60
77
def _data_received (self , message ):
61
78
"""Called when data is received from the transport."""
62
79
parsed_msg = self ._decoder (message )
@@ -111,16 +128,21 @@ async def async_disconnect(self) -> None:
111
128
self ._sync_disconnect ()
112
129
113
130
def _set_encoder_decoder (self ):
114
- """Updates the encoder decoder. For L01 these are updated with nonces after the first hello."""
131
+ """Updates the encoder decoder. These are updated with nonces after the first hello.
132
+ Only L01 uses the nonces."""
115
133
self ._encoder = create_local_encoder (self .device_info .device .local_key , self ._connect_nonce , self ._ack_nonce )
116
134
self ._decoder = create_local_decoder (self .device_info .device .local_key , self ._connect_nonce , self ._ack_nonce )
117
135
118
- async def _do_hello (self , version : str ) -> bool :
136
+ async def _do_hello (self , local_protocol_version : LocalProtocolVersion ) -> bool :
119
137
"""Perform the initial handshaking."""
120
- self ._logger .debug (f"Attempting to use the { version } protocol for client { self .device_info .device .duid } ..." )
138
+ self ._logger .debug (
139
+ "Attempting to use the %s protocol for client %s..." ,
140
+ local_protocol_version ,
141
+ self .device_info .device .duid ,
142
+ )
121
143
request = RoborockMessage (
122
144
protocol = RoborockMessageProtocol .HELLO_REQUEST ,
123
- version = version .encode (),
145
+ version = local_protocol_version .encode (),
124
146
random = self ._connect_nonce ,
125
147
seq = 1 ,
126
148
)
@@ -132,31 +154,39 @@ async def _do_hello(self, version: str) -> bool:
132
154
)
133
155
self ._ack_nonce = response .random
134
156
self ._set_encoder_decoder ()
135
- self ._version = version
136
- self ._logger .debug (f"Client { self .device_info .device .duid } speaks the { version } protocol." )
157
+ self ._local_protocol_version = local_protocol_version
158
+
159
+ self ._logger .debug (
160
+ "Client %s speaks the %s protocol." ,
161
+ self .device_info .device .duid ,
162
+ local_protocol_version ,
163
+ )
137
164
return True
138
165
except RoborockException as e :
139
166
self ._logger .debug (
140
- f"Client { self .device_info .device .duid } did not respond or does not speak the { version } protocol. { e } "
167
+ "Client %s did not respond or does not speak the %s protocol. %s" ,
168
+ self .device_info .device .duid ,
169
+ local_protocol_version ,
170
+ e ,
141
171
)
142
172
return False
143
173
144
174
async def hello (self ):
145
175
"""Send hello to the device to negotiate protocol."""
146
- if self ._version :
176
+ if self ._local_protocol_version :
147
177
# version is forced
148
- if not await self ._do_hello (self ._version ):
149
- raise RoborockException (f"Failed to connect to device with protocol { self ._version } " )
178
+ if not await self ._do_hello (self ._local_protocol_version ):
179
+ raise RoborockException (f"Failed to connect to device with protocol { self ._local_protocol_version } " )
150
180
else :
151
181
# try 1.0, then L01
152
- if not await self ._do_hello ("1.0" ):
153
- if not await self ._do_hello (" L01" ):
182
+ if not await self ._do_hello (LocalProtocolVersion . V1 ):
183
+ if not await self ._do_hello (LocalProtocolVersion . L01 ):
154
184
raise RoborockException ("Failed to connect to device with any known protocol" )
155
185
156
186
async def ping (self ) -> None :
157
- # Realistically, this should be set here, but this is to be safe and for typing.
158
- version = b"1.0" if self . _version is None else self ._version .encode ()
159
- ping_message = RoborockMessage ( protocol = RoborockMessageProtocol . PING_REQUEST , version = version )
187
+ ping_message = RoborockMessage (
188
+ protocol = RoborockMessageProtocol . PING_REQUEST , version = self .local_protocol_version .encode ()
189
+ )
160
190
await self ._send_message (
161
191
roborock_message = ping_message ,
162
192
request_id = ping_message .seq ,
@@ -180,7 +210,8 @@ async def _send_command(
180
210
raise RoborockException (f"Method { method } is not supported over local connection" )
181
211
request_message = RequestMessage (method = method , params = params )
182
212
roborock_message = request_message .encode_message (
183
- RoborockMessageProtocol .GENERAL_REQUEST , version = self ._version if self ._version is not None else "1.0"
213
+ RoborockMessageProtocol .GENERAL_REQUEST ,
214
+ version = self .local_protocol_version ,
184
215
)
185
216
self ._logger .debug ("Building message id %s for method %s" , request_message .request_id , method )
186
217
return await self ._send_message (
0 commit comments