Skip to content
8 changes: 1 addition & 7 deletions roborock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .util import get_next_int

_LOGGER = logging.getLogger(__name__)
KEEPALIVE = 60
KEEPALIVE = 70


class RoborockClient(ABC):
Expand Down Expand Up @@ -78,12 +78,6 @@ def should_keepalive(self) -> bool:
return False
return True

async def validate_connection(self) -> None:
if not self.should_keepalive():
self._logger.info("Resetting Roborock connection due to keepalive timeout")
await self.async_disconnect()
await self.async_connect()

async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any:
try:
response = await queue.async_get(self.queue_timeout)
Expand Down
105 changes: 92 additions & 13 deletions roborock/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from typing import Any

import paho.mqtt.client as mqtt
from paho.mqtt.enums import MQTTErrorCode

# Mypy is not seeing this for some reason. It wants me to use the depreciated ReasonCodes
from paho.mqtt.reasoncodes import ReasonCode # type: ignore

from .api import KEEPALIVE, RoborockClient
from .containers import DeviceData, UserData
Expand Down Expand Up @@ -67,7 +71,8 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
self._mqtt_client = _Mqtt()
self._mqtt_client.on_connect = self._mqtt_on_connect
self._mqtt_client.on_message = self._mqtt_on_message
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect
# Due to the incorrect ReasonCode, it is confused by typing
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect # type: ignore
if mqtt_params.tls:
self._mqtt_client.tls_set()

Expand All @@ -76,43 +81,57 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
self._mutex = Lock()
self._decoder: Decoder = create_mqtt_decoder(device_info.device.local_key)
self._encoder: Encoder = create_mqtt_encoder(device_info.device.local_key)
self.received_message_since_last_disconnect = False
self._topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"

def _mqtt_on_connect(self, *args, **kwargs):
_, __, ___, rc, ____ = args
def _mqtt_on_connect(
self,
client: mqtt.Client,
userdata: object,
flags: dict[str, int],
rc: ReasonCode,
properties: mqtt.Properties | None = None,
):
connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID)
if rc != mqtt.MQTT_ERR_SUCCESS:
message = f"Failed to connect ({mqtt.error_string(rc)})"
if rc.is_failure:
message = f"Failed to connect ({rc})"
self._logger.error(message)
if connection_queue:
connection_queue.set_exception(VacuumError(message))
else:
self._logger.debug("Failed to notify connect future, not in queue")
return
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
(result, mid) = self._mqtt_client.subscribe(topic)
(result, mid) = self._mqtt_client.subscribe(self._topic)
if result != 0:
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
message = f"Failed to subscribe ({str(rc)})"
self._logger.error(message)
if connection_queue:
connection_queue.set_exception(VacuumError(message))
return
self._logger.info(f"Subscribed to topic {topic}")
self._logger.info(f"Subscribed to topic {self._topic}")
if connection_queue:
connection_queue.set_result(True)

def _mqtt_on_message(self, *args, **kwargs):
self.received_message_since_last_disconnect = True
client, __, msg = args
try:
messages = self._decoder(msg.payload)
super().on_message_received(messages)
except Exception as ex:
self._logger.exception(ex)

def _mqtt_on_disconnect(self, *args, **kwargs):
_, __, rc, ___ = args
def _mqtt_on_disconnect(
self,
client: mqtt.Client,
data: object,
flags: dict[str, int],
rc: ReasonCode | None,
properties: mqtt.Properties | None = None,
):
try:
exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None
exc = RoborockException(str(rc)) if rc is not None and rc.is_failure else None
super().on_connection_lost(exc)
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
if connection_queue:
Expand All @@ -138,7 +157,7 @@ def _sync_disconnect(self) -> Any:

if rc != mqtt.MQTT_ERR_SUCCESS:
disconnected_future.cancel()
raise RoborockException(f"Failed to disconnect ({mqtt.error_string(rc)})")
raise RoborockException(f"Failed to disconnect ({str(rc)})")

return disconnected_future

Expand Down Expand Up @@ -178,3 +197,63 @@ def _send_msg_raw(self, msg: bytes) -> None:
)
if info.rc != mqtt.MQTT_ERR_SUCCESS:
raise RoborockException(f"Failed to publish ({mqtt.error_string(info.rc)})")

async def _unsubscribe(self) -> MQTTErrorCode:
"""Unsubscribe from the topic."""
loop = asyncio.get_running_loop()
(result, mid) = await loop.run_in_executor(None, self._mqtt_client.unsubscribe, self._topic)

if result != 0:
message = f"Failed to unsubscribe ({mqtt.error_string(result)})"
self._logger.error(message)
else:
self._logger.info(f"Unsubscribed from topic {self._topic}")
return result

async def _subscribe(self) -> MQTTErrorCode:
"""Subscribe to the topic."""
loop = asyncio.get_running_loop()
(result, mid) = await loop.run_in_executor(None, self._mqtt_client.subscribe, self._topic)

if result != 0:
message = f"Failed to subscribe ({mqtt.error_string(result)})"
self._logger.error(message)
else:
self._logger.info(f"Subscribed to topic {self._topic}")
return result

async def _reconnect(self) -> None:
"""Reconnect to the MQTT broker."""
await self.async_disconnect()
await self.async_connect()

async def _validate_connection(self) -> None:
"""Override the default validate connection to try to re-subscribe rather than disconnect.
When something seems to be wrong with our connection, we should follow the following steps:
1. Try to unsubscribe and resubscribe from the topic.
2. If we don't end up getting a message, we should completely disconnect and reconnect to the MQTT broker.
3. We will continue to try to disconnect and reconnect until we get a message.
4. If we get a message, the next time connection is lost, We will go back to step 1.
"""
# If we should no longer keep the current connection alive...
if not self.should_keepalive():
self._logger.info("Resetting Roborock connection due to keepalive timeout")
if not self.received_message_since_last_disconnect:
# If we have already tried to unsub and resub, and we are still in this state,
# we should try to reconnect.
return await self._reconnect()
try:
# Mark that we have tried to unsubscribe and resubscribe
self.received_message_since_last_disconnect = False
if await self._unsubscribe() != 0:
# If we fail to unsubscribe, reconnect to the broker
return await self._reconnect()
if await self._subscribe() != 0:
# If we fail to subscribe, reconnected to the broker.
return await self._reconnect()

except Exception: # noqa
# If we get any errors at all, we should just reconnect.
return await self._reconnect()
# Call connect to make sure everything is still in a good state.
await self.async_connect()
1 change: 1 addition & 0 deletions roborock/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Roborock exceptions."""

from __future__ import annotations


Expand Down
8 changes: 7 additions & 1 deletion roborock/version_1_apis/roborock_local_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ async def ping(self) -> None:
response_protocol=RoborockMessageProtocol.PING_RESPONSE,
)

async def _validate_connection(self) -> None:
if not self.should_keepalive():
self._logger.info("Resetting Roborock connection due to keepalive timeout")
await self.async_disconnect()
await self.async_connect()

def _send_msg_raw(self, data: bytes):
try:
if not self.transport:
Expand Down Expand Up @@ -172,7 +178,7 @@ async def _send_message(
method: str | None = None,
params: list | dict | int | None = None,
) -> RoborockMessage:
await self.validate_connection()
await self._validate_connection()
msg = self._encoder(roborock_message)
if method:
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
Expand Down
2 changes: 1 addition & 1 deletion roborock/version_1_apis/roborock_mqtt_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def _send_command(
)
self._logger.debug("Building message id %s for method %s", request_message.request_id, method)

await self.validate_connection()
await self._validate_connection()
request_id = request_message.request_id
response_protocol = (
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
Expand Down
2 changes: 1 addition & 1 deletion roborock/version_a01_apis/roborock_mqtt_client_a01.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)

async def _send_message(self, roborock_message: RoborockMessage):
await self.validate_connection()
await self._validate_connection()
response_protocol = RoborockMessageProtocol.RPC_RESPONSE

m = self._encoder(roborock_message)
Expand Down
Loading