Skip to content

Commit 450e35e

Browse files
authored
fix: Merge the local api with the local v1 api (#438)
1 parent b7bab8b commit 450e35e

File tree

5 files changed

+130
-154
lines changed

5 files changed

+130
-154
lines changed

roborock/api.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,3 @@ def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
116116
request_id = new_id
117117
self._waiting_queue[request_id] = queue
118118
return asyncio.ensure_future(self._wait_response(request_id, queue))
119-
120-
@abstractmethod
121-
async def send_message(self, roborock_message: RoborockMessage):
122-
"""Send a message to the Roborock device."""

roborock/local_api.py

Lines changed: 0 additions & 139 deletions
This file was deleted.

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 126 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
import asyncio
12
import logging
3+
from asyncio import Lock, TimerHandle, Transport, get_running_loop
4+
from collections.abc import Callable
5+
from dataclasses import dataclass
26

3-
from roborock.local_api import RoborockLocalClient
7+
import async_timeout
48

5-
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
6-
from ..exceptions import VacuumError
9+
from .. import CommandVacuumError, DeviceData, RoborockCommand
10+
from ..api import RoborockClient
11+
from ..exceptions import RoborockConnectionException, RoborockException, VacuumError
12+
from ..protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
713
from ..protocols.v1_protocol import encode_local_payload
814
from ..roborock_message import RoborockMessage, RoborockMessageProtocol
915
from ..util import RoborockLoggerAdapter
@@ -12,16 +18,129 @@
1218
_LOGGER = logging.getLogger(__name__)
1319

1420

15-
class RoborockLocalClientV1(RoborockLocalClient, RoborockClientV1):
21+
@dataclass
22+
class _LocalProtocol(asyncio.Protocol):
23+
"""Callbacks for the Roborock local client transport."""
24+
25+
messages_cb: Callable[[bytes], None]
26+
connection_lost_cb: Callable[[Exception | None], None]
27+
28+
def data_received(self, bytes) -> None:
29+
"""Called when data is received from the transport."""
30+
self.messages_cb(bytes)
31+
32+
def connection_lost(self, exc: Exception | None) -> None:
33+
"""Called when the transport connection is lost."""
34+
self.connection_lost_cb(exc)
35+
36+
37+
class RoborockLocalClientV1(RoborockClientV1, RoborockClient):
1638
"""Roborock local client for v1 devices."""
1739

1840
def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
1941
"""Initialize the Roborock local client."""
20-
RoborockLocalClient.__init__(self, device_data)
42+
if device_data.host is None:
43+
raise RoborockException("Host is required")
44+
self.host = device_data.host
45+
self._batch_structs: list[RoborockMessage] = []
46+
self._executing = False
47+
self.transport: Transport | None = None
48+
self._mutex = Lock()
49+
self.keep_alive_task: TimerHandle | None = None
2150
RoborockClientV1.__init__(self, device_data, "abc")
51+
RoborockClient.__init__(self, device_data)
52+
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
53+
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
54+
self._decoder: Decoder = create_local_decoder(device_data.device.local_key)
2255
self.queue_timeout = queue_timeout
2356
self._logger = RoborockLoggerAdapter(device_data.device.name, _LOGGER)
2457

58+
def _data_received(self, message):
59+
"""Called when data is received from the transport."""
60+
parsed_msg = self._decoder(message)
61+
self.on_message_received(parsed_msg)
62+
63+
def _connection_lost(self, exc: Exception | None):
64+
"""Called when the transport connection is lost."""
65+
self._sync_disconnect()
66+
self.on_connection_lost(exc)
67+
68+
def is_connected(self):
69+
return self.transport and self.transport.is_reading()
70+
71+
async def keep_alive_func(self, _=None):
72+
try:
73+
await self.ping()
74+
except RoborockException:
75+
pass
76+
loop = asyncio.get_running_loop()
77+
self.keep_alive_task = loop.call_later(10, lambda: asyncio.create_task(self.keep_alive_func()))
78+
79+
async def async_connect(self) -> None:
80+
should_ping = False
81+
async with self._mutex:
82+
try:
83+
if not self.is_connected():
84+
self._sync_disconnect()
85+
async with async_timeout.timeout(self.queue_timeout):
86+
self._logger.debug(f"Connecting to {self.host}")
87+
loop = get_running_loop()
88+
self.transport, _ = await loop.create_connection( # type: ignore
89+
lambda: self._local_protocol, self.host, 58867
90+
)
91+
self._logger.info(f"Connected to {self.host}")
92+
should_ping = True
93+
except BaseException as e:
94+
raise RoborockConnectionException(f"Failed connecting to {self.host}") from e
95+
if should_ping:
96+
await self.hello()
97+
await self.keep_alive_func()
98+
99+
def _sync_disconnect(self) -> None:
100+
loop = asyncio.get_running_loop()
101+
if self.transport and loop.is_running():
102+
self._logger.debug(f"Disconnecting from {self.host}")
103+
self.transport.close()
104+
if self.keep_alive_task:
105+
self.keep_alive_task.cancel()
106+
107+
async def async_disconnect(self) -> None:
108+
async with self._mutex:
109+
self._sync_disconnect()
110+
111+
async def hello(self):
112+
request_id = 1
113+
protocol = RoborockMessageProtocol.HELLO_REQUEST
114+
try:
115+
return await self._send_message(
116+
RoborockMessage(
117+
protocol=protocol,
118+
seq=request_id,
119+
random=22,
120+
)
121+
)
122+
except Exception as e:
123+
self._logger.error(e)
124+
125+
async def ping(self) -> None:
126+
request_id = 2
127+
protocol = RoborockMessageProtocol.PING_REQUEST
128+
return await self._send_message(
129+
RoborockMessage(
130+
protocol=protocol,
131+
seq=request_id,
132+
random=23,
133+
)
134+
)
135+
136+
def _send_msg_raw(self, data: bytes):
137+
try:
138+
if not self.transport:
139+
raise RoborockException("Can not send message without connection")
140+
self.transport.write(data)
141+
except Exception as e:
142+
raise RoborockException(e) from e
143+
25144
async def _send_command(
26145
self,
27146
method: RoborockCommand | str,
@@ -32,9 +151,9 @@ async def _send_command(
32151

33152
roborock_message = encode_local_payload(method, params)
34153
self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id(), method)
35-
return await self.send_message(roborock_message)
154+
return await self._send_message(roborock_message)
36155

37-
async def send_message(self, roborock_message: RoborockMessage):
156+
async def _send_message(self, roborock_message: RoborockMessage):
38157
await self.validate_connection()
39158
method = roborock_message.get_method()
40159
params = roborock_message.get_params()

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
self.queue_timeout = queue_timeout
4040
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
4141

42-
async def send_message(self, roborock_message: RoborockMessage):
42+
async def _send_message(self, roborock_message: RoborockMessage):
4343
await self.validate_connection()
4444
response_protocol = RoborockMessageProtocol.RPC_RESPONSE
4545

@@ -67,11 +67,11 @@ async def update_values(
6767
message = encode_mqtt_payload(
6868
{RoborockDyadDataProtocol.ID_QUERY: str([int(protocol) for protocol in dyad_data_protocols])}
6969
)
70-
return await self.send_message(message)
70+
return await self._send_message(message)
7171

7272
async def set_value(
7373
self, protocol: RoborockDyadDataProtocol | RoborockZeoProtocol, value: typing.Any
7474
) -> dict[int, typing.Any]:
7575
"""Set a value for a specific protocol on the A01 device."""
7676
message = encode_mqtt_payload({protocol: value})
77-
return await self.send_message(message)
77+
return await self._send_message(message)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def handle_write(data: bytes) -> None:
307307

308308
return (mock_transport, "proto")
309309

310-
with patch("roborock.local_api.get_running_loop") as mock_loop:
310+
with patch("roborock.version_1_apis.roborock_local_client_v1.get_running_loop") as mock_loop:
311311
mock_loop.return_value.create_connection.side_effect = create_connection
312312
yield
313313

0 commit comments

Comments
 (0)