Skip to content

Commit d0212e5

Browse files
authored
fix: improve new v1 apis to use mqtt lazily and work entirely locally (#491)
* fix: improve new v1 apis to use mqtt lazily and work entirely locally * chore: remove unnecessary logging * chore: Update comments * chore: extract caching logic to one place * chore: remove whitespace
1 parent c0c082b commit d0212e5

File tree

6 files changed

+189
-129
lines changed

6 files changed

+189
-129
lines changed

roborock/devices/device_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
UserData,
1515
)
1616
from roborock.devices.device import RoborockDevice
17-
from roborock.mqtt.roborock_session import create_mqtt_session
17+
from roborock.mqtt.roborock_session import create_lazy_mqtt_session
1818
from roborock.mqtt.session import MqttSession
1919
from roborock.protocol import create_mqtt_params
2020
from roborock.web_api import RoborockApiClient
@@ -141,7 +141,7 @@ async def create_device_manager(
141141
cache = NoCache()
142142

143143
mqtt_params = create_mqtt_params(user_data.rriot)
144-
mqtt_session = await create_mqtt_session(mqtt_params)
144+
mqtt_session = await create_lazy_mqtt_session(mqtt_params)
145145

146146
def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
147147
channel: Channel

roborock/devices/v1_channel.py

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
This module provides a unified channel interface for V1 protocol devices,
44
handling both MQTT and local connections with automatic fallback.
55
"""
6-
6+
import asyncio
7+
import datetime
78
import logging
89
from collections.abc import Callable
910
from typing import TypeVar
@@ -22,7 +23,12 @@
2223
from .channel import Channel
2324
from .local_channel import LocalChannel, LocalSession, create_local_session
2425
from .mqtt_channel import MqttChannel
25-
from .v1_rpc_channel import PickFirstAvailable, V1RpcChannel, create_local_rpc_channel, create_mqtt_rpc_channel
26+
from .v1_rpc_channel import (
27+
PickFirstAvailable,
28+
V1RpcChannel,
29+
create_local_rpc_channel,
30+
create_mqtt_rpc_channel,
31+
)
2632

2733
_LOGGER = logging.getLogger(__name__)
2834

@@ -32,6 +38,15 @@
3238

3339
_T = TypeVar("_T", bound=RoborockBase)
3440

41+
# Exponential backoff parameters for reconnecting to local
42+
MIN_RECONNECT_INTERVAL = datetime.timedelta(minutes=1)
43+
MAX_RECONNECT_INTERVAL = datetime.timedelta(minutes=10)
44+
RECONNECT_MULTIPLIER = 1.5
45+
# After this many hours, the network info is refreshed
46+
NETWORK_INFO_REFRESH_INTERVAL = datetime.timedelta(hours=12)
47+
# Interval to check that the local connection is healthy
48+
LOCAL_CONNECTION_CHECK_INTERVAL = datetime.timedelta(seconds=15)
49+
3550

3651
class V1Channel(Channel):
3752
"""Unified V1 protocol channel with automatic MQTT/local connection handling.
@@ -69,6 +84,8 @@ def __init__(
6984
self._local_unsub: Callable[[], None] | None = None
7085
self._callback: Callable[[RoborockMessage], None] | None = None
7186
self._cache = cache
87+
self._reconnect_task: asyncio.Task[None] | None = None
88+
self._last_network_info_refresh: datetime.datetime | None = None
7289

7390
@property
7491
def is_connected(self) -> bool:
@@ -78,7 +95,7 @@ def is_connected(self) -> bool:
7895
@property
7996
def is_local_connected(self) -> bool:
8097
"""Return whether local connection is available."""
81-
return self._local_unsub is not None
98+
return self._local_channel is not None and self._local_channel.is_connected
8299

83100
@property
84101
def is_mqtt_connected(self) -> bool:
@@ -103,25 +120,35 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
103120
a RoborockException. A local connection failure will not raise an exception,
104121
since the local connection is optional.
105122
"""
123+
if self._callback is not None:
124+
raise ValueError("Only one subscription allowed at a time")
106125

107-
if self._mqtt_unsub:
108-
raise ValueError("Already connected to the device")
109-
self._callback = callback
110-
111-
# First establish MQTT connection
112-
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
113-
_LOGGER.debug("V1Channel connected to device %s via MQTT", self._device_uid)
114-
115-
# Try to establish an optional local connection as well.
126+
# Make an initial, optimistic attempt to connect to local with the
127+
# cache. The cache information will be refreshed by the background task.
116128
try:
117-
self._local_unsub = await self._local_connect()
129+
await self._local_connect(use_cache=True)
118130
except RoborockException as err:
119131
_LOGGER.warning("Could not establish local connection for device %s: %s", self._device_uid, err)
120-
else:
121-
_LOGGER.debug("Local connection established for device %s", self._device_uid)
132+
133+
# Start a background task to manage the local connection health. This
134+
# happens independent of whether we were able to connect locally now.
135+
_LOGGER.info("self._reconnect_task=%s", self._reconnect_task)
136+
if self._reconnect_task is None:
137+
loop = asyncio.get_running_loop()
138+
self._reconnect_task = loop.create_task(self._background_reconnect())
139+
140+
if not self.is_local_connected:
141+
# We were not able to connect locally, so fallback to MQTT and at least
142+
# establish that connection explicitly. If this fails then raise an
143+
# error and let the caller know we failed to subscribe.
144+
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
145+
_LOGGER.debug("V1Channel connected to device %s via MQTT", self._device_uid)
122146

123147
def unsub() -> None:
124148
"""Unsubscribe from all messages."""
149+
if self._reconnect_task:
150+
self._reconnect_task.cancel()
151+
self._reconnect_task = None
125152
if self._mqtt_unsub:
126153
self._mqtt_unsub()
127154
self._mqtt_unsub = None
@@ -130,15 +157,16 @@ def unsub() -> None:
130157
self._local_unsub = None
131158
_LOGGER.debug("Unsubscribed from device %s", self._device_uid)
132159

160+
self._callback = callback
133161
return unsub
134162

135-
async def _get_networking_info(self) -> NetworkInfo:
163+
async def _get_networking_info(self, *, use_cache: bool = True) -> NetworkInfo:
136164
"""Retrieve networking information for the device.
137165
138166
This is a cloud only command used to get the local device's IP address.
139167
"""
140168
cache_data = await self._cache.get()
141-
if cache_data.network_info and (network_info := cache_data.network_info.get(self._device_uid)):
169+
if use_cache and cache_data.network_info and (network_info := cache_data.network_info.get(self._device_uid)):
142170
_LOGGER.debug("Using cached network info for device %s", self._device_uid)
143171
return network_info
144172
try:
@@ -148,24 +176,81 @@ async def _get_networking_info(self) -> NetworkInfo:
148176
except RoborockException as e:
149177
raise RoborockException(f"Network info failed for device {self._device_uid}") from e
150178
_LOGGER.debug("Network info for device %s: %s", self._device_uid, network_info)
179+
self._last_network_info_refresh = datetime.datetime.now(datetime.timezone.utc)
151180
cache_data.network_info[self._device_uid] = network_info
152181
await self._cache.set(cache_data)
153182
return network_info
154183

155-
async def _local_connect(self) -> Callable[[], None]:
184+
async def _local_connect(self, *, use_cache: bool = True) -> None:
156185
"""Set up local connection if possible."""
157-
_LOGGER.debug("Attempting to connect to local channel for device %s", self._device_uid)
158-
networking_info = await self._get_networking_info()
186+
_LOGGER.debug(
187+
"Attempting to connect to local channel for device %s (use_cache=%s)", self._device_uid, use_cache
188+
)
189+
networking_info = await self._get_networking_info(use_cache=use_cache)
159190
host = networking_info.ip
160191
_LOGGER.debug("Connecting to local channel at %s", host)
161-
self._local_channel = self._local_session(host)
192+
# Create a new local channel and connect
193+
local_channel = self._local_session(host)
162194
try:
163-
await self._local_channel.connect()
195+
await local_channel.connect()
164196
except RoborockException as e:
165-
self._local_channel = None
166197
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
198+
# Wire up the new channel
199+
self._local_channel = local_channel
167200
self._local_rpc_channel = create_local_rpc_channel(self._local_channel)
168-
return await self._local_channel.subscribe(self._on_local_message)
201+
self._local_unsub = await self._local_channel.subscribe(self._on_local_message)
202+
_LOGGER.info("Successfully connected to local device %s", self._device_uid)
203+
204+
async def _background_reconnect(self) -> None:
205+
"""Task to run in the background to manage the local connection."""
206+
_LOGGER.debug("Starting background task to manage local connection for %s", self._device_uid)
207+
reconnect_backoff = MIN_RECONNECT_INTERVAL
208+
local_connect_failures = 0
209+
210+
while True:
211+
try:
212+
if self.is_local_connected:
213+
await asyncio.sleep(LOCAL_CONNECTION_CHECK_INTERVAL.total_seconds())
214+
continue
215+
216+
# Not connected, so wait with backoff before trying to connect.
217+
# The first time through, we don't sleep, we just try to connect.
218+
local_connect_failures += 1
219+
if local_connect_failures > 1:
220+
await asyncio.sleep(reconnect_backoff.total_seconds())
221+
reconnect_backoff = min(reconnect_backoff * RECONNECT_MULTIPLIER, MAX_RECONNECT_INTERVAL)
222+
223+
use_cache = self._should_use_cache(local_connect_failures)
224+
await self._local_connect(use_cache=use_cache)
225+
# Reset backoff and failures on success
226+
reconnect_backoff = MIN_RECONNECT_INTERVAL
227+
local_connect_failures = 0
228+
229+
except asyncio.CancelledError:
230+
_LOGGER.debug("Background reconnect task cancelled")
231+
if self._local_channel:
232+
self._local_channel.close()
233+
return
234+
except RoborockException as err:
235+
_LOGGER.debug("Background reconnect failed: %s", err)
236+
except Exception:
237+
_LOGGER.exception("Unhandled exception in background reconnect task")
238+
239+
def _should_use_cache(self, local_connect_failures: int) -> bool:
240+
"""Determine whether to use cached network info on retries.
241+
242+
On the first retry we'll avoid the cache to handle the case where
243+
the network ip may have recently changed. Otherwise, use the cache
244+
if available then expire at some point.
245+
"""
246+
if local_connect_failures == 1:
247+
return False
248+
elif self._last_network_info_refresh and (
249+
datetime.datetime.now(datetime.timezone.utc) - self._last_network_info_refresh
250+
> NETWORK_INFO_REFRESH_INTERVAL
251+
):
252+
return False
253+
return True
169254

170255
def _on_mqtt_message(self, message: RoborockMessage) -> None:
171256
"""Handle incoming MQTT messages."""

roborock/mqtt/roborock_session.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,57 @@ async def publish(self, topic: str, message: bytes) -> None:
220220
raise MqttSessionException(f"Error publishing message: {err}") from err
221221

222222

223+
class LazyMqttSession(MqttSession):
224+
"""An MQTT session that is started on first attempt to subscribe.
225+
226+
This is a wrapper around an existing MqttSession that will only start
227+
the underlying session when the first attempt to subscribe or publish
228+
is made.
229+
"""
230+
231+
def __init__(self, session: RoborockMqttSession) -> None:
232+
"""Initialize the lazy session with an existing session."""
233+
self._lock = asyncio.Lock()
234+
self._started = False
235+
self._session = session
236+
237+
@property
238+
def connected(self) -> bool:
239+
"""True if the session is connected to the broker."""
240+
return self._session.connected
241+
242+
async def _maybe_start(self) -> None:
243+
"""Start the MQTT session if not already started."""
244+
async with self._lock:
245+
if not self._started:
246+
await self._session.start()
247+
self._started = True
248+
249+
async def subscribe(self, device_id: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
250+
"""Invoke the callback when messages are received on the topic.
251+
252+
The returned callable unsubscribes from the topic when called.
253+
"""
254+
await self._maybe_start()
255+
return await self._session.subscribe(device_id, callback)
256+
257+
async def publish(self, topic: str, message: bytes) -> None:
258+
"""Publish a message on the specified topic.
259+
260+
This will raise an exception if the message could not be sent.
261+
"""
262+
await self._maybe_start()
263+
return await self._session.publish(topic, message)
264+
265+
async def close(self) -> None:
266+
"""Cancels the mqtt loop.
267+
268+
This will close the underlying session and will not allow it to be
269+
restarted again.
270+
"""
271+
await self._session.close()
272+
273+
223274
async def create_mqtt_session(params: MqttParams) -> MqttSession:
224275
"""Create an MQTT session.
225276
@@ -230,3 +281,12 @@ async def create_mqtt_session(params: MqttParams) -> MqttSession:
230281
session = RoborockMqttSession(params)
231282
await session.start()
232283
return session
284+
285+
286+
async def create_lazy_mqtt_session(params: MqttParams) -> MqttSession:
287+
"""Create a lazy MQTT session.
288+
289+
This function is a factory for creating an MQTT session that will
290+
only connect when the first attempt to subscribe or publish is made.
291+
"""
292+
return LazyMqttSession(RoborockMqttSession(params))

tests/devices/test_device_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
@pytest.fixture(autouse=True, name="mqtt_session")
2020
def setup_mqtt_session() -> Generator[Mock, None, None]:
2121
"""Fixture to set up the MQTT session for the tests."""
22-
with patch("roborock.devices.device_manager.create_mqtt_session") as mock_create_session:
22+
with patch("roborock.devices.device_manager.create_lazy_mqtt_session") as mock_create_session:
2323
yield mock_create_session
2424

2525

tests/devices/test_mqtt_channel.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import asyncio
44
import json
55
import logging
6-
from collections.abc import AsyncGenerator, Callable, Generator
7-
from unittest.mock import AsyncMock, Mock, patch
6+
from collections.abc import AsyncGenerator, Callable
7+
from unittest.mock import AsyncMock, Mock
88

99
import pytest
1010

@@ -48,11 +48,9 @@
4848

4949

5050
@pytest.fixture(name="mqtt_session", autouse=True)
51-
def setup_mqtt_session() -> Generator[Mock, None, None]:
51+
def setup_mqtt_session() -> Mock:
5252
"""Fixture to set up the MQTT session for the tests."""
53-
mock_session = AsyncMock()
54-
with patch("roborock.devices.device_manager.create_mqtt_session", return_value=mock_session):
55-
yield mock_session
53+
return AsyncMock()
5654

5755

5856
@pytest.fixture(name="mqtt_channel", autouse=True)

0 commit comments

Comments
 (0)