diff --git a/roborock/api.py b/roborock/api.py index 777ead6d..88ae565b 100644 --- a/roborock/api.py +++ b/roborock/api.py @@ -21,7 +21,7 @@ from .roborock_message import ( RoborockMessage, ) -from .util import get_next_int, get_running_loop_or_create_one +from .util import get_next_int _LOGGER = logging.getLogger(__name__) KEEPALIVE = 60 @@ -35,7 +35,6 @@ class RoborockClient(ABC): def __init__(self, device_info: DeviceData) -> None: """Initialize RoborockClient.""" - self.event_loop = get_running_loop_or_create_one() self.device_info = device_info self._nonce = secrets.token_bytes(16) self._waiting_queue: dict[int, RoborockFuture] = {} diff --git a/roborock/cloud_api.py b/roborock/cloud_api.py index f2731371..4387fcbf 100644 --- a/roborock/cloud_api.py +++ b/roborock/cloud_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging import threading from abc import ABC @@ -158,7 +159,8 @@ async def async_disconnect(self) -> None: if disconnected_future := self._sync_disconnect(): # There are no errors set on this future await disconnected_future - await self.event_loop.run_in_executor(None, self._mqtt_client.loop_stop) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._mqtt_client.loop_stop) async def async_connect(self) -> None: async with self._mutex: diff --git a/roborock/local_api.py b/roborock/local_api.py index f0a2de96..5f761876 100644 --- a/roborock/local_api.py +++ b/roborock/local_api.py @@ -3,7 +3,7 @@ import asyncio import logging from abc import ABC -from asyncio import Lock, TimerHandle, Transport +from asyncio import Lock, TimerHandle, Transport, get_running_loop from collections.abc import Callable from dataclasses import dataclass @@ -72,7 +72,8 @@ async def keep_alive_func(self, _=None): await self.ping() except RoborockException: pass - self.keep_alive_task = self.event_loop.call_later(10, lambda: asyncio.create_task(self.keep_alive_func())) + loop = asyncio.get_running_loop() + self.keep_alive_task = loop.call_later(10, lambda: asyncio.create_task(self.keep_alive_func())) async def async_connect(self) -> None: should_ping = False @@ -82,7 +83,8 @@ async def async_connect(self) -> None: self._sync_disconnect() async with async_timeout.timeout(self.queue_timeout): self._logger.debug(f"Connecting to {self.host}") - self.transport, _ = await self.event_loop.create_connection( # type: ignore + loop = get_running_loop() + self.transport, _ = await loop.create_connection( # type: ignore lambda: self._local_protocol, self.host, 58867 ) self._logger.info(f"Connected to {self.host}") @@ -94,7 +96,8 @@ async def async_connect(self) -> None: await self.keep_alive_func() def _sync_disconnect(self) -> None: - if self.transport and self.event_loop.is_running(): + loop = asyncio.get_running_loop() + if self.transport and loop.is_running(): self._logger.debug(f"Disconnecting from {self.host}") self.transport.close() if self.keep_alive_task: diff --git a/roborock/util.py b/roborock/util.py index c6013d24..574b0362 100644 --- a/roborock/util.py +++ b/roborock/util.py @@ -74,8 +74,7 @@ def wrapped(*args, **kwargs): class RepeatableTask: - def __init__(self, loop: AbstractEventLoop, callback: Callable[[], Coroutine], interval: int): - self.loop = loop + def __init__(self, callback: Callable[[], Coroutine], interval: int): self.callback = callback self.interval = interval self._task: TimerHandle | None = None @@ -86,7 +85,8 @@ async def _run_task(self): response = await self.callback() except RoborockException: pass - self._task = self.loop.call_later(self.interval, self._run_task_soon) + loop = asyncio.get_running_loop() + self._task = loop.call_later(self.interval, self._run_task_soon) return response def _run_task_soon(self): diff --git a/roborock/version_1_apis/roborock_client_v1.py b/roborock/version_1_apis/roborock_client_v1.py index 754b1af8..a8dbb098 100644 --- a/roborock/version_1_apis/roborock_client_v1.py +++ b/roborock/version_1_apis/roborock_client_v1.py @@ -82,11 +82,11 @@ class AttributeCache: - def __init__(self, attribute: RoborockAttribute, loop: asyncio.AbstractEventLoop, send_command: _SendCommandT): + def __init__(self, attribute: RoborockAttribute, send_command: _SendCommandT): self.attribute = attribute self._send_command = send_command self.attribute = attribute - self.task = RepeatableTask(loop, self._async_value, EVICT_TIME) + self.task = RepeatableTask(self._async_value, EVICT_TIME) self._value: Any = None self._mutex = asyncio.Lock() self.unsupported: bool = False @@ -156,7 +156,7 @@ def __init__(self, device_info: DeviceData, endpoint: str): super().__init__(device_info) self._status_type: type[Status] = ModelStatus.get(device_info.model, S7MaxVStatus) self.cache: dict[CacheableAttribute, AttributeCache] = { - cacheable_attribute: AttributeCache(attr, self.event_loop, self._send_command) + cacheable_attribute: AttributeCache(attr, self._send_command) for cacheable_attribute, attr in get_cache_map().items() } if device_info.device.duid not in self._listeners: diff --git a/tests/conftest.py b/tests/conftest.py index e2199e6c..1986e9b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -227,7 +227,7 @@ def handle_write(data: bytes) -> None: return (mock_transport, "proto") - with patch("roborock.api.get_running_loop_or_create_one") as mock_loop: + with patch("roborock.local_api.get_running_loop") as mock_loop: mock_loop.return_value.create_connection.side_effect = create_connection yield