Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
417 changes: 417 additions & 0 deletions roborock/b01_containers.py

Large diffs are not rendered by default.

15 changes: 1 addition & 14 deletions roborock/clean_modes.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,8 @@
from __future__ import annotations

from enum import StrEnum

from roborock import DeviceFeatures


class RoborockModeEnum(StrEnum):
"""A custom StrEnum that also stores an integer code for each member."""

code: int

def __new__(cls, value: str, code: int) -> RoborockModeEnum:
"""Creates a new enum member."""
member = str.__new__(cls, value)
member._value_ = value
member.code = code
return member
from .code_mappings import RoborockModeEnum


class CleanModes(RoborockModeEnum):
Expand Down
22 changes: 21 additions & 1 deletion roborock/code_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from collections import namedtuple
from enum import Enum, IntEnum
from enum import Enum, IntEnum, StrEnum

_LOGGER = logging.getLogger(__name__)
completed_warnings = set()
Expand Down Expand Up @@ -51,6 +51,26 @@ def items(cls: type[RoborockEnum]):
return cls.as_dict().items()


class RoborockModeEnum(StrEnum):
"""A custom StrEnum that also stores an integer code for each member."""

code: int

def __new__(cls, value: str, code: int) -> RoborockModeEnum:
"""Creates a new enum member."""
member = str.__new__(cls, value)
member._value_ = value
member.code = code
return member

@classmethod
def from_code(cls, code: int):
for member in cls:
if member.code == code:
return member
raise ValueError(f"{code} is not a valid code for {cls.__name__}")


ProductInfo = namedtuple("ProductInfo", ["nickname", "short_models"])


Expand Down
3 changes: 3 additions & 0 deletions roborock/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
RoborockFanSpeedSaros10R,
RoborockFinishReason,
RoborockInCleaning,
RoborockModeEnum,
RoborockMopIntensityCode,
RoborockMopIntensityP10,
RoborockMopIntensityQ7Max,
Expand Down Expand Up @@ -120,6 +121,8 @@ def _convert_to_class_obj(class_type: type, value):
return {k: RoborockBase._convert_to_class_obj(value_type, v) for k, v in value.items()}
if issubclass(class_type, RoborockBase):
return class_type.from_dict(value)
if issubclass(class_type, RoborockModeEnum):
return class_type.from_code(value)
if class_type is Any:
return value
return class_type(value) # type: ignore[call-arg]
Expand Down
7 changes: 2 additions & 5 deletions roborock/devices/b01_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from __future__ import annotations

import logging
from typing import Any

from roborock.protocols.b01_protocol import (
CommandType,
ParamsType,
decode_rpc_response,
encode_mqtt_payload,
)

Expand All @@ -22,9 +20,8 @@ async def send_decoded_command(
dps: int,
command: CommandType,
params: ParamsType,
) -> dict[int, Any]:
) -> None:
"""Send a command on the MQTT channel and get a decoded response."""
_LOGGER.debug("Sending MQTT command: %s", params)
roborock_message = encode_mqtt_payload(dps, command, params)
response = await mqtt_channel.send_message(roborock_message)
return decode_rpc_response(response) # type: ignore[return-value]
await mqtt_channel.send_message_no_wait(roborock_message)
4 changes: 2 additions & 2 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
case _:
raise NotImplementedError(f"Device {device.name} has unsupported category {product.category}")
case DeviceVersion.B01:
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
traits.append(B01PropsApi(mqtt_channel))
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
traits.append(B01PropsApi(channel))
case _:
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
return RoborockDevice(device, channel, traits)
Expand Down
12 changes: 11 additions & 1 deletion roborock/devices/mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def message_handler(payload: bytes) -> None:
return
for message in messages:
_LOGGER.debug("Received message: %s", message)
asyncio.create_task(self._resolve_future_with_lock(message))
if message.version == b"1.0":
asyncio.create_task(self._resolve_future_with_lock(message))
try:
callback(message)
except Exception as e:
Expand All @@ -95,6 +96,15 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
else:
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)

async def send_message_no_wait(self, message: RoborockMessage) -> None:
"""Send a command message without waiting for a response."""
try:
encoded_msg = self._encoder(message)
await self._mqtt_session.publish(self._publish_topic, encoded_msg)
except Exception:
logging.exception("Uncaught error sending command")
raise

async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
"""Send a command message and wait for the response message.

Expand Down
7 changes: 4 additions & 3 deletions roborock/devices/traits/b01/props.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
from typing import Any

from roborock import RoborockB01Methods
from roborock.roborock_message import RoborockB01Props
Expand All @@ -26,6 +25,8 @@ def __init__(self, channel: MqttChannel) -> None:
"""Initialize the B01Props API."""
self._channel = channel

async def query_values(self, props: list[RoborockB01Props]) -> dict[int, Any]:
async def query_values(self, props: list[RoborockB01Props]) -> None:
"""Query the device for the values of the given Dyad protocols."""
return await send_decoded_command(self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params=props)
return await send_decoded_command(
self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params={"property": props}
)
2 changes: 1 addition & 1 deletion roborock/protocols/b01_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def encode_mqtt_payload(dps: int, command: CommandType, params: ParamsType) -> RoborockMessage:
"""Encode payload for B01 commands over MQTT."""
dps_data = {"dps": {dps: {"method": command, "params": params or []}}}
dps_data = {"dps": {dps: {"method": str(command), "msgId": "1751755654575", "params": params or []}}}
payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size)
return RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
Expand Down