Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 23 additions & 38 deletions roborock/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import json
import logging
import threading
from collections.abc import Callable
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, cast
Expand All @@ -43,6 +44,8 @@
from roborock.devices.cache import Cache, CacheData
from roborock.devices.device import RoborockDevice
from roborock.devices.device_manager import DeviceManager, create_device_manager, create_home_data_api
from roborock.devices.traits import Trait
from roborock.devices.traits.v1 import V1TraitMixin
from roborock.protocol import MessageParser
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
from roborock.web_api import RoborockApiClient
Expand Down Expand Up @@ -377,23 +380,30 @@ async def execute_scene(ctx, scene_id):
await client.execute_scene(cache_data.user_data, scene_id)


async def _v1_trait(context: RoborockContext, device_id: str, display_func: Callable[[], V1TraitMixin]) -> Trait:
device_manager = await context.get_device_manager()
device = await device_manager.get_device(device_id)
if device.v1_properties is None:
raise RoborockException(f"Device {device.name} does not support V1 protocol")

trait = display_func(device.v1_properties)
await trait.refresh()
return trait


async def _display_v1_trait(context: RoborockContext, device_id: str, display_func: Callable[[], Trait]) -> None:
trait = await _v1_trait(context, device_id, display_func)
click.echo(dump_json(trait.as_dict()))


@session.command()
@click.option("--device_id", required=True)
@click.pass_context
@async_command
async def status(ctx, device_id: str):
"""Get device status."""
context: RoborockContext = ctx.obj

device_manager = await context.get_device_manager()
device = await device_manager.get_device(device_id)

if not (status_trait := device.traits.get("status")):
click.echo(f"Device {device.name} does not have a status trait")
return

status_result = await status_trait.get_status()
click.echo(dump_json(status_result.as_dict()))
await _display_v1_trait(context, device_id, lambda v1: v1.status)


@session.command()
Expand All @@ -403,15 +413,7 @@ async def status(ctx, device_id: str):
async def clean_summary(ctx, device_id: str):
"""Get device clean summary."""
context: RoborockContext = ctx.obj

device_manager = await context.get_device_manager()
device = await device_manager.get_device(device_id)
if not (clean_summary_trait := device.traits.get("clean_summary")):
click.echo(f"Device {device.name} does not have a clean summary trait")
return

clean_summary_result = await clean_summary_trait.get_clean_summary()
click.echo(dump_json(clean_summary_result.as_dict()))
await _display_v1_trait(context, device_id, lambda v1: v1.clean_summary)


@session.command()
Expand All @@ -421,17 +423,7 @@ async def clean_summary(ctx, device_id: str):
async def volume(ctx, device_id: str):
"""Get device volume."""
context: RoborockContext = ctx.obj

device_manager = await context.get_device_manager()
device = await device_manager.get_device(device_id)

if not (volume_trait := device.traits.get("sound_volume")):
click.echo(f"Device {device.name} does not have a volume trait")
return

volume_result = await volume_trait.get_volume()
click.echo(f"Device {device_id} volume:")
click.echo(volume_result)
await _display_v1_trait(context, device_id, lambda v1: v1.sound_volume)


@session.command()
Expand All @@ -442,14 +434,7 @@ async def volume(ctx, device_id: str):
async def set_volume(ctx, device_id: str, volume: int):
"""Set the devicevolume."""
context: RoborockContext = ctx.obj

device_manager = await context.get_device_manager()
device = await device_manager.get_device(device_id)

if not (volume_trait := device.traits.get("sound_volume")):
click.echo(f"Device {device.name} does not have a volume trait")
return

volume_trait = await _v1_trait(context, device_id, lambda v1: v1.sound_volume)
await volume_trait.set_volume(volume)
click.echo(f"Set Device {device_id} volume to {volume}")

Expand Down
23 changes: 10 additions & 13 deletions roborock/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import logging
from abc import ABC
from collections.abc import Callable, Mapping
from types import MappingProxyType
from collections.abc import Callable

from roborock.containers import HomeDataDevice
from roborock.roborock_message import RoborockMessage

from .channel import Channel
from .traits.trait import Trait
from .traits import Trait
from .traits.traits_mixin import TraitsMixin

_LOGGER = logging.getLogger(__name__)

Expand All @@ -22,33 +22,35 @@
]


class RoborockDevice(ABC):
class RoborockDevice(ABC, TraitsMixin):
"""A generic channel for establishing a connection with a Roborock device.

Individual channel implementations have their own methods for speaking to
the device that hide some of the protocol specific complexity, but they
are still specialized for the device type and protocol.

Attributes of the device are exposed through traits, which are mixed in
through the TraitsMixin class. Traits are optional and may not be present
on all devices.
"""

def __init__(
self,
device_info: HomeDataDevice,
channel: Channel,
traits: list[Trait],
trait: Trait,
) -> None:
"""Initialize the RoborockDevice.

The device takes ownership of the channel for communication with the device.
Use `connect()` to establish the connection, which will set up the appropriate
protocol channel. Use `close()` to clean up all connections.
"""
TraitsMixin.__init__(self, trait)
self._duid = device_info.duid
self._name = device_info.name
self._channel = channel
self._unsub: Callable[[], None] | None = None
self._trait_map = {trait.name: trait for trait in traits}
if len(self._trait_map) != len(traits):
raise ValueError("Duplicate trait names found in traits list")

@property
def duid(self) -> str:
Expand Down Expand Up @@ -81,8 +83,3 @@ async def close(self) -> None:
def _on_message(self, message: RoborockMessage) -> None:
"""Handle incoming messages from the device."""
_LOGGER.debug("Received message from device: %s", message)

@property
def traits(self) -> Mapping[str, Trait]:
"""Return the traits of the device."""
return MappingProxyType(self._trait_map)
32 changes: 7 additions & 25 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import aiohttp

from roborock.code_mappings import RoborockCategory
from roborock.containers import (
HomeData,
HomeDataDevice,
Expand All @@ -23,14 +22,7 @@
from .cache import Cache, NoCache
from .channel import Channel
from .mqtt_channel import create_mqtt_channel
from .traits.b01.props import B01PropsApi
from .traits.clean_summary import CleanSummaryTrait
from .traits.dnd import DoNotDisturbTrait
from .traits.dyad import DyadApi
from .traits.sound_volume import SoundVolumeTrait
from .traits.status import StatusTrait
from .traits.trait import Trait
from .traits.zeo import ZeoApi
from .traits import Trait, a01, b01, v1
from .v1_channel import create_v1_channel

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -153,30 +145,20 @@ async def create_device_manager(

def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
channel: Channel
traits: list[Trait] = []
# TODO: Define a registration mechanism/factory for v1 traits
trait: Trait
match device.pv:
case DeviceVersion.V1:
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
traits.append(StatusTrait(product, channel.rpc_channel))
traits.append(DoNotDisturbTrait(channel.rpc_channel))
traits.append(CleanSummaryTrait(channel.rpc_channel))
traits.append(SoundVolumeTrait(channel.rpc_channel))
trait = v1.create(product, channel.rpc_channel)
case DeviceVersion.A01:
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
match product.category:
case RoborockCategory.WET_DRY_VAC:
traits.append(DyadApi(mqtt_channel))
case RoborockCategory.WASHING_MACHINE:
traits.append(ZeoApi(mqtt_channel))
case _:
raise NotImplementedError(f"Device {device.name} has unsupported category {product.category}")
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
trait = a01.create(product, channel)
case DeviceVersion.B01:
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
traits.append(B01PropsApi(channel))
trait = b01.create(channel)
case _:
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
return RoborockDevice(device, channel, traits)
return RoborockDevice(device, channel, trait)

manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session, cache=cache)
await manager.discover_devices()
Expand Down
15 changes: 15 additions & 0 deletions roborock/devices/traits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Module for device traits."""

from abc import ABC

__all__ = [
"Trait",
"traits_mixin",
"v1",
"a01",
"b01",
]


class Trait(ABC):
"""Base class for all traits."""
61 changes: 61 additions & 0 deletions roborock/devices/traits/a01/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Any

from roborock.containers import HomeDataProduct, RoborockCategory
from roborock.devices.a01_channel import send_decoded_command
from roborock.devices.mqtt_channel import MqttChannel
from roborock.devices.traits import Trait
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol

__init__ = [
"DyadApi",
"ZeoApi",
]


class DyadApi(Trait):
"""API for interacting with Dyad devices."""

def __init__(self, channel: MqttChannel) -> None:
"""Initialize the Dyad API."""
self._channel = channel

async def query_values(self, protocols: list[RoborockDyadDataProtocol]) -> dict[RoborockDyadDataProtocol, Any]:
"""Query the device for the values of the given Dyad protocols."""
params = {RoborockDyadDataProtocol.ID_QUERY: [int(p) for p in protocols]}
return await send_decoded_command(self._channel, params)

async def set_value(self, protocol: RoborockDyadDataProtocol, value: Any) -> dict[RoborockDyadDataProtocol, Any]:
"""Set a value for a specific protocol on the device."""
params = {protocol: value}
return await send_decoded_command(self._channel, params)


class ZeoApi(Trait):
"""API for interacting with Zeo devices."""

name = "zeo"

def __init__(self, channel: MqttChannel) -> None:
"""Initialize the Zeo API."""
self._channel = channel

async def query_values(self, protocols: list[RoborockZeoProtocol]) -> dict[RoborockZeoProtocol, Any]:
"""Query the device for the values of the given protocols."""
params = {RoborockZeoProtocol.ID_QUERY: [int(p) for p in protocols]}
return await send_decoded_command(self._channel, params)

async def set_value(self, protocol: RoborockZeoProtocol, value: Any) -> dict[RoborockZeoProtocol, Any]:
"""Set a value for a specific protocol on the device."""
params = {protocol: value}
return await send_decoded_command(self._channel, params)


def create(product: HomeDataProduct, mqtt_channel: MqttChannel) -> DyadApi | ZeoApi:
"""Create traits for A01 devices."""
match product.category:
case RoborockCategory.WET_DRY_VAC:
return DyadApi(mqtt_channel)
case RoborockCategory.WASHING_MACHINE:
return ZeoApi(mqtt_channel)
case _:
raise NotImplementedError(f"Unsupported category {product.category}")
31 changes: 31 additions & 0 deletions roborock/devices/traits/b01/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Traits for B01 devices."""

from roborock import RoborockB01Methods
from roborock.devices.b01_channel import send_decoded_command
from roborock.devices.mqtt_channel import MqttChannel
from roborock.devices.traits import Trait
from roborock.roborock_message import RoborockB01Props

__init__ = [
"create_b01_traits",
"PropertiesApi",
]


class PropertiesApi(Trait):
"""API for interacting with B01 devices."""

def __init__(self, channel: MqttChannel) -> None:
"""Initialize the B01Props API."""
self._channel = channel

async def query_values(self, props: list[RoborockB01Props]) -> None:
"""Query the device for the values of the given Dyad protocols."""
await send_decoded_command(
self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params={"property": props}
)


def create(channel: MqttChannel) -> PropertiesApi:
"""Create traits for B01 devices."""
return PropertiesApi(channel)
32 changes: 0 additions & 32 deletions roborock/devices/traits/b01/props.py

This file was deleted.

Loading