Skip to content

Commit 5fe20ba

Browse files
authored
Add an aiomqtt based MQTT session module (#366)
* feat: Add an aiomqtt based MQTT session module * feat: Add exception handling and increased test coverage
1 parent ba422aa commit 5fe20ba

File tree

8 files changed

+564
-8
lines changed

8 files changed

+564
-8
lines changed

poetry.lock

Lines changed: 16 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ paho-mqtt = ">=1.6.1,<3.0.0"
3131
construct = "^2.10.57"
3232
vacuum-map-parser-roborock = "*"
3333
pyrate-limiter = "^3.7.0"
34+
aiomqtt = "^2.3.2"
3435

3536

3637
[build-system]

roborock/mqtt/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""This module contains the low level MQTT client for the Roborock vacuum cleaner.
2+
3+
This is not meant to be used directly, but rather as a base for the higher level
4+
modules.
5+
"""
6+
7+
__all__: list[str] = []

roborock/mqtt/roborock_session.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
"""An MQTT session for sending and receiving messages.
2+
3+
See create_mqtt_session for a factory function to create an MQTT session.
4+
5+
This is a thin wrapper around the async MQTT client that handles dispatching messages
6+
from a topic to a callback function, since the async MQTT client does not
7+
support this out of the box. It also handles the authentication process and
8+
receiving messages from the vacuum cleaner.
9+
"""
10+
11+
import asyncio
12+
import datetime
13+
import logging
14+
from collections.abc import Callable
15+
from contextlib import asynccontextmanager
16+
17+
import aiomqtt
18+
from aiomqtt import MqttError, TLSParameters
19+
20+
from .session import MqttParams, MqttSession, MqttSessionException
21+
22+
_LOGGER = logging.getLogger(__name__)
23+
_MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt")
24+
25+
KEEPALIVE = 60
26+
27+
# Exponential backoff parameters
28+
MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10)
29+
MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30)
30+
BACKOFF_MULTIPLIER = 1.5
31+
32+
33+
class RoborockMqttSession(MqttSession):
34+
"""An MQTT session for sending and receiving messages.
35+
36+
You can start a session invoking the start() method which will connect to
37+
the MQTT broker. A caller may subscribe to a topic, and the session keeps
38+
track of which callbacks to invoke for each topic.
39+
40+
The client is run as a background task that will run until shutdown. Once
41+
connected, the client will wait for messages to be received in a loop. If
42+
the connection is lost, the client will be re-created and reconnected. There
43+
is backoff to avoid spamming the broker with connection attempts. The client
44+
will automatically re-establish any subscriptions when the connection is
45+
re-established.
46+
"""
47+
48+
def __init__(self, params: MqttParams):
49+
self._params = params
50+
self._background_task: asyncio.Task[None] | None = None
51+
self._healthy = False
52+
self._backoff = MIN_BACKOFF_INTERVAL
53+
self._client: aiomqtt.Client | None = None
54+
self._client_lock = asyncio.Lock()
55+
self._listeners: dict[str, list[Callable[[bytes], None]]] = {}
56+
57+
@property
58+
def connected(self) -> bool:
59+
"""True if the session is connected to the broker."""
60+
return self._healthy
61+
62+
async def start(self) -> None:
63+
"""Start the MQTT session.
64+
65+
This has special behavior for the first connection attempt where any
66+
failures are raised immediately. This is to allow the caller to
67+
handle the failure and retry if desired itself. Once connected,
68+
the session will retry connecting in the background.
69+
"""
70+
start_future: asyncio.Future[None] = asyncio.Future()
71+
loop = asyncio.get_event_loop()
72+
self._background_task = loop.create_task(self._run_task(start_future))
73+
try:
74+
await start_future
75+
except MqttError as err:
76+
raise MqttSessionException(f"Error starting MQTT session: {err}") from err
77+
except Exception as err:
78+
raise MqttSessionException(f"Unexpected error starting session: {err}") from err
79+
else:
80+
_LOGGER.debug("MQTT session started successfully")
81+
82+
async def close(self) -> None:
83+
"""Cancels the MQTT loop and shutdown the client library."""
84+
if self._background_task:
85+
self._background_task.cancel()
86+
try:
87+
await self._background_task
88+
except asyncio.CancelledError:
89+
pass
90+
async with self._client_lock:
91+
if self._client:
92+
await self._client.close()
93+
94+
self._healthy = False
95+
96+
async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
97+
"""Run the MQTT loop."""
98+
_LOGGER.info("Starting MQTT session")
99+
while True:
100+
try:
101+
async with self._mqtt_client(self._params) as client:
102+
# Reset backoff once we've successfully connected
103+
self._backoff = MIN_BACKOFF_INTERVAL
104+
self._healthy = True
105+
if start_future:
106+
start_future.set_result(None)
107+
start_future = None
108+
109+
await self._process_message_loop(client)
110+
111+
except MqttError as err:
112+
if start_future:
113+
_LOGGER.info("MQTT error starting session: %s", err)
114+
start_future.set_exception(err)
115+
return
116+
_LOGGER.info("MQTT error: %s", err)
117+
except asyncio.CancelledError as err:
118+
if start_future:
119+
_LOGGER.debug("MQTT loop was cancelled")
120+
start_future.set_exception(err)
121+
_LOGGER.debug("MQTT loop was cancelled whiel starting")
122+
return
123+
# Catch exceptions to avoid crashing the loop
124+
# and to allow the loop to retry.
125+
except Exception as err:
126+
# This error is thrown when the MQTT loop is cancelled
127+
# and the generator is not stopped.
128+
if "generator didn't stop" in str(err):
129+
_LOGGER.debug("MQTT loop was cancelled")
130+
return
131+
if start_future:
132+
_LOGGER.error("Uncaught error starting MQTT session: %s", err)
133+
start_future.set_exception(err)
134+
return
135+
_LOGGER.error("Uncaught error during MQTT session: %s", err)
136+
137+
self._healthy = False
138+
_LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds())
139+
await asyncio.sleep(self._backoff.total_seconds())
140+
self._backoff = min(self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL)
141+
142+
@asynccontextmanager
143+
async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
144+
"""Connect to the MQTT broker and listen for messages."""
145+
_LOGGER.debug("Connecting to %s:%s for %s", params.host, params.port, params.username)
146+
try:
147+
async with aiomqtt.Client(
148+
hostname=params.host,
149+
port=params.port,
150+
username=params.username,
151+
password=params.password,
152+
keepalive=KEEPALIVE,
153+
protocol=aiomqtt.ProtocolVersion.V5,
154+
tls_params=TLSParameters() if params.tls else None,
155+
timeout=params.timeout,
156+
logger=_MQTT_LOGGER,
157+
) as client:
158+
_LOGGER.debug("Connected to MQTT broker")
159+
# Re-establish any existing subscriptions
160+
async with self._client_lock:
161+
self._client = client
162+
for topic in self._listeners:
163+
_LOGGER.debug("Re-establising subscription to topic %s", topic)
164+
# TODO: If this fails it will break the whole connection. Make
165+
# this retry again in the background with backoff.
166+
await client.subscribe(topic)
167+
168+
yield client
169+
finally:
170+
async with self._client_lock:
171+
self._client = None
172+
173+
async def _process_message_loop(self, client: aiomqtt.Client) -> None:
174+
_LOGGER.debug("client=%s", client)
175+
_LOGGER.debug("Processing MQTT messages: %s", client.messages)
176+
async for message in client.messages:
177+
_LOGGER.debug("Received message: %s", message)
178+
for listener in self._listeners.get(message.topic.value, []):
179+
try:
180+
listener(message.payload)
181+
except asyncio.CancelledError:
182+
raise
183+
except Exception as e:
184+
_LOGGER.error("Uncaught exception in subscriber callback: %s", e)
185+
186+
async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
187+
"""Subscribe to messages on the specified topic and invoke the callback for new messages.
188+
189+
The callback will be called with the message payload as a bytes object. The callback
190+
should not block since it runs in the async loop. It should not raise any exceptions.
191+
192+
The returned callable unsubscribes from the topic when called.
193+
"""
194+
_LOGGER.debug("Subscribing to topic %s", topic)
195+
if topic not in self._listeners:
196+
self._listeners[topic] = []
197+
self._listeners[topic].append(callback)
198+
199+
async with self._client_lock:
200+
if self._client:
201+
_LOGGER.debug("Establishing subscription to topic %s", topic)
202+
try:
203+
await self._client.subscribe(topic)
204+
except MqttError as err:
205+
raise MqttSessionException(f"Error subscribing to topic: {err}") from err
206+
else:
207+
_LOGGER.debug("Client not connected, will establish subscription later")
208+
209+
return lambda: self._listeners[topic].remove(callback)
210+
211+
async def publish(self, topic: str, message: bytes) -> None:
212+
"""Publish a message on the topic."""
213+
_LOGGER.debug("Sending message to topic %s: %s", topic, message)
214+
client: aiomqtt.Client
215+
async with self._client_lock:
216+
if self._client is None:
217+
raise MqttSessionException("Could not publish message, MQTT client not connected")
218+
client = self._client
219+
try:
220+
await client.publish(topic, message)
221+
except MqttError as err:
222+
raise MqttSessionException(f"Error publishing message: {err}") from err
223+
224+
225+
async def create_mqtt_session(params: MqttParams) -> MqttSession:
226+
"""Create an MQTT session.
227+
228+
This function is a factory for creating an MQTT session. This will
229+
raise an exception if initial attempt to connect fails. Once connected,
230+
the session will retry connecting on failure in the background.
231+
"""
232+
session = RoborockMqttSession(params)
233+
await session.start()
234+
return session

roborock/mqtt/session.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""An MQTT session for sending and receiving messages."""
2+
3+
from abc import ABC, abstractmethod
4+
from collections.abc import Callable
5+
from dataclasses import dataclass
6+
7+
from roborock.exceptions import RoborockException
8+
9+
DEFAULT_TIMEOUT = 30.0
10+
11+
12+
@dataclass
13+
class MqttParams:
14+
"""MQTT parameters for the connection."""
15+
16+
host: str
17+
"""MQTT host to connect to."""
18+
19+
port: int
20+
"""MQTT port to connect to."""
21+
22+
tls: bool
23+
"""Use TLS for the connection."""
24+
25+
username: str
26+
"""MQTT username to use for authentication."""
27+
28+
password: str
29+
"""MQTT password to use for authentication."""
30+
31+
timeout: float = DEFAULT_TIMEOUT
32+
"""Timeout for communications with the broker in seconds."""
33+
34+
35+
class MqttSession(ABC):
36+
"""An MQTT session for sending and receiving messages."""
37+
38+
@property
39+
@abstractmethod
40+
def connected(self) -> bool:
41+
"""True if the session is connected to the broker."""
42+
43+
@abstractmethod
44+
async def subscribe(self, device_id: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
45+
"""Invoke the callback when messages are received on the topic.
46+
47+
The returned callable unsubscribes from the topic when called.
48+
"""
49+
50+
@abstractmethod
51+
async def publish(self, topic: str, message: bytes) -> None:
52+
"""Publish a message on the specified topic.
53+
54+
This will raise an exception if the message could not be sent.
55+
"""
56+
57+
@abstractmethod
58+
async def close(self) -> None:
59+
"""Cancels the mqtt loop"""
60+
61+
62+
class MqttSessionException(RoborockException):
63+
""" "Raised when there is an error communicating with MQTT."""

tests/conftest.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ class FakeSocketHandler:
3333
handle request callback handles the incoming requests and prepares the responses.
3434
"""
3535

36-
def __init__(self, handle_request: RequestHandler) -> None:
36+
def __init__(self, handle_request: RequestHandler, response_queue: Queue[bytes]) -> None:
3737
self.response_buf = io.BytesIO()
3838
self.handle_request = handle_request
39+
self.response_queue = response_queue
3940

4041
def pending(self) -> int:
4142
"""Return the number of bytes in the response buffer."""
@@ -62,9 +63,17 @@ def handle_socket_send(self, client_request: bytes) -> int:
6263
# The buffer will be emptied when the client calls recv() on the socket
6364
_LOGGER.debug("Queued: 0x%s", response.hex())
6465
self.response_buf.write(response)
65-
6666
return len(client_request)
6767

68+
def push_response(self) -> None:
69+
"""Push a response to the client."""
70+
if not self.response_queue.empty():
71+
response = self.response_queue.get()
72+
# Enqueue a response to be sent back to the client in the buffer.
73+
# The buffer will be emptied when the client calls recv() on the socket
74+
_LOGGER.debug("Queued: 0x%s", response.hex())
75+
self.response_buf.write(response)
76+
6877

6978
@pytest.fixture(name="received_requests")
7079
def received_requests_fixture() -> Queue[bytes]:
@@ -97,9 +106,9 @@ def handle_request(client_request: bytes) -> bytes | None:
97106

98107

99108
@pytest.fixture(name="fake_socket_handler")
100-
def fake_socket_handler_fixture(request_handler: RequestHandler) -> FakeSocketHandler:
109+
def fake_socket_handler_fixture(request_handler: RequestHandler, response_queue: Queue[bytes]) -> FakeSocketHandler:
101110
"""Fixture that creates a fake MQTT broker."""
102-
return FakeSocketHandler(request_handler)
111+
return FakeSocketHandler(request_handler, response_queue)
103112

104113

105114
@pytest.fixture(name="mock_sock")

0 commit comments

Comments
 (0)