Skip to content
2 changes: 1 addition & 1 deletion roborock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .util import get_next_int

_LOGGER = logging.getLogger(__name__)
KEEPALIVE = 60
KEEPALIVE = 70


class RoborockClient(ABC):
Expand Down
86 changes: 70 additions & 16 deletions roborock/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

import paho.mqtt.client as mqtt

# Mypy is not seeing this for some reason. It wants me to use the depreciated ReasonCodes
from paho.mqtt.reasoncodes import ReasonCode # type: ignore

from .api import KEEPALIVE, RoborockClient
from .containers import DeviceData, UserData
from .exceptions import RoborockException, VacuumError
from .exceptions import RoborockException, RoborockInvalidUserData, VacuumError
from .protocol import (
Decoder,
Encoder,
Expand Down Expand Up @@ -67,7 +70,8 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
self._mqtt_client = _Mqtt()
self._mqtt_client.on_connect = self._mqtt_on_connect
self._mqtt_client.on_message = self._mqtt_on_message
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect
# Due to the incorrect ReasonCode, it is confused by typing
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect # type: ignore
if mqtt_params.tls:
self._mqtt_client.tls_set()

Expand All @@ -76,43 +80,63 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
self._mutex = Lock()
self._decoder: Decoder = create_mqtt_decoder(device_info.device.local_key)
self._encoder: Encoder = create_mqtt_encoder(device_info.device.local_key)

def _mqtt_on_connect(self, *args, **kwargs):
_, __, ___, rc, ____ = args
self.previous_attempt_was_subscribe = False
self._topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"

def _mqtt_on_connect(
self,
client: mqtt.Client,
data: object,
flags: dict[str, int],
rc: ReasonCode,
properties: mqtt.Properties | None = None,
):
connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID)
if rc != mqtt.MQTT_ERR_SUCCESS:
message = f"Failed to connect ({mqtt.error_string(rc)})"
if rc.is_failure:
message = f"Failed to connect ({rc})"
self._logger.error(message)
if connection_queue:
connection_queue.set_exception(VacuumError(message))
# These are the ReasonCodes relating to authorization issues.
if rc.value in {24, 25, 133, 134, 135, 144}:
connection_queue.set_exception(
RoborockInvalidUserData("Failed to connect to mqtt. Invalid user data. Re-auth is needed.")
)
else:
connection_queue.set_exception(VacuumError(message))
else:
self._logger.debug("Failed to notify connect future, not in queue")
return
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
(result, mid) = self._mqtt_client.subscribe(topic)
(result, mid) = self._mqtt_client.subscribe(self._topic)
if result != 0:
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
message = f"Failed to subscribe ({str(rc)})"
self._logger.error(message)
if connection_queue:
connection_queue.set_exception(VacuumError(message))
return
self._logger.info(f"Subscribed to topic {topic}")
self._logger.info(f"Subscribed to topic {self._topic}")
if connection_queue:
connection_queue.set_result(True)

def _mqtt_on_message(self, *args, **kwargs):
self.previous_attempt_was_subscribe = False
client, __, msg = args
try:
messages = self._decoder(msg.payload)
super().on_message_received(messages)
except Exception as ex:
self._logger.exception(ex)

def _mqtt_on_disconnect(self, *args, **kwargs):
_, __, rc, ___ = args
def _mqtt_on_disconnect(
self,
client: mqtt.Client,
data: object,
flags: dict[str, int],
rc: ReasonCode | None,
properties: mqtt.Properties | None = None,
):
try:
exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None
exc = RoborockException(str(rc)) if rc is not None and rc.is_failure else None
super().on_connection_lost(exc)
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
if connection_queue:
Expand All @@ -138,7 +162,7 @@ def _sync_disconnect(self) -> Any:

if rc != mqtt.MQTT_ERR_SUCCESS:
disconnected_future.cancel()
raise RoborockException(f"Failed to disconnect ({mqtt.error_string(rc)})")
raise RoborockException(f"Failed to disconnect ({str(rc)})")

return disconnected_future

Expand Down Expand Up @@ -178,3 +202,33 @@ def _send_msg_raw(self, msg: bytes) -> None:
)
if info.rc != mqtt.MQTT_ERR_SUCCESS:
raise RoborockException(f"Failed to publish ({mqtt.error_string(info.rc)})")

async def validate_connection(self) -> None:
"""Override the default validate connection to try to re-subscribe rather than disconnect."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble following everything going on here because of:

  • the new state member variable
  • keepalive
  • interaction with the parent class, invoking parent calss validate connection 4 different ways

Could this have an opening description of (1) big picture whats supposed to be happening here then (2) at a detailed level what these flows are and how the 3 things interact? Im also unsure how this pare relates to the error handling issue above.

A suggestion for simplification would be to no longer call the parent class validate_connection logic. Instead move that down into the local and cloud clients and inline it here so there are fewer interactions with the parent class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a look now, i broke things out into some helper functions

if self.previous_attempt_was_subscribe:
# If we have already tried to unsub and resub, and we are still in this state,
# we should just do the normal validate connection.
return await super().validate_connection()
try:
if not self.should_keepalive():
self.previous_attempt_was_subscribe = True
loop = asyncio.get_running_loop()

self._logger.info("Resetting Roborock connection due to keepalive timeout")
(result, mid) = await loop.run_in_executor(None, self._mqtt_client.unsubscribe, self._topic)

if result != 0:
message = f"Failed to unsubscribe ({mqtt.error_string(result)})"
self._logger.error(message)
return await super().validate_connection()
(result, mid) = await loop.run_in_executor(None, self._mqtt_client.subscribe, self._topic)

if result != 0:
message = f"Failed to subscribe ({mqtt.error_string(result)})"
self._logger.error(message)
return await super().validate_connection()

self._logger.info(f"Subscribed to topic {self._topic}")
except Exception: # noqa
return await super().validate_connection()
await self.async_connect()
5 changes: 5 additions & 0 deletions roborock/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Roborock exceptions."""

from __future__ import annotations


Expand Down Expand Up @@ -76,3 +77,7 @@ class RoborockTooManyRequest(RoborockException):

class RoborockRateLimit(RoborockException):
"""Class for our rate limits exceptions."""


class RoborockInvalidUserData(RoborockException):
"""Class to state the user data is invalid (expired or manipulated)."""
6 changes: 3 additions & 3 deletions roborock/roborock_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import async_timeout

from .exceptions import VacuumError
from .exceptions import RoborockInvalidUserData, VacuumError


class RoborockFuture:
Expand All @@ -21,11 +21,11 @@ def _set_result(self, item: Any) -> None:
def set_result(self, item: Any) -> None:
self.loop.call_soon_threadsafe(self._set_result, item)

def _set_exception(self, exc: VacuumError) -> None:
def _set_exception(self, exc: VacuumError | RoborockInvalidUserData) -> None:
if not self.fut.cancelled():
self.fut.set_exception(exc)

def set_exception(self, exc: VacuumError) -> None:
def set_exception(self, exc: VacuumError | RoborockInvalidUserData) -> None:
self.loop.call_soon_threadsafe(self._set_exception, exc)

async def async_get(self, timeout: float | int) -> tuple[Any, VacuumError | None]:
Expand Down