1
+ import asyncio
1
2
import logging
3
+ from asyncio import Lock , TimerHandle , Transport , get_running_loop
4
+ from collections .abc import Callable
5
+ from dataclasses import dataclass
2
6
3
- from roborock . local_api import RoborockLocalClient
7
+ import async_timeout
4
8
5
- from .. import CommandVacuumError , DeviceData , RoborockCommand , RoborockException
6
- from ..exceptions import VacuumError
9
+ from .. import CommandVacuumError , DeviceData , RoborockCommand
10
+ from ..api import RoborockClient
11
+ from ..exceptions import RoborockConnectionException , RoborockException , VacuumError
12
+ from ..protocol import Decoder , Encoder , create_local_decoder , create_local_encoder
7
13
from ..protocols .v1_protocol import encode_local_payload
8
14
from ..roborock_message import RoborockMessage , RoborockMessageProtocol
9
15
from ..util import RoborockLoggerAdapter
12
18
_LOGGER = logging .getLogger (__name__ )
13
19
14
20
15
- class RoborockLocalClientV1 (RoborockLocalClient , RoborockClientV1 ):
21
+ @dataclass
22
+ class _LocalProtocol (asyncio .Protocol ):
23
+ """Callbacks for the Roborock local client transport."""
24
+
25
+ messages_cb : Callable [[bytes ], None ]
26
+ connection_lost_cb : Callable [[Exception | None ], None ]
27
+
28
+ def data_received (self , bytes ) -> None :
29
+ """Called when data is received from the transport."""
30
+ self .messages_cb (bytes )
31
+
32
+ def connection_lost (self , exc : Exception | None ) -> None :
33
+ """Called when the transport connection is lost."""
34
+ self .connection_lost_cb (exc )
35
+
36
+
37
+ class RoborockLocalClientV1 (RoborockClientV1 , RoborockClient ):
16
38
"""Roborock local client for v1 devices."""
17
39
18
40
def __init__ (self , device_data : DeviceData , queue_timeout : int = 4 ):
19
41
"""Initialize the Roborock local client."""
20
- RoborockLocalClient .__init__ (self , device_data )
42
+ if device_data .host is None :
43
+ raise RoborockException ("Host is required" )
44
+ self .host = device_data .host
45
+ self ._batch_structs : list [RoborockMessage ] = []
46
+ self ._executing = False
47
+ self .transport : Transport | None = None
48
+ self ._mutex = Lock ()
49
+ self .keep_alive_task : TimerHandle | None = None
21
50
RoborockClientV1 .__init__ (self , device_data , "abc" )
51
+ RoborockClient .__init__ (self , device_data )
52
+ self ._local_protocol = _LocalProtocol (self ._data_received , self ._connection_lost )
53
+ self ._encoder : Encoder = create_local_encoder (device_data .device .local_key )
54
+ self ._decoder : Decoder = create_local_decoder (device_data .device .local_key )
22
55
self .queue_timeout = queue_timeout
23
56
self ._logger = RoborockLoggerAdapter (device_data .device .name , _LOGGER )
24
57
58
+ def _data_received (self , message ):
59
+ """Called when data is received from the transport."""
60
+ parsed_msg = self ._decoder (message )
61
+ self .on_message_received (parsed_msg )
62
+
63
+ def _connection_lost (self , exc : Exception | None ):
64
+ """Called when the transport connection is lost."""
65
+ self ._sync_disconnect ()
66
+ self .on_connection_lost (exc )
67
+
68
+ def is_connected (self ):
69
+ return self .transport and self .transport .is_reading ()
70
+
71
+ async def keep_alive_func (self , _ = None ):
72
+ try :
73
+ await self .ping ()
74
+ except RoborockException :
75
+ pass
76
+ loop = asyncio .get_running_loop ()
77
+ self .keep_alive_task = loop .call_later (10 , lambda : asyncio .create_task (self .keep_alive_func ()))
78
+
79
+ async def async_connect (self ) -> None :
80
+ should_ping = False
81
+ async with self ._mutex :
82
+ try :
83
+ if not self .is_connected ():
84
+ self ._sync_disconnect ()
85
+ async with async_timeout .timeout (self .queue_timeout ):
86
+ self ._logger .debug (f"Connecting to { self .host } " )
87
+ loop = get_running_loop ()
88
+ self .transport , _ = await loop .create_connection ( # type: ignore
89
+ lambda : self ._local_protocol , self .host , 58867
90
+ )
91
+ self ._logger .info (f"Connected to { self .host } " )
92
+ should_ping = True
93
+ except BaseException as e :
94
+ raise RoborockConnectionException (f"Failed connecting to { self .host } " ) from e
95
+ if should_ping :
96
+ await self .hello ()
97
+ await self .keep_alive_func ()
98
+
99
+ def _sync_disconnect (self ) -> None :
100
+ loop = asyncio .get_running_loop ()
101
+ if self .transport and loop .is_running ():
102
+ self ._logger .debug (f"Disconnecting from { self .host } " )
103
+ self .transport .close ()
104
+ if self .keep_alive_task :
105
+ self .keep_alive_task .cancel ()
106
+
107
+ async def async_disconnect (self ) -> None :
108
+ async with self ._mutex :
109
+ self ._sync_disconnect ()
110
+
111
+ async def hello (self ):
112
+ request_id = 1
113
+ protocol = RoborockMessageProtocol .HELLO_REQUEST
114
+ try :
115
+ return await self ._send_message (
116
+ RoborockMessage (
117
+ protocol = protocol ,
118
+ seq = request_id ,
119
+ random = 22 ,
120
+ )
121
+ )
122
+ except Exception as e :
123
+ self ._logger .error (e )
124
+
125
+ async def ping (self ) -> None :
126
+ request_id = 2
127
+ protocol = RoborockMessageProtocol .PING_REQUEST
128
+ return await self ._send_message (
129
+ RoborockMessage (
130
+ protocol = protocol ,
131
+ seq = request_id ,
132
+ random = 23 ,
133
+ )
134
+ )
135
+
136
+ def _send_msg_raw (self , data : bytes ):
137
+ try :
138
+ if not self .transport :
139
+ raise RoborockException ("Can not send message without connection" )
140
+ self .transport .write (data )
141
+ except Exception as e :
142
+ raise RoborockException (e ) from e
143
+
25
144
async def _send_command (
26
145
self ,
27
146
method : RoborockCommand | str ,
@@ -32,9 +151,9 @@ async def _send_command(
32
151
33
152
roborock_message = encode_local_payload (method , params )
34
153
self ._logger .debug ("Building message id %s for method %s" , roborock_message .get_request_id (), method )
35
- return await self .send_message (roborock_message )
154
+ return await self ._send_message (roborock_message )
36
155
37
- async def send_message (self , roborock_message : RoborockMessage ):
156
+ async def _send_message (self , roborock_message : RoborockMessage ):
38
157
await self .validate_connection ()
39
158
method = roborock_message .get_method ()
40
159
params = roborock_message .get_params ()
0 commit comments