Skip to content

Commit f076a51

Browse files
authored
feat: Add an explicit module for caching (#432)
* feat: Update the cli cache to also store network info A follow up step will support caching network info for the new device manager API, and will use the CLI to do this. * chore: fix lint errors * feat: Add explicit cache module * fix: adjust cache implementation defaults
1 parent 2453081 commit f076a51

File tree

7 files changed

+196
-21
lines changed

7 files changed

+196
-21
lines changed

roborock/cli.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from roborock import RoborockException
1414
from roborock.containers import DeviceData, HomeData, HomeDataProduct, LoginData, NetworkInfo, RoborockBase, UserData
15+
from roborock.devices.cache import Cache, CacheData
1516
from roborock.devices.device_manager import create_device_manager, create_home_data_api
1617
from roborock.protocol import MessageParser
1718
from roborock.util import run_sync
@@ -39,7 +40,7 @@ class ConnectionCache(RoborockBase):
3940
network_info: dict[str, NetworkInfo] | None = None
4041

4142

42-
class RoborockContext:
43+
class RoborockContext(Cache):
4344
roborock_file = Path("~/.roborock").expanduser()
4445
_cache_data: ConnectionCache | None = None
4546

@@ -68,6 +69,18 @@ def cache_data(self) -> ConnectionCache:
6869
self.validate()
6970
return self._cache_data
7071

72+
async def get(self) -> CacheData:
73+
"""Get cached value."""
74+
connection_cache = self.cache_data()
75+
return CacheData(home_data=connection_cache.home_data, network_info=connection_cache.network_info or {})
76+
77+
async def set(self, value: CacheData) -> None:
78+
"""Set value in the cache."""
79+
connection_cache = self.cache_data()
80+
connection_cache.home_data = value.home_data
81+
connection_cache.network_info = value.network_info
82+
self.update(connection_cache)
83+
7184

7285
@click.option("-d", "--debug", default=False, count=True)
7386
@click.version_option(package_name="python-roborock")
@@ -119,14 +132,8 @@ async def session(ctx, duration: int):
119132

120133
home_data_api = create_home_data_api(cache_data.email, cache_data.user_data)
121134

122-
async def home_data_cache() -> HomeData:
123-
if cache_data.home_data is None:
124-
cache_data.home_data = await home_data_api()
125-
context.update(cache_data)
126-
return cache_data.home_data
127-
128135
# Create device manager
129-
device_manager = await create_device_manager(cache_data.user_data, home_data_cache)
136+
device_manager = await create_device_manager(cache_data.user_data, home_data_api, context)
130137

131138
devices = await device_manager.get_devices()
132139
click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}")

roborock/devices/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
__all__ = [
44
"device",
55
"device_manager",
6+
"cache",
67
]

roborock/devices/cache.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""This module provides caching functionality for the Roborock device management system.
2+
3+
This module defines a cache interface that you may use to cache device
4+
information to avoid unnecessary API calls. Callers may implement
5+
this interface to provide their own caching mechanism.
6+
"""
7+
8+
from dataclasses import dataclass, field
9+
from typing import Protocol
10+
11+
from roborock.containers import HomeData, NetworkInfo
12+
13+
14+
@dataclass
15+
class CacheData:
16+
"""Data structure for caching device information."""
17+
18+
home_data: HomeData | None = None
19+
"""Home data containing device and product information."""
20+
21+
network_info: dict[str, NetworkInfo] = field(default_factory=dict)
22+
"""Network information indexed by device DUID."""
23+
24+
25+
class Cache(Protocol):
26+
"""Protocol for a cache that can store and retrieve values."""
27+
28+
async def get(self) -> CacheData:
29+
"""Get cached value."""
30+
...
31+
32+
async def set(self, value: CacheData) -> None:
33+
"""Set value in the cache."""
34+
...
35+
36+
37+
class InMemoryCache(Cache):
38+
"""In-memory cache implementation."""
39+
40+
def __init__(self):
41+
self._data = CacheData()
42+
43+
async def get(self) -> CacheData:
44+
return self._data
45+
46+
async def set(self, value: CacheData) -> None:
47+
self._data = value
48+
49+
50+
class NoCache(Cache):
51+
"""No-op cache implementation."""
52+
53+
async def get(self) -> CacheData:
54+
return CacheData()
55+
56+
async def set(self, value: CacheData) -> None:
57+
pass

roborock/devices/device_manager.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from roborock.protocol import create_mqtt_params
1919
from roborock.web_api import RoborockApiClient
2020

21+
from .cache import Cache, NoCache
2122
from .channel import Channel
2223
from .mqtt_channel import create_mqtt_channel
2324
from .traits.dyad import DyadApi
@@ -32,8 +33,6 @@
3233
"create_device_manager",
3334
"create_home_data_api",
3435
"DeviceManager",
35-
"HomeDataApi",
36-
"DeviceCreator",
3736
]
3837

3938

@@ -57,19 +56,27 @@ def __init__(
5756
home_data_api: HomeDataApi,
5857
device_creator: DeviceCreator,
5958
mqtt_session: MqttSession,
59+
cache: Cache,
6060
) -> None:
6161
"""Initialize the DeviceManager with user data and optional cache storage.
6262
6363
This takes ownership of the MQTT session and will close it when the manager is closed.
6464
"""
6565
self._home_data_api = home_data_api
66+
self._cache = cache
6667
self._device_creator = device_creator
6768
self._devices: dict[str, RoborockDevice] = {}
6869
self._mqtt_session = mqtt_session
6970

7071
async def discover_devices(self) -> list[RoborockDevice]:
7172
"""Discover all devices for the logged-in user."""
72-
home_data = await self._home_data_api()
73+
cache_data = await self._cache.get()
74+
if not cache_data.home_data:
75+
_LOGGER.debug("No cached home data found, fetching from API")
76+
cache_data.home_data = await self._home_data_api()
77+
await self._cache.set(cache_data)
78+
home_data = cache_data.home_data
79+
7380
device_products = home_data.device_products
7481
_LOGGER.debug("Discovered %d devices %s", len(device_products), home_data)
7582

@@ -118,13 +125,19 @@ async def home_data_api() -> HomeData:
118125
return home_data_api
119126

120127

121-
async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi) -> DeviceManager:
128+
async def create_device_manager(
129+
user_data: UserData,
130+
home_data_api: HomeDataApi,
131+
cache: Cache | None = None,
132+
) -> DeviceManager:
122133
"""Convenience function to create and initialize a DeviceManager.
123134
124135
The Home Data is fetched using the provided home_data_api callable which
125136
is exposed this way to allow for swapping out other implementations to
126137
include caching or other optimizations.
127138
"""
139+
if cache is None:
140+
cache = NoCache()
128141

129142
mqtt_params = create_mqtt_params(user_data.rriot)
130143
mqtt_session = await create_mqtt_session(mqtt_params)
@@ -135,7 +148,7 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
135148
# TODO: Define a registration mechanism/factory for v1 traits
136149
match device.pv:
137150
case DeviceVersion.V1:
138-
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device)
151+
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
139152
traits.append(StatusTrait(product, channel.rpc_channel))
140153
case DeviceVersion.A01:
141154
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
@@ -150,6 +163,6 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
150163
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
151164
return RoborockDevice(device, channel, traits)
152165

153-
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
166+
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session, cache=cache)
154167
await manager.discover_devices()
155168
return manager

roborock/devices/v1_channel.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from roborock.roborock_message import RoborockMessage
1919
from roborock.roborock_typing import RoborockCommand
2020

21+
from .cache import Cache
2122
from .channel import Channel
2223
from .local_channel import LocalChannel, LocalSession, create_local_session
2324
from .mqtt_channel import MqttChannel
@@ -46,6 +47,7 @@ def __init__(
4647
security_data: SecurityData,
4748
mqtt_channel: MqttChannel,
4849
local_session: LocalSession,
50+
cache: Cache,
4951
) -> None:
5052
"""Initialize the V1Channel.
5153
@@ -62,7 +64,7 @@ def __init__(
6264
self._mqtt_unsub: Callable[[], None] | None = None
6365
self._local_unsub: Callable[[], None] | None = None
6466
self._callback: Callable[[RoborockMessage], None] | None = None
65-
self._networking_info: NetworkInfo | None = None
67+
self._cache = cache
6668

6769
@property
6870
def is_connected(self) -> bool:
@@ -131,19 +133,26 @@ async def _get_networking_info(self) -> NetworkInfo:
131133
132134
This is a cloud only command used to get the local device's IP address.
133135
"""
136+
cache_data = await self._cache.get()
137+
if cache_data.network_info and (network_info := cache_data.network_info.get(self._device_uid)):
138+
_LOGGER.debug("Using cached network info for device %s", self._device_uid)
139+
return network_info
134140
try:
135-
return await self._mqtt_rpc_channel.send_command(
141+
network_info = await self._mqtt_rpc_channel.send_command(
136142
RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo
137143
)
138144
except RoborockException as e:
139145
raise RoborockException(f"Network info failed for device {self._device_uid}") from e
146+
_LOGGER.debug("Network info for device %s: %s", self._device_uid, network_info)
147+
cache_data.network_info[self._device_uid] = network_info
148+
await self._cache.set(cache_data)
149+
return network_info
140150

141151
async def _local_connect(self) -> Callable[[], None]:
142152
"""Set up local connection if possible."""
143153
_LOGGER.debug("Attempting to connect to local channel for device %s", self._device_uid)
144-
if self._networking_info is None:
145-
self._networking_info = await self._get_networking_info()
146-
host = self._networking_info.ip
154+
networking_info = await self._get_networking_info()
155+
host = networking_info.ip
147156
_LOGGER.debug("Connecting to local channel at %s", host)
148157
self._local_channel = self._local_session(host)
149158
try:
@@ -168,10 +177,14 @@ def _on_local_message(self, message: RoborockMessage) -> None:
168177

169178

170179
def create_v1_channel(
171-
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
180+
user_data: UserData,
181+
mqtt_params: MqttParams,
182+
mqtt_session: MqttSession,
183+
device: HomeDataDevice,
184+
cache: Cache,
172185
) -> V1Channel:
173186
"""Create a V1Channel for the given device."""
174187
security_data = create_security_data(user_data.rriot)
175188
mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
176189
local_session = create_local_session(device.local_key)
177-
return V1Channel(device.duid, security_data, mqtt_channel, local_session=local_session)
190+
return V1Channel(device.duid, security_data, mqtt_channel, local_session=local_session, cache=cache)

tests/devices/test_device_manager.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from roborock.containers import HomeData, UserData
9+
from roborock.devices.cache import CacheData, InMemoryCache
910
from roborock.devices.device_manager import create_device_manager, create_home_data_api
1011
from roborock.exceptions import RoborockException
1112

@@ -98,3 +99,37 @@ async def test_create_home_data_api_exception() -> None:
9899

99100
with pytest.raises(RoborockException, match="Test exception"):
100101
await api()
102+
103+
104+
async def test_cache_logic() -> None:
105+
"""Test that the cache logic works correctly."""
106+
call_count = 0
107+
108+
async def mock_home_data_with_counter() -> HomeData:
109+
nonlocal call_count
110+
call_count += 1
111+
return HomeData.from_dict(mock_data.HOME_DATA_RAW)
112+
113+
class TestCache:
114+
def __init__(self):
115+
self._data = CacheData()
116+
117+
async def get(self) -> CacheData:
118+
return self._data
119+
120+
async def set(self, value: CacheData) -> None:
121+
self._data = value
122+
123+
# First call happens during create_device_manager initialization
124+
device_manager = await create_device_manager(USER_DATA, mock_home_data_with_counter, cache=InMemoryCache())
125+
assert call_count == 1
126+
127+
# Second call should use cache, not increment call_count
128+
devices2 = await device_manager.discover_devices()
129+
assert call_count == 1 # Should still be 1, not 2
130+
assert len(devices2) == 1
131+
132+
await device_manager.close()
133+
assert len(devices2) == 1
134+
135+
await device_manager.close()

tests/devices/test_v1_channel.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212

1313
from roborock.containers import NetworkInfo, RoborockStateCode, S5MaxStatus, UserData
14+
from roborock.devices.cache import CacheData, InMemoryCache
1415
from roborock.devices.local_channel import LocalChannel, LocalSession
1516
from roborock.devices.mqtt_channel import MqttChannel
1617
from roborock.devices.v1_channel import V1Channel
@@ -105,6 +106,7 @@ def setup_v1_channel(
105106
security_data=TEST_SECURITY_DATA,
106107
mqtt_channel=mock_mqtt_channel,
107108
local_session=mock_local_session,
109+
cache=InMemoryCache(),
108110
)
109111

110112

@@ -408,6 +410,52 @@ async def test_v1_channel_networking_info_retrieved_during_connection(
408410
mock_local_session.assert_called_once_with(mock_data.NETWORK_INFO["ip"])
409411

410412

413+
async def test_v1_channel_networking_info_cached_during_connection(
414+
mock_mqtt_channel: Mock,
415+
mock_local_channel: Mock,
416+
mock_local_session: Mock,
417+
) -> None:
418+
"""Test that networking information is cached and reused on subsequent connections."""
419+
420+
# Create a cache with pre-populated network info
421+
cache_data = CacheData()
422+
cache_data.network_info[TEST_DEVICE_UID] = TEST_NETWORKING_INFO
423+
424+
mock_cache = AsyncMock()
425+
mock_cache.get.return_value = cache_data
426+
mock_cache.set = AsyncMock()
427+
428+
# Setup: MQTT and local connections succeed
429+
mock_mqtt_channel.subscribe.return_value = Mock()
430+
mock_local_channel.subscribe.return_value = Mock()
431+
432+
# Create V1Channel with the mock cache
433+
v1_channel = V1Channel(
434+
device_uid=TEST_DEVICE_UID,
435+
security_data=TEST_SECURITY_DATA,
436+
mqtt_channel=mock_mqtt_channel,
437+
local_session=mock_local_session,
438+
cache=mock_cache,
439+
)
440+
441+
# Subscribe - should use cached network info
442+
await v1_channel.subscribe(Mock())
443+
444+
# Verify both connections are established
445+
assert v1_channel.is_mqtt_connected
446+
assert v1_channel.is_local_connected
447+
448+
# Verify network info was NOT requested via MQTT (cache hit)
449+
mock_mqtt_channel.send_message.assert_not_called()
450+
451+
# Verify local session was created with the correct IP from cache
452+
mock_local_session.assert_called_once_with(mock_data.NETWORK_INFO["ip"])
453+
454+
# Verify cache was accessed but not updated (cache hit)
455+
mock_cache.get.assert_called_once()
456+
mock_cache.set.assert_not_called()
457+
458+
411459
# V1Channel edge cases tests
412460

413461

@@ -513,6 +561,7 @@ async def test_v1_channel_full_subscribe_and_command_flow(
513561
security_data=TEST_SECURITY_DATA,
514562
mqtt_channel=mock_mqtt_channel,
515563
local_session=mock_local_session,
564+
cache=InMemoryCache(),
516565
)
517566

518567
# Mock network info for local connection

0 commit comments

Comments
 (0)