Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
UserData,
)
from roborock.devices.device import RoborockDevice
from roborock.mqtt.roborock_session import create_mqtt_session
from roborock.mqtt.roborock_session import create_lazy_mqtt_session
from roborock.mqtt.session import MqttSession
from roborock.protocol import create_mqtt_params
from roborock.web_api import RoborockApiClient
Expand Down Expand Up @@ -141,7 +141,7 @@ async def create_device_manager(
cache = NoCache()

mqtt_params = create_mqtt_params(user_data.rriot)
mqtt_session = await create_mqtt_session(mqtt_params)
mqtt_session = await create_lazy_mqtt_session(mqtt_params)

def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
channel: Channel
Expand Down
127 changes: 103 additions & 24 deletions roborock/devices/v1_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
This module provides a unified channel interface for V1 protocol devices,
handling both MQTT and local connections with automatic fallback.
"""

import asyncio
import datetime
import logging
from collections.abc import Callable
from typing import TypeVar
Expand All @@ -22,7 +23,12 @@
from .channel import Channel
from .local_channel import LocalChannel, LocalSession, create_local_session
from .mqtt_channel import MqttChannel
from .v1_rpc_channel import PickFirstAvailable, V1RpcChannel, create_local_rpc_channel, create_mqtt_rpc_channel
from .v1_rpc_channel import (
PickFirstAvailable,
V1RpcChannel,
create_local_rpc_channel,
create_mqtt_rpc_channel,
)

_LOGGER = logging.getLogger(__name__)

Expand All @@ -32,6 +38,15 @@

_T = TypeVar("_T", bound=RoborockBase)

# Exponential backoff parameters for reconnecting to local
MIN_RECONNECT_INTERVAL = datetime.timedelta(minutes=1)
MAX_RECONNECT_INTERVAL = datetime.timedelta(minutes=10)
RECONNECT_MULTIPLIER = 1.5
# After this many hours, the network info is refreshed
NETWORK_INFO_REFRESH_INTERVAL = datetime.timedelta(hours=12)
# Interval to check that the local connection is healthy
LOCAL_CONNECTION_CHECK_INTERVAL = datetime.timedelta(seconds=15)


class V1Channel(Channel):
"""Unified V1 protocol channel with automatic MQTT/local connection handling.
Expand Down Expand Up @@ -69,6 +84,8 @@ def __init__(
self._local_unsub: Callable[[], None] | None = None
self._callback: Callable[[RoborockMessage], None] | None = None
self._cache = cache
self._reconnect_task: asyncio.Task[None] | None = None
self._last_network_info_refresh: datetime.datetime | None = None

@property
def is_connected(self) -> bool:
Expand All @@ -78,7 +95,7 @@ def is_connected(self) -> bool:
@property
def is_local_connected(self) -> bool:
"""Return whether local connection is available."""
return self._local_unsub is not None
return self._local_channel is not None and self._local_channel.is_connected

@property
def is_mqtt_connected(self) -> bool:
Expand All @@ -103,25 +120,35 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
a RoborockException. A local connection failure will not raise an exception,
since the local connection is optional.
"""
if self._callback is not None:
raise ValueError("Only one subscription allowed at a time")

if self._mqtt_unsub:
raise ValueError("Already connected to the device")
self._callback = callback

# First establish MQTT connection
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
_LOGGER.debug("V1Channel connected to device %s via MQTT", self._device_uid)

# Try to establish an optional local connection as well.
# Make an initial, optimistic attempt to connect to local with the
# cache. The cache information will be refreshed by the background task.
try:
self._local_unsub = await self._local_connect()
await self._local_connect(use_cache=True)
except RoborockException as err:
_LOGGER.warning("Could not establish local connection for device %s: %s", self._device_uid, err)
else:
_LOGGER.debug("Local connection established for device %s", self._device_uid)

# Start a background task to manage the local connection health. This
# happens independent of whether we were able to connect locally now.
_LOGGER.info("self._reconnect_task=%s", self._reconnect_task)
if self._reconnect_task is None:
loop = asyncio.get_running_loop()
self._reconnect_task = loop.create_task(self._background_reconnect())

if not self.is_local_connected:
# We were not able to connect locally, so fallback to MQTT and at least
# establish that connection explicitly. If this fails then raise an
# error and let the caller know we failed to subscribe.
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
_LOGGER.debug("V1Channel connected to device %s via MQTT", self._device_uid)

def unsub() -> None:
"""Unsubscribe from all messages."""
if self._reconnect_task:
self._reconnect_task.cancel()
self._reconnect_task = None
if self._mqtt_unsub:
self._mqtt_unsub()
self._mqtt_unsub = None
Expand All @@ -130,15 +157,16 @@ def unsub() -> None:
self._local_unsub = None
_LOGGER.debug("Unsubscribed from device %s", self._device_uid)

self._callback = callback
return unsub

async def _get_networking_info(self) -> NetworkInfo:
async def _get_networking_info(self, *, use_cache: bool = True) -> NetworkInfo:
"""Retrieve networking information for the device.

This is a cloud only command used to get the local device's IP address.
"""
cache_data = await self._cache.get()
if cache_data.network_info and (network_info := cache_data.network_info.get(self._device_uid)):
if use_cache and cache_data.network_info and (network_info := cache_data.network_info.get(self._device_uid)):
_LOGGER.debug("Using cached network info for device %s", self._device_uid)
return network_info
try:
Expand All @@ -148,24 +176,75 @@ async def _get_networking_info(self) -> NetworkInfo:
except RoborockException as e:
raise RoborockException(f"Network info failed for device {self._device_uid}") from e
_LOGGER.debug("Network info for device %s: %s", self._device_uid, network_info)
self._last_network_info_refresh = datetime.datetime.now(datetime.timezone.utc)
cache_data.network_info[self._device_uid] = network_info
await self._cache.set(cache_data)
return network_info

async def _local_connect(self) -> Callable[[], None]:
async def _local_connect(self, *, use_cache: bool = True) -> None:
"""Set up local connection if possible."""
_LOGGER.debug("Attempting to connect to local channel for device %s", self._device_uid)
networking_info = await self._get_networking_info()
_LOGGER.debug(
"Attempting to connect to local channel for device %s (use_cache=%s)", self._device_uid, use_cache
)
networking_info = await self._get_networking_info(use_cache=use_cache)
host = networking_info.ip
_LOGGER.debug("Connecting to local channel at %s", host)
self._local_channel = self._local_session(host)
# Create a new local channel and connect
local_channel = self._local_session(host)
try:
await self._local_channel.connect()
await local_channel.connect()
except RoborockException as e:
self._local_channel = None
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
# Wire up the new channel
self._local_channel = local_channel
self._local_rpc_channel = create_local_rpc_channel(self._local_channel)
return await self._local_channel.subscribe(self._on_local_message)
self._local_unsub = await self._local_channel.subscribe(self._on_local_message)
_LOGGER.info("Successfully connected to local device %s", self._device_uid)

async def _background_reconnect(self) -> None:
"""Task to run in the background to manage the local connection."""
_LOGGER.debug("Starting background task to manage local connection for %s", self._device_uid)
reconnect_backoff = MIN_RECONNECT_INTERVAL
local_connect_failures = 0

while True:
try:
if self.is_local_connected:
await asyncio.sleep(LOCAL_CONNECTION_CHECK_INTERVAL.total_seconds())
continue

# Not connected, so wait with backoff before trying to connect.
# The first time through, we don't sleep, we just try to connect.
local_connect_failures += 1
if local_connect_failures > 1:
await asyncio.sleep(reconnect_backoff.total_seconds())
reconnect_backoff = min(reconnect_backoff * RECONNECT_MULTIPLIER, MAX_RECONNECT_INTERVAL)

# First failure refreshes cache. Subsequent failures use the cache
# until the refresh interval expires.
use_cache = True
if local_connect_failures == 1:
use_cache = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could simplify by setting use_cache = True inside the if statement for >1. One fewer if statement

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although i guess you'll have to move that elif below.

You don't have to do this if you don't want to. If you think it hurts readability or understandability, just ignore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah i looked at doing that before, but its like this because there are a couple scenarios for deciding if cache should be used that it makes sense to do them at once. So, I decided to instead extract to a helper function so at least this flow is a little easier to read.

elif self._last_network_info_refresh and (
datetime.datetime.now(datetime.timezone.utc) - self._last_network_info_refresh
> NETWORK_INFO_REFRESH_INTERVAL
):
use_cache = False

await self._local_connect(use_cache=use_cache)
# Reset backoff and failures on success
reconnect_backoff = MIN_RECONNECT_INTERVAL
local_connect_failures = 0

except asyncio.CancelledError:
_LOGGER.debug("Background reconnect task cancelled")
if self._local_channel:
self._local_channel.close()
return
except RoborockException as err:
_LOGGER.debug("Background reconnect failed: %s", err)
except Exception:
_LOGGER.exception("Unhandled exception in background reconnect task")

def _on_mqtt_message(self, message: RoborockMessage) -> None:
"""Handle incoming MQTT messages."""
Expand Down
60 changes: 60 additions & 0 deletions roborock/mqtt/roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,57 @@ async def publish(self, topic: str, message: bytes) -> None:
raise MqttSessionException(f"Error publishing message: {err}") from err


class LazyMqttSession(MqttSession):
"""An MQTT session that is started on first attempt to subscribe.

This is a wrapper around an existing MqttSession that will only start
the underlying session when the first attempt to subscribe or publish
is made.
"""

def __init__(self, session: RoborockMqttSession) -> None:
"""Initialize the lazy session with an existing session."""
self._lock = asyncio.Lock()
self._started = False
self._session = session

@property
def connected(self) -> bool:
"""True if the session is connected to the broker."""
return self._session.connected

async def _maybe_start(self) -> None:
"""Start the MQTT session if not already started."""
async with self._lock:
if not self._started:
await self._session.start()
self._started = True

async def subscribe(self, device_id: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
"""Invoke the callback when messages are received on the topic.

The returned callable unsubscribes from the topic when called.
"""
await self._maybe_start()
return await self._session.subscribe(device_id, callback)

async def publish(self, topic: str, message: bytes) -> None:
"""Publish a message on the specified topic.

This will raise an exception if the message could not be sent.
"""
await self._maybe_start()
return await self._session.publish(topic, message)

async def close(self) -> None:
"""Cancels the mqtt loop.

This will close the underlying session and will not allow it to be
restarted again.
"""
await self._session.close()


async def create_mqtt_session(params: MqttParams) -> MqttSession:
"""Create an MQTT session.

Expand All @@ -230,3 +281,12 @@ async def create_mqtt_session(params: MqttParams) -> MqttSession:
session = RoborockMqttSession(params)
await session.start()
return session


async def create_lazy_mqtt_session(params: MqttParams) -> MqttSession:
"""Create a lazy MQTT session.

This function is a factory for creating an MQTT session that will
only connect when the first attempt to subscribe or publish is made.
"""
return LazyMqttSession(RoborockMqttSession(params))
2 changes: 1 addition & 1 deletion tests/devices/test_device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@pytest.fixture(autouse=True, name="mqtt_session")
def setup_mqtt_session() -> Generator[Mock, None, None]:
"""Fixture to set up the MQTT session for the tests."""
with patch("roborock.devices.device_manager.create_mqtt_session") as mock_create_session:
with patch("roborock.devices.device_manager.create_lazy_mqtt_session") as mock_create_session:
yield mock_create_session


Expand Down
10 changes: 4 additions & 6 deletions tests/devices/test_mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import asyncio
import json
import logging
from collections.abc import AsyncGenerator, Callable, Generator
from unittest.mock import AsyncMock, Mock, patch
from collections.abc import AsyncGenerator, Callable
from unittest.mock import AsyncMock, Mock

import pytest

Expand Down Expand Up @@ -48,11 +48,9 @@


@pytest.fixture(name="mqtt_session", autouse=True)
def setup_mqtt_session() -> Generator[Mock, None, None]:
def setup_mqtt_session() -> Mock:
"""Fixture to set up the MQTT session for the tests."""
mock_session = AsyncMock()
with patch("roborock.devices.device_manager.create_mqtt_session", return_value=mock_session):
yield mock_session
return AsyncMock()


@pytest.fixture(name="mqtt_channel", autouse=True)
Expand Down
Loading