diff --git a/roborock/api.py b/roborock/api.py index ee22d491..86b3bb5b 100644 --- a/roborock/api.py +++ b/roborock/api.py @@ -18,7 +18,6 @@ from .roborock_future import RoborockFuture from .roborock_message import ( RoborockMessage, - RoborockMessageProtocol, ) from .util import get_next_int @@ -97,9 +96,7 @@ async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any: def _async_response(self, request_id: int, protocol_id: int = 0) -> Any: queue = RoborockFuture(protocol_id) - if request_id in self._waiting_queue and not ( - request_id == 2 and protocol_id == RoborockMessageProtocol.PING_REQUEST - ): + if request_id in self._waiting_queue: new_id = get_next_int(10000, 32767) self._logger.warning( "Attempting to create a future with an existing id %s (%s)... New id is %s. " diff --git a/roborock/protocol.py b/roborock/protocol.py index 6b02836e..2c8c541c 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -282,6 +282,16 @@ def _encode(self, obj, context, _): iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25] decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8")) return decipher.encrypt(obj) + elif context.version == b"L01": + return Utils.encrypt_gcm_l01( + plaintext=obj, + local_key=context.search("local_key"), + timestamp=context.timestamp, + sequence=context.seq, + nonce=context.random, + connect_nonce=context.search("connect_nonce"), + ack_nonce=context.search("ack_nonce"), + ) token = self.token_func(context) encrypted = Utils.encrypt_ecb(obj, token) return encrypted @@ -297,6 +307,16 @@ def _decode(self, obj, context, _): iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25] decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8")) return decipher.decrypt(obj) + elif context.version == b"L01": + return Utils.decrypt_gcm_l01( + payload=obj, + local_key=context.search("local_key"), + timestamp=context.timestamp, + sequence=context.seq, + nonce=context.random, + connect_nonce=context.search("connect_nonce"), + ack_nonce=context.search("ack_nonce"), + ) token = self.token_func(context) decrypted = Utils.decrypt_ecb(obj, token) return decrypted @@ -321,7 +341,7 @@ class PrefixedStruct(Struct): def _parse(self, stream, context, path): subcon1 = Peek(Optional(Bytes(3))) peek_version = subcon1.parse_stream(stream, **context) - if peek_version not in (b"1.0", b"A01", b"B01"): + if peek_version not in (b"1.0", b"A01", b"B01", b"L01"): subcon2 = Bytes(4) subcon2.parse_stream(stream, **context) return super()._parse(stream, context, path) @@ -374,10 +394,12 @@ def __init__(self, con: Construct, required_local_key: bool): self.con = con self.required_local_key = required_local_key - def parse(self, data: bytes, local_key: str | None = None) -> tuple[list[RoborockMessage], bytes]: + def parse( + self, data: bytes, local_key: str | None = None, connect_nonce: int | None = None, ack_nonce: int | None = None + ) -> tuple[list[RoborockMessage], bytes]: if self.required_local_key and local_key is None: raise RoborockException("Local key is required") - parsed = self.con.parse(data, local_key=local_key) + parsed = self.con.parse(data, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce) parsed_messages = [Container({"message": parsed.message})] if parsed.get("message") else parsed.messages messages = [] for message in parsed_messages: @@ -395,7 +417,12 @@ def parse(self, data: bytes, local_key: str | None = None) -> tuple[list[Roboroc return messages, remaining def build( - self, roborock_messages: list[RoborockMessage] | RoborockMessage, local_key: str, prefixed: bool = True + self, + roborock_messages: list[RoborockMessage] | RoborockMessage, + local_key: str, + prefixed: bool = True, + connect_nonce: int | None = None, + ack_nonce: int | None = None, ) -> bytes: if isinstance(roborock_messages, RoborockMessage): roborock_messages = [roborock_messages] @@ -416,7 +443,11 @@ def build( } ) return self.con.build( - {"messages": [message for message in messages], "remaining": b""}, local_key=local_key, prefixed=prefixed + {"messages": [message for message in messages], "remaining": b""}, + local_key=local_key, + prefixed=prefixed, + connect_nonce=connect_nonce, + ack_nonce=ack_nonce, ) @@ -466,29 +497,31 @@ def encode(messages: RoborockMessage) -> bytes: return encode -def create_local_decoder(local_key: str) -> Decoder: +def create_local_decoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Decoder: """Create a decoder for local API messages.""" # This buffer is used to accumulate bytes until a complete message can be parsed. # It is defined outside the decode function to maintain state across calls. buffer: bytes = b"" - def decode(bytes: bytes) -> list[RoborockMessage]: + def decode(bytes_data: bytes) -> list[RoborockMessage]: """Parse the given data into Roborock messages.""" nonlocal buffer - buffer += bytes - parsed_messages, remaining = MessageParser.parse(buffer, local_key=local_key) + buffer += bytes_data + parsed_messages, remaining = MessageParser.parse( + buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce + ) buffer = remaining return parsed_messages return decode -def create_local_encoder(local_key: str) -> Encoder: +def create_local_encoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Encoder: """Create an encoder for local API messages.""" def encode(message: RoborockMessage) -> bytes: """Called when data is sent to the transport.""" - return MessageParser.build(message, local_key=local_key) + return MessageParser.build(message, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce) return encode diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 6e4817d8..92540153 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -65,15 +65,14 @@ class RequestMessage: request_id: int = field(default_factory=lambda: get_next_int(10000, 32767)) def encode_message( - self, - protocol: RoborockMessageProtocol, - security_data: SecurityData | None = None, + self, protocol: RoborockMessageProtocol, security_data: SecurityData | None = None, version: str = "1.0" ) -> RoborockMessage: """Convert the request message to a RoborockMessage.""" return RoborockMessage( timestamp=self.timestamp, protocol=protocol, payload=self._as_payload(security_data=security_data), + version=version.encode(), ) def _as_payload(self, security_data: SecurityData | None) -> bytes: diff --git a/roborock/util.py b/roborock/util.py index 72037589..0b36c48d 100644 --- a/roborock/util.py +++ b/roborock/util.py @@ -90,7 +90,7 @@ def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, Muta counter_map: dict[tuple[int, int], int] = {} -def get_next_int(min_val: int, max_val: int): +def get_next_int(min_val: int, max_val: int) -> int: """Gets a random int in the range, precached to help keep it fast.""" if (min_val, max_val) not in counter_map: # If we have never seen this range, or if the cache is getting low, make a bunch of preshuffled values. diff --git a/roborock/version_1_apis/roborock_client_v1.py b/roborock/version_1_apis/roborock_client_v1.py index d1d2fddb..2a4d7b84 100644 --- a/roborock/version_1_apis/roborock_client_v1.py +++ b/roborock/version_1_apis/roborock_client_v1.py @@ -447,10 +447,16 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None: self._logger.debug( "Received unsolicited map response for request_id %s", map_response.request_id ) + elif data.protocol == RoborockMessageProtocol.GENERAL_RESPONSE and data.payload is None: + # Api will often send blank messages with matching sequences, we can ignore these. + continue else: queue = self._waiting_queue.get(data.seq) if queue: - queue.set_result(data.payload) + if data.protocol == RoborockMessageProtocol.HELLO_RESPONSE: + queue.set_result(data) + else: + queue.set_result(data.payload) else: self._logger.debug("Received response for unknown request id %s", data.seq) except Exception as ex: diff --git a/roborock/version_1_apis/roborock_local_client_v1.py b/roborock/version_1_apis/roborock_local_client_v1.py index e2c68118..42095405 100644 --- a/roborock/version_1_apis/roborock_local_client_v1.py +++ b/roborock/version_1_apis/roborock_local_client_v1.py @@ -3,32 +3,27 @@ from asyncio import Lock, TimerHandle, Transport, get_running_loop from collections.abc import Callable from dataclasses import dataclass +from enum import StrEnum import async_timeout from .. import CommandVacuumError, DeviceData, RoborockCommand from ..api import RoborockClient from ..exceptions import RoborockConnectionException, RoborockException, VacuumError -from ..protocol import Decoder, Encoder, create_local_decoder, create_local_encoder +from ..protocol import create_local_decoder, create_local_encoder from ..protocols.v1_protocol import RequestMessage from ..roborock_message import RoborockMessage, RoborockMessageProtocol -from ..util import RoborockLoggerAdapter +from ..util import RoborockLoggerAdapter, get_next_int from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1 _LOGGER = logging.getLogger(__name__) -_HELLO_REQUEST_MESSAGE = RoborockMessage( - protocol=RoborockMessageProtocol.HELLO_REQUEST, - seq=1, - random=22, -) +class LocalProtocolVersion(StrEnum): + """Supported local protocol versions. Different from vacuum protocol versions.""" -_PING_REQUEST_MESSAGE = RoborockMessage( - protocol=RoborockMessageProtocol.PING_REQUEST, - seq=2, - random=23, -) + L01 = "L01" + V1 = "1.0" @dataclass @@ -50,7 +45,12 @@ def connection_lost(self, exc: Exception | None) -> None: class RoborockLocalClientV1(RoborockClientV1, RoborockClient): """Roborock local client for v1 devices.""" - def __init__(self, device_data: DeviceData, queue_timeout: int = 4): + def __init__( + self, + device_data: DeviceData, + queue_timeout: int = 4, + local_protocol_version: LocalProtocolVersion | None = None, + ): """Initialize the Roborock local client.""" if device_data.host is None: raise RoborockException("Host is required") @@ -63,11 +63,17 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4): RoborockClientV1.__init__(self, device_data, security_data=None) RoborockClient.__init__(self, device_data) self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost) - self._encoder: Encoder = create_local_encoder(device_data.device.local_key) - self._decoder: Decoder = create_local_decoder(device_data.device.local_key) + self._local_protocol_version = local_protocol_version + self._connect_nonce = get_next_int(10000, 32767) + self._ack_nonce: int | None = None + self._set_encoder_decoder() self.queue_timeout = queue_timeout self._logger = RoborockLoggerAdapter(device_data.device.name, _LOGGER) + @property + def local_protocol_version(self) -> LocalProtocolVersion: + return LocalProtocolVersion.V1 if self._local_protocol_version is None else self._local_protocol_version + def _data_received(self, message): """Called when data is received from the transport.""" parsed_msg = self._decoder(message) @@ -121,20 +127,69 @@ async def async_disconnect(self) -> None: async with self._mutex: self._sync_disconnect() - async def hello(self): + def _set_encoder_decoder(self): + """Updates the encoder decoder. These are updated with nonces after the first hello. + Only L01 uses the nonces.""" + self._encoder = create_local_encoder(self.device_info.device.local_key, self._connect_nonce, self._ack_nonce) + self._decoder = create_local_decoder(self.device_info.device.local_key, self._connect_nonce, self._ack_nonce) + + async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> bool: + """Perform the initial handshaking.""" + self._logger.debug( + "Attempting to use the %s protocol for client %s...", + local_protocol_version, + self.device_info.device.duid, + ) + request = RoborockMessage( + protocol=RoborockMessageProtocol.HELLO_REQUEST, + version=local_protocol_version.encode(), + random=self._connect_nonce, + seq=1, + ) try: - return await self._send_message( - roborock_message=_HELLO_REQUEST_MESSAGE, - request_id=_HELLO_REQUEST_MESSAGE.seq, + response = await self._send_message( + roborock_message=request, + request_id=request.seq, response_protocol=RoborockMessageProtocol.HELLO_RESPONSE, ) - except Exception as e: - self._logger.error(e) + self._ack_nonce = response.random + self._set_encoder_decoder() + self._local_protocol_version = local_protocol_version + + self._logger.debug( + "Client %s speaks the %s protocol.", + self.device_info.device.duid, + local_protocol_version, + ) + return True + except RoborockException as e: + self._logger.debug( + "Client %s did not respond or does not speak the %s protocol. %s", + self.device_info.device.duid, + local_protocol_version, + e, + ) + return False + + async def hello(self): + """Send hello to the device to negotiate protocol.""" + if self._local_protocol_version: + # version is forced + if not await self._do_hello(self._local_protocol_version): + raise RoborockException(f"Failed to connect to device with protocol {self._local_protocol_version}") + else: + # try 1.0, then L01 + if not await self._do_hello(LocalProtocolVersion.V1): + if not await self._do_hello(LocalProtocolVersion.L01): + raise RoborockException("Failed to connect to device with any known protocol") async def ping(self) -> None: + ping_message = RoborockMessage( + protocol=RoborockMessageProtocol.PING_REQUEST, version=self.local_protocol_version.encode() + ) await self._send_message( - roborock_message=_PING_REQUEST_MESSAGE, - request_id=_PING_REQUEST_MESSAGE.seq, + roborock_message=ping_message, + request_id=ping_message.seq, response_protocol=RoborockMessageProtocol.PING_RESPONSE, ) @@ -154,7 +209,10 @@ async def _send_command( if method in CLOUD_REQUIRED: raise RoborockException(f"Method {method} is not supported over local connection") request_message = RequestMessage(method=method, params=params) - roborock_message = request_message.encode_message(RoborockMessageProtocol.GENERAL_REQUEST) + roborock_message = request_message.encode_message( + RoborockMessageProtocol.GENERAL_REQUEST, + version=self.local_protocol_version, + ) self._logger.debug("Building message id %s for method %s", request_message.request_id, method) return await self._send_message( roborock_message,