Skip to content

Commit 599da6c

Browse files
Lash-Lallenporter
andauthored
fix: mqtt error handling (#460)
* fix: handle auth expiring * fix: str some other rcs * fix: address some comments * fix: test * feat: add seperate validate connection for the cloud api and bump keepalive * chore: remove extra exception * chore: add else back * Update roborock/cloud_api.py Co-authored-by: Allen Porter <[email protected]> * chore: clean up * fix: changes * chore: inverse boolean logic to match variable naming --------- Co-authored-by: Allen Porter <[email protected]>
1 parent 420e4ae commit 599da6c

File tree

6 files changed

+103
-23
lines changed

6 files changed

+103
-23
lines changed

roborock/api.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .util import get_next_int
2424

2525
_LOGGER = logging.getLogger(__name__)
26-
KEEPALIVE = 60
26+
KEEPALIVE = 70
2727

2828

2929
class RoborockClient(ABC):
@@ -78,12 +78,6 @@ def should_keepalive(self) -> bool:
7878
return False
7979
return True
8080

81-
async def validate_connection(self) -> None:
82-
if not self.should_keepalive():
83-
self._logger.info("Resetting Roborock connection due to keepalive timeout")
84-
await self.async_disconnect()
85-
await self.async_connect()
86-
8781
async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any:
8882
try:
8983
response = await queue.async_get(self.queue_timeout)

roborock/cloud_api.py

Lines changed: 92 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from typing import Any
99

1010
import paho.mqtt.client as mqtt
11+
from paho.mqtt.enums import MQTTErrorCode
12+
13+
# Mypy is not seeing this for some reason. It wants me to use the depreciated ReasonCodes
14+
from paho.mqtt.reasoncodes import ReasonCode # type: ignore
1115

1216
from .api import KEEPALIVE, RoborockClient
1317
from .containers import DeviceData, UserData
@@ -67,7 +71,8 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
6771
self._mqtt_client = _Mqtt()
6872
self._mqtt_client.on_connect = self._mqtt_on_connect
6973
self._mqtt_client.on_message = self._mqtt_on_message
70-
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect
74+
# Due to the incorrect ReasonCode, it is confused by typing
75+
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect # type: ignore
7176
if mqtt_params.tls:
7277
self._mqtt_client.tls_set()
7378

@@ -76,43 +81,57 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
7681
self._mutex = Lock()
7782
self._decoder: Decoder = create_mqtt_decoder(device_info.device.local_key)
7883
self._encoder: Encoder = create_mqtt_encoder(device_info.device.local_key)
84+
self.received_message_since_last_disconnect = False
85+
self._topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
7986

80-
def _mqtt_on_connect(self, *args, **kwargs):
81-
_, __, ___, rc, ____ = args
87+
def _mqtt_on_connect(
88+
self,
89+
client: mqtt.Client,
90+
userdata: object,
91+
flags: dict[str, int],
92+
rc: ReasonCode,
93+
properties: mqtt.Properties | None = None,
94+
):
8295
connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID)
83-
if rc != mqtt.MQTT_ERR_SUCCESS:
84-
message = f"Failed to connect ({mqtt.error_string(rc)})"
96+
if rc.is_failure:
97+
message = f"Failed to connect ({rc})"
8598
self._logger.error(message)
8699
if connection_queue:
87100
connection_queue.set_exception(VacuumError(message))
88101
else:
89102
self._logger.debug("Failed to notify connect future, not in queue")
90103
return
91104
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
92-
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
93-
(result, mid) = self._mqtt_client.subscribe(topic)
105+
(result, mid) = self._mqtt_client.subscribe(self._topic)
94106
if result != 0:
95-
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
107+
message = f"Failed to subscribe ({str(rc)})"
96108
self._logger.error(message)
97109
if connection_queue:
98110
connection_queue.set_exception(VacuumError(message))
99111
return
100-
self._logger.info(f"Subscribed to topic {topic}")
112+
self._logger.info(f"Subscribed to topic {self._topic}")
101113
if connection_queue:
102114
connection_queue.set_result(True)
103115

104116
def _mqtt_on_message(self, *args, **kwargs):
117+
self.received_message_since_last_disconnect = True
105118
client, __, msg = args
106119
try:
107120
messages = self._decoder(msg.payload)
108121
super().on_message_received(messages)
109122
except Exception as ex:
110123
self._logger.exception(ex)
111124

112-
def _mqtt_on_disconnect(self, *args, **kwargs):
113-
_, __, rc, ___ = args
125+
def _mqtt_on_disconnect(
126+
self,
127+
client: mqtt.Client,
128+
data: object,
129+
flags: dict[str, int],
130+
rc: ReasonCode | None,
131+
properties: mqtt.Properties | None = None,
132+
):
114133
try:
115-
exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None
134+
exc = RoborockException(str(rc)) if rc is not None and rc.is_failure else None
116135
super().on_connection_lost(exc)
117136
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
118137
if connection_queue:
@@ -138,7 +157,7 @@ def _sync_disconnect(self) -> Any:
138157

139158
if rc != mqtt.MQTT_ERR_SUCCESS:
140159
disconnected_future.cancel()
141-
raise RoborockException(f"Failed to disconnect ({mqtt.error_string(rc)})")
160+
raise RoborockException(f"Failed to disconnect ({str(rc)})")
142161

143162
return disconnected_future
144163

@@ -178,3 +197,63 @@ def _send_msg_raw(self, msg: bytes) -> None:
178197
)
179198
if info.rc != mqtt.MQTT_ERR_SUCCESS:
180199
raise RoborockException(f"Failed to publish ({mqtt.error_string(info.rc)})")
200+
201+
async def _unsubscribe(self) -> MQTTErrorCode:
202+
"""Unsubscribe from the topic."""
203+
loop = asyncio.get_running_loop()
204+
(result, mid) = await loop.run_in_executor(None, self._mqtt_client.unsubscribe, self._topic)
205+
206+
if result != 0:
207+
message = f"Failed to unsubscribe ({mqtt.error_string(result)})"
208+
self._logger.error(message)
209+
else:
210+
self._logger.info(f"Unsubscribed from topic {self._topic}")
211+
return result
212+
213+
async def _subscribe(self) -> MQTTErrorCode:
214+
"""Subscribe to the topic."""
215+
loop = asyncio.get_running_loop()
216+
(result, mid) = await loop.run_in_executor(None, self._mqtt_client.subscribe, self._topic)
217+
218+
if result != 0:
219+
message = f"Failed to subscribe ({mqtt.error_string(result)})"
220+
self._logger.error(message)
221+
else:
222+
self._logger.info(f"Subscribed to topic {self._topic}")
223+
return result
224+
225+
async def _reconnect(self) -> None:
226+
"""Reconnect to the MQTT broker."""
227+
await self.async_disconnect()
228+
await self.async_connect()
229+
230+
async def _validate_connection(self) -> None:
231+
"""Override the default validate connection to try to re-subscribe rather than disconnect.
232+
When something seems to be wrong with our connection, we should follow the following steps:
233+
1. Try to unsubscribe and resubscribe from the topic.
234+
2. If we don't end up getting a message, we should completely disconnect and reconnect to the MQTT broker.
235+
3. We will continue to try to disconnect and reconnect until we get a message.
236+
4. If we get a message, the next time connection is lost, We will go back to step 1.
237+
"""
238+
# If we should no longer keep the current connection alive...
239+
if not self.should_keepalive():
240+
self._logger.info("Resetting Roborock connection due to keepalive timeout")
241+
if not self.received_message_since_last_disconnect:
242+
# If we have already tried to unsub and resub, and we are still in this state,
243+
# we should try to reconnect.
244+
return await self._reconnect()
245+
try:
246+
# Mark that we have tried to unsubscribe and resubscribe
247+
self.received_message_since_last_disconnect = False
248+
if await self._unsubscribe() != 0:
249+
# If we fail to unsubscribe, reconnect to the broker
250+
return await self._reconnect()
251+
if await self._subscribe() != 0:
252+
# If we fail to subscribe, reconnected to the broker.
253+
return await self._reconnect()
254+
255+
except Exception: # noqa
256+
# If we get any errors at all, we should just reconnect.
257+
return await self._reconnect()
258+
# Call connect to make sure everything is still in a good state.
259+
await self.async_connect()

roborock/exceptions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Roborock exceptions."""
2+
23
from __future__ import annotations
34

45

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ async def ping(self) -> None:
138138
response_protocol=RoborockMessageProtocol.PING_RESPONSE,
139139
)
140140

141+
async def _validate_connection(self) -> None:
142+
if not self.should_keepalive():
143+
self._logger.info("Resetting Roborock connection due to keepalive timeout")
144+
await self.async_disconnect()
145+
await self.async_connect()
146+
141147
def _send_msg_raw(self, data: bytes):
142148
try:
143149
if not self.transport:
@@ -172,7 +178,7 @@ async def _send_message(
172178
method: str | None = None,
173179
params: list | dict | int | None = None,
174180
) -> RoborockMessage:
175-
await self.validate_connection()
181+
await self._validate_connection()
176182
msg = self._encoder(roborock_message)
177183
if method:
178184
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def _send_command(
5151
)
5252
self._logger.debug("Building message id %s for method %s", request_message.request_id, method)
5353

54-
await self.validate_connection()
54+
await self._validate_connection()
5555
request_id = request_message.request_id
5656
response_protocol = (
5757
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
4141

4242
async def _send_message(self, roborock_message: RoborockMessage):
43-
await self.validate_connection()
43+
await self._validate_connection()
4444
response_protocol = RoborockMessageProtocol.RPC_RESPONSE
4545

4646
m = self._encoder(roborock_message)

0 commit comments

Comments
 (0)