Skip to content
55 changes: 44 additions & 11 deletions roborock/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ def _encode(self, obj, context, _):
iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25]
decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
return decipher.encrypt(obj)
elif context.version == b"L01":
return Utils.encrypt_gcm_l01(
plaintext=obj,
local_key=context.search("local_key"),
timestamp=context.timestamp,
sequence=context.seq,
nonce=context.random,
connect_nonce=context.search("connect_nonce"),
ack_nonce=context.search("ack_nonce"),
)
token = self.token_func(context)
encrypted = Utils.encrypt_ecb(obj, token)
return encrypted
Expand All @@ -297,6 +307,16 @@ def _decode(self, obj, context, _):
iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25]
decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
return decipher.decrypt(obj)
elif context.version == b"L01":
return Utils.decrypt_gcm_l01(
payload=obj,
local_key=context.search("local_key"),
timestamp=context.timestamp,
sequence=context.seq,
nonce=context.random,
connect_nonce=context.search("connect_nonce"),
ack_nonce=context.search("ack_nonce"),
)
token = self.token_func(context)
decrypted = Utils.decrypt_ecb(obj, token)
return decrypted
Expand All @@ -321,7 +341,7 @@ class PrefixedStruct(Struct):
def _parse(self, stream, context, path):
subcon1 = Peek(Optional(Bytes(3)))
peek_version = subcon1.parse_stream(stream, **context)
if peek_version not in (b"1.0", b"A01", b"B01"):
if peek_version not in (b"1.0", b"A01", b"B01", b"L01"):
subcon2 = Bytes(4)
subcon2.parse_stream(stream, **context)
return super()._parse(stream, context, path)
Expand Down Expand Up @@ -374,10 +394,12 @@ def __init__(self, con: Construct, required_local_key: bool):
self.con = con
self.required_local_key = required_local_key

def parse(self, data: bytes, local_key: str | None = None) -> tuple[list[RoborockMessage], bytes]:
def parse(
self, data: bytes, local_key: str | None = None, connect_nonce: int | None = None, ack_nonce: int | None = None
) -> tuple[list[RoborockMessage], bytes]:
if self.required_local_key and local_key is None:
raise RoborockException("Local key is required")
parsed = self.con.parse(data, local_key=local_key)
parsed = self.con.parse(data, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce)
parsed_messages = [Container({"message": parsed.message})] if parsed.get("message") else parsed.messages
messages = []
for message in parsed_messages:
Expand All @@ -395,7 +417,12 @@ def parse(self, data: bytes, local_key: str | None = None) -> tuple[list[Roboroc
return messages, remaining

def build(
self, roborock_messages: list[RoborockMessage] | RoborockMessage, local_key: str, prefixed: bool = True
self,
roborock_messages: list[RoborockMessage] | RoborockMessage,
local_key: str,
prefixed: bool = True,
connect_nonce: int | None = None,
ack_nonce: int | None = None,
) -> bytes:
if isinstance(roborock_messages, RoborockMessage):
roborock_messages = [roborock_messages]
Expand All @@ -416,7 +443,11 @@ def build(
}
)
return self.con.build(
{"messages": [message for message in messages], "remaining": b""}, local_key=local_key, prefixed=prefixed
{"messages": [message for message in messages], "remaining": b""},
local_key=local_key,
prefixed=prefixed,
connect_nonce=connect_nonce,
ack_nonce=ack_nonce,
)


Expand Down Expand Up @@ -466,29 +497,31 @@ def encode(messages: RoborockMessage) -> bytes:
return encode


def create_local_decoder(local_key: str) -> Decoder:
def create_local_decoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Decoder:
"""Create a decoder for local API messages."""

# This buffer is used to accumulate bytes until a complete message can be parsed.
# It is defined outside the decode function to maintain state across calls.
buffer: bytes = b""

def decode(bytes: bytes) -> list[RoborockMessage]:
def decode(bytes_data: bytes) -> list[RoborockMessage]:
"""Parse the given data into Roborock messages."""
nonlocal buffer
buffer += bytes
parsed_messages, remaining = MessageParser.parse(buffer, local_key=local_key)
buffer += bytes_data
parsed_messages, remaining = MessageParser.parse(
buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce
)
buffer = remaining
return parsed_messages

return decode


def create_local_encoder(local_key: str) -> Encoder:
def create_local_encoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Encoder:
"""Create an encoder for local API messages."""

def encode(message: RoborockMessage) -> bytes:
"""Called when data is sent to the transport."""
return MessageParser.build(message, local_key=local_key)
return MessageParser.build(message, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce)

return encode
2 changes: 1 addition & 1 deletion roborock/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, Muta
counter_map: dict[tuple[int, int], int] = {}


def get_next_int(min_val: int, max_val: int):
def get_next_int(min_val: int, max_val: int) -> int:
"""Gets a random int in the range, precached to help keep it fast."""
if (min_val, max_val) not in counter_map:
# If we have never seen this range, or if the cache is getting low, make a bunch of preshuffled values.
Expand Down
8 changes: 7 additions & 1 deletion roborock/version_1_apis/roborock_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,15 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
"Received unsolicited map response for request_id %s", map_response.request_id
)
else:
if data.protocol == RoborockMessageProtocol.GENERAL_RESPONSE and data.payload is None:
# Api will often send blank messages with matching sequences, we can ignore these.
continue
queue = self._waiting_queue.get(data.seq)
if queue:
queue.set_result(data.payload)
if data.protocol == RoborockMessageProtocol.HELLO_RESPONSE:
queue.set_result(data)
else:
queue.set_result(data.payload)
else:
self._logger.debug("Received response for unknown request id %s", data.seq)
except Exception as ex:
Expand Down
111 changes: 86 additions & 25 deletions roborock/version_1_apis/roborock_local_client_v1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import json
import logging
import math
import time
from asyncio import Lock, TimerHandle, Transport, get_running_loop
from collections.abc import Callable
from dataclasses import dataclass
Expand All @@ -12,25 +15,12 @@
from ..protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
from ..protocols.v1_protocol import RequestMessage
from ..roborock_message import RoborockMessage, RoborockMessageProtocol
from ..util import RoborockLoggerAdapter
from ..util import RoborockLoggerAdapter, get_next_int
from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1

_LOGGER = logging.getLogger(__name__)


_HELLO_REQUEST_MESSAGE = RoborockMessage(
protocol=RoborockMessageProtocol.HELLO_REQUEST,
seq=1,
random=22,
)

_PING_REQUEST_MESSAGE = RoborockMessage(
protocol=RoborockMessageProtocol.PING_REQUEST,
seq=2,
random=23,
)


@dataclass
class _LocalProtocol(asyncio.Protocol):
"""Callbacks for the Roborock local client transport."""
Expand All @@ -50,7 +40,7 @@ def connection_lost(self, exc: Exception | None) -> None:
class RoborockLocalClientV1(RoborockClientV1, RoborockClient):
"""Roborock local client for v1 devices."""

def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
def __init__(self, device_data: DeviceData, queue_timeout: int = 4, version: str | None = None):
"""Initialize the Roborock local client."""
if device_data.host is None:
raise RoborockException("Host is required")
Expand All @@ -63,8 +53,14 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
RoborockClientV1.__init__(self, device_data, security_data=None)
RoborockClient.__init__(self, device_data)
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
self._decoder: Decoder = create_local_decoder(device_data.device.local_key)
self._version = version
self._connect_nonce: int | None = None
self._ack_nonce: int | None = None
if version == "L01":
self._set_l01_encoder_decoder()
else:
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
self._decoder: Decoder = create_local_decoder(device_data.device.local_key)
self.queue_timeout = queue_timeout
self._logger = RoborockLoggerAdapter(device_data.device.name, _LOGGER)

Expand Down Expand Up @@ -121,20 +117,58 @@ async def async_disconnect(self) -> None:
async with self._mutex:
self._sync_disconnect()

async def hello(self):
def _set_l01_encoder_decoder(self):
"""Tell the system to use the L01 encoder/decoder."""
self._encoder = create_local_encoder(self.device_info.device.local_key, self._connect_nonce, self._ack_nonce)
self._decoder = create_local_decoder(self.device_info.device.local_key, self._connect_nonce, self._ack_nonce)

async def _do_hello(self, version: str) -> bool:
"""Perform the initial handshaking."""
self._logger.debug(f"Attempting to use the {version} protocol for client {self.device_info.device.duid}...")
self._connect_nonce = get_next_int(10000, 32767)
request = RoborockMessage(
protocol=RoborockMessageProtocol.HELLO_REQUEST,
version=version.encode(),
random=self._connect_nonce,
seq=1,
)
try:
return await self._send_message(
roborock_message=_HELLO_REQUEST_MESSAGE,
request_id=_HELLO_REQUEST_MESSAGE.seq,
response = await self._send_message(
roborock_message=request,
request_id=request.seq,
response_protocol=RoborockMessageProtocol.HELLO_RESPONSE,
)
except Exception as e:
self._logger.error(e)
if response.version.decode() == "L01":
self._ack_nonce = response.random
self._set_l01_encoder_decoder()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, up until this point the encoder was used without a connect nonce or an ack nonce. But then for all future messages beyond the hello, these must be used? Is that a requirement or can the connect nonce be set up front in the encoder even for the hello message?

Is response.random set to None in the non-L01 protocols? Does it hurt to set it on future requests for non-L01 protocols too?

Copy link
Collaborator Author

@Lash-L Lash-L Sep 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

response.random is not None for non-L01 protocols. However, it is actually fine if they exist, that is a good call. It isn't used on encrypting the 1.0 payload, so those values can be anything.

As well, the hello message does not include any payload, so the value doesn't matter there as well.

The only potential issues to setting connect_nonce upfront is that if the user reconnects, it will be the same nonce, but the vac should handle that fine as it is a new transport thread.

self._version = version
self._logger.debug(f"Client {self.device_info.device.duid} speaks the {version} protocol.")
return True
except RoborockException as e:
self._logger.debug(
f"Client {self.device_info.device.duid} did not respond or does not speak the {version} protocol. {e}"
)
return False

async def hello(self):
"""Send hello to the device to negotiate protocol."""
if self._version:
# version is forced
if not await self._do_hello(self._version):
raise RoborockException(f"Failed to connect to device with protocol {self._version}")
else:
# try 1.0, then L01
if not await self._do_hello("1.0"):
if not await self._do_hello("L01"):
raise RoborockException("Failed to connect to device with any known protocol")

async def ping(self) -> None:
ping_message = RoborockMessage(
protocol=RoborockMessageProtocol.PING_REQUEST,
)
await self._send_message(
roborock_message=_PING_REQUEST_MESSAGE,
request_id=_PING_REQUEST_MESSAGE.seq,
roborock_message=ping_message,
request_id=ping_message.seq,
response_protocol=RoborockMessageProtocol.PING_RESPONSE,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to set ping request statically like we did. I know we talked about this before but I wasn't fully sure.

This works


Expand All @@ -153,6 +187,33 @@ async def _send_command(
):
if method in CLOUD_REQUIRED:
raise RoborockException(f"Method {method} is not supported over local connection")
if self._version == "L01":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all seems like the responsibility of RequestMessage.encode_message and should live in protocol. Firstly because the code follow this pattern so its cleaner and unit testable, but also so that it can be reused by the other client.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great call, as it turns out all the logic was similar enough to just carry over with a small change.

request_id = get_next_int(10000, 999999)
dps_payload = {
"id": request_id,
"method": method,
"params": params,
}
ts = math.floor(time.time())
payload = {
"dps": {str(RoborockMessageProtocol.RPC_REQUEST.value): json.dumps(dps_payload, separators=(",", ":"))},
"t": ts,
}
roborock_message = RoborockMessage(
protocol=RoborockMessageProtocol.GENERAL_REQUEST,
payload=json.dumps(payload, separators=(",", ":")).encode("utf-8"),
version=self._version.encode(),
timestamp=ts,
)
self._logger.debug("Building message id %s for method %s", request_id, method)
return await self._send_message(
roborock_message,
request_id=request_id,
response_protocol=RoborockMessageProtocol.GENERAL_REQUEST,
method=method,
params=params,
)

request_message = RequestMessage(method=method, params=params)
roborock_message = request_message.encode_message(RoborockMessageProtocol.GENERAL_REQUEST)
self._logger.debug("Building message id %s for method %s", request_message.request_id, method)
Expand Down