From 77e63455cc0c6b83bb6a14e62a833f0b81af23f3 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 3 Oct 2025 16:52:28 -0700 Subject: [PATCH 1/6] split manager --- reflex/istate/manager/__init__.py | 120 ++++++ reflex/istate/manager/disk.py | 210 ++++++++++ reflex/istate/manager/memory.py | 76 ++++ .../istate/{manager.py => manager/redis.py} | 379 +----------------- reflex/state.py | 6 +- reflex/testing.py | 12 +- tests/integration/test_client_storage.py | 11 +- tests/integration/test_connection_banner.py | 2 +- tests/units/test_app.py | 6 +- tests/units/test_state.py | 12 +- tests/units/test_state_tree.py | 3 +- 11 files changed, 433 insertions(+), 404 deletions(-) create mode 100644 reflex/istate/manager/__init__.py create mode 100644 reflex/istate/manager/disk.py create mode 100644 reflex/istate/manager/memory.py rename reflex/istate/{manager.py => manager/redis.py} (58%) diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py new file mode 100644 index 00000000000..5918712ac61 --- /dev/null +++ b/reflex/istate/manager/__init__.py @@ -0,0 +1,120 @@ +"""State manager for managing client states.""" + +import contextlib +import dataclasses +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator + +from reflex import constants +from reflex.config import get_config +from reflex.state import BaseState +from reflex.utils import console, prerequisites +from reflex.utils.exceptions import InvalidStateManagerModeError + + +@dataclasses.dataclass +class StateManager(ABC): + """A class to manage many client states.""" + + # The state class to use. + state: type[BaseState] + + @classmethod + def create(cls, state: type[BaseState]): + """Create a new state manager. + + Args: + state: The state class to use. + + Raises: + InvalidStateManagerModeError: If the state manager mode is invalid. + + Returns: + The state manager (either disk, memory or redis). + """ + config = get_config() + if prerequisites.parse_redis_url() is not None: + config.state_manager_mode = constants.StateManagerMode.REDIS + if config.state_manager_mode == constants.StateManagerMode.MEMORY: + from reflex.istate.manager.memory import StateManagerMemory + + return StateManagerMemory(state=state) + if config.state_manager_mode == constants.StateManagerMode.DISK: + from reflex.istate.manager.disk import StateManagerDisk + + return StateManagerDisk(state=state) + if config.state_manager_mode == constants.StateManagerMode.REDIS: + redis = prerequisites.get_redis() + if redis is not None: + from reflex.istate.manager.redis import StateManagerRedis + + # make sure expiration values are obtained only from the config object on creation + return StateManagerRedis( + state=state, + redis=redis, + token_expiration=config.redis_token_expiration, + lock_expiration=config.redis_lock_expiration, + lock_warning_threshold=config.redis_lock_warning_threshold, + ) + msg = f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" + raise InvalidStateManagerModeError(msg) + + @abstractmethod + async def get_state(self, token: str) -> BaseState: + """Get the state for a token. + + Args: + token: The token to get the state for. + + Returns: + The state for the token. + """ + + @abstractmethod + async def set_state(self, token: str, state: BaseState): + """Set the state for a token. + + Args: + token: The token to set the state for. + state: The state to set. + """ + + @abstractmethod + @contextlib.asynccontextmanager + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: + """Modify the state for a token while holding exclusive lock. + + Args: + token: The token to modify the state for. + + Yields: + The state for the token. + """ + yield self.state() + + +def _default_token_expiration() -> int: + """Get the default token expiration time. + + Returns: + The default token expiration time. + """ + return get_config().redis_token_expiration + + +def reset_disk_state_manager(): + """Reset the disk state manager.""" + console.debug("Resetting disk state manager.") + states_directory = prerequisites.get_states_dir() + if states_directory.exists(): + for path in states_directory.iterdir(): + path.unlink() + + +def get_state_manager() -> StateManager: + """Get the state manager for the app that is currently running. + + Returns: + The state manager. + """ + return prerequisites.get_and_validate_app().app.state_manager diff --git a/reflex/istate/manager/disk.py b/reflex/istate/manager/disk.py new file mode 100644 index 00000000000..0db38b567fb --- /dev/null +++ b/reflex/istate/manager/disk.py @@ -0,0 +1,210 @@ +"""A state manager that stores states on disk.""" + +import asyncio +import contextlib +import dataclasses +import functools +from collections.abc import AsyncIterator +from hashlib import md5 +from pathlib import Path + +from typing_extensions import override + +from reflex.istate.manager import StateManager, _default_token_expiration +from reflex.state import BaseState, _split_substate_key, _substate_key +from reflex.utils import path_ops, prerequisites + + +@dataclasses.dataclass +class StateManagerDisk(StateManager): + """A state manager that stores states on disk.""" + + # The mapping of client ids to states. + states: dict[str, BaseState] = dataclasses.field(default_factory=dict) + + # The mutex ensures the dict of mutexes is updated exclusively + _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) + + # The dict of mutexes for each client + _states_locks: dict[str, asyncio.Lock] = dataclasses.field( + default_factory=dict, + init=False, + ) + + # The token expiration time (s). + token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) + + def __post_init_(self): + """Create a new state manager.""" + path_ops.mkdir(self.states_directory) + + self._purge_expired_states() + + @functools.cached_property + def states_directory(self) -> Path: + """Get the states directory. + + Returns: + The states directory. + """ + return prerequisites.get_states_dir() + + def _purge_expired_states(self): + """Purge expired states from the disk.""" + import time + + for path in path_ops.ls(self.states_directory): + # check path is a pickle file + if path.suffix != ".pkl": + continue + + # load last edited field from file + last_edited = path.stat().st_mtime + + # check if the file is older than the token expiration time + if time.time() - last_edited > self.token_expiration: + # remove the file + path.unlink() + + def token_path(self, token: str) -> Path: + """Get the path for a token. + + Args: + token: The token to get the path for. + + Returns: + The path for the token. + """ + return ( + self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl" + ).absolute() + + async def load_state(self, token: str) -> BaseState | None: + """Load a state object based on the provided token. + + Args: + token: The token used to identify the state object. + + Returns: + The loaded state object or None. + """ + token_path = self.token_path(token) + + if token_path.exists(): + try: + with token_path.open(mode="rb") as file: + return BaseState._deserialize(fp=file) + except Exception: + pass + return None + + async def populate_substates( + self, client_token: str, state: BaseState, root_state: BaseState + ): + """Populate the substates of a state object. + + Args: + client_token: The client token. + state: The state object to populate. + root_state: The root state object. + """ + for substate in state.get_substates(): + substate_token = _substate_key(client_token, substate) + + fresh_instance = await root_state.get_state(substate) + instance = await self.load_state(substate_token) + if instance is not None: + # Ensure all substates exist, even if they weren't serialized previously. + instance.substates = fresh_instance.substates + else: + instance = fresh_instance + state.substates[substate.get_name()] = instance + instance.parent_state = state + + await self.populate_substates(client_token, instance, root_state) + + @override + async def get_state( + self, + token: str, + ) -> BaseState: + """Get the state for a token. + + Args: + token: The token to get the state for. + + Returns: + The state for the token. + """ + client_token = _split_substate_key(token)[0] + root_state = self.states.get(client_token) + if root_state is not None: + # Retrieved state from memory. + return root_state + + # Deserialize root state from disk. + root_state = await self.load_state(_substate_key(client_token, self.state)) + # Create a new root state tree with all substates instantiated. + fresh_root_state = self.state(_reflex_internal_init=True) + if root_state is None: + root_state = fresh_root_state + else: + # Ensure all substates exist, even if they were not serialized previously. + root_state.substates = fresh_root_state.substates + self.states[client_token] = root_state + await self.populate_substates(client_token, root_state, root_state) + return root_state + + async def set_state_for_substate(self, client_token: str, substate: BaseState): + """Set the state for a substate. + + Args: + client_token: The client token. + substate: The substate to set. + """ + substate_token = _substate_key(client_token, substate) + + if substate._get_was_touched(): + substate._was_touched = False # Reset the touched flag after serializing. + pickle_state = substate._serialize() + if pickle_state: + if not self.states_directory.exists(): + self.states_directory.mkdir(parents=True, exist_ok=True) + self.token_path(substate_token).write_bytes(pickle_state) + + for substate_substate in substate.substates.values(): + await self.set_state_for_substate(client_token, substate_substate) + + @override + async def set_state(self, token: str, state: BaseState): + """Set the state for a token. + + Args: + token: The token to set the state for. + state: The state to set. + """ + client_token, _ = _split_substate_key(token) + await self.set_state_for_substate(client_token, state) + + @override + @contextlib.asynccontextmanager + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: + """Modify the state for a token while holding exclusive lock. + + Args: + token: The token to modify the state for. + + Yields: + The state for the token. + """ + # Memory state manager ignores the substate suffix and always returns the top-level state. + client_token, _ = _split_substate_key(token) + if client_token not in self._states_locks: + async with self._state_manager_lock: + if client_token not in self._states_locks: + self._states_locks[client_token] = asyncio.Lock() + + async with self._states_locks[client_token]: + state = await self.get_state(token) + yield state + await self.set_state(token, state) diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py new file mode 100644 index 00000000000..9706c9e49fc --- /dev/null +++ b/reflex/istate/manager/memory.py @@ -0,0 +1,76 @@ +"""A state manager that stores states in memory.""" + +import asyncio +import contextlib +import dataclasses +from collections.abc import AsyncIterator + +from typing_extensions import override + +from reflex.istate.manager import StateManager +from reflex.state import BaseState, _split_substate_key + + +@dataclasses.dataclass +class StateManagerMemory(StateManager): + """A state manager that stores states in memory.""" + + # The mapping of client ids to states. + states: dict[str, BaseState] = dataclasses.field(default_factory=dict) + + # The mutex ensures the dict of mutexes is updated exclusively + _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) + + # The dict of mutexes for each client + _states_locks: dict[str, asyncio.Lock] = dataclasses.field( + default_factory=dict, init=False + ) + + @override + async def get_state(self, token: str) -> BaseState: + """Get the state for a token. + + Args: + token: The token to get the state for. + + Returns: + The state for the token. + """ + # Memory state manager ignores the substate suffix and always returns the top-level state. + token = _split_substate_key(token)[0] + if token not in self.states: + self.states[token] = self.state(_reflex_internal_init=True) + return self.states[token] + + @override + async def set_state(self, token: str, state: BaseState): + """Set the state for a token. + + Args: + token: The token to set the state for. + state: The state to set. + """ + token = _split_substate_key(token)[0] + self.states[token] = state + + @override + @contextlib.asynccontextmanager + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: + """Modify the state for a token while holding exclusive lock. + + Args: + token: The token to modify the state for. + + Yields: + The state for the token. + """ + # Memory state manager ignores the substate suffix and always returns the top-level state. + token = _split_substate_key(token)[0] + if token not in self._states_locks: + async with self._state_manager_lock: + if token not in self._states_locks: + self._states_locks[token] = asyncio.Lock() + + async with self._states_locks[token]: + state = await self.get_state(token) + yield state diff --git a/reflex/istate/manager.py b/reflex/istate/manager/redis.py similarity index 58% rename from reflex/istate/manager.py rename to reflex/istate/manager/redis.py index 6400528f534..b8dd3d82fd5 100644 --- a/reflex/istate/manager.py +++ b/reflex/istate/manager/redis.py @@ -1,387 +1,29 @@ -"""State manager for managing client states.""" +"""A state manager that stores states in redis.""" import asyncio import contextlib import dataclasses -import functools import time import uuid -from abc import ABC, abstractmethod from collections.abc import AsyncIterator -from hashlib import md5 -from pathlib import Path +from typing import override from redis import ResponseError from redis.asyncio import Redis from redis.asyncio.client import PubSub -from typing_extensions import override -from reflex import constants from reflex.config import get_config from reflex.environment import environment +from reflex.istate.manager import StateManager, _default_token_expiration from reflex.state import BaseState, _split_substate_key, _substate_key -from reflex.utils import console, path_ops, prerequisites +from reflex.utils import console from reflex.utils.exceptions import ( InvalidLockWarningThresholdError, - InvalidStateManagerModeError, LockExpiredError, StateSchemaMismatchError, ) -@dataclasses.dataclass -class StateManager(ABC): - """A class to manage many client states.""" - - # The state class to use. - state: type[BaseState] - - @classmethod - def create(cls, state: type[BaseState]): - """Create a new state manager. - - Args: - state: The state class to use. - - Raises: - InvalidStateManagerModeError: If the state manager mode is invalid. - - Returns: - The state manager (either disk, memory or redis). - """ - config = get_config() - if prerequisites.parse_redis_url() is not None: - config.state_manager_mode = constants.StateManagerMode.REDIS - if config.state_manager_mode == constants.StateManagerMode.MEMORY: - return StateManagerMemory(state=state) - if config.state_manager_mode == constants.StateManagerMode.DISK: - return StateManagerDisk(state=state) - if config.state_manager_mode == constants.StateManagerMode.REDIS: - redis = prerequisites.get_redis() - if redis is not None: - # make sure expiration values are obtained only from the config object on creation - return StateManagerRedis( - state=state, - redis=redis, - token_expiration=config.redis_token_expiration, - lock_expiration=config.redis_lock_expiration, - lock_warning_threshold=config.redis_lock_warning_threshold, - ) - msg = f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" - raise InvalidStateManagerModeError(msg) - - @abstractmethod - async def get_state(self, token: str) -> BaseState: - """Get the state for a token. - - Args: - token: The token to get the state for. - - Returns: - The state for the token. - """ - - @abstractmethod - async def set_state(self, token: str, state: BaseState): - """Set the state for a token. - - Args: - token: The token to set the state for. - state: The state to set. - """ - - @abstractmethod - @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[BaseState]: - """Modify the state for a token while holding exclusive lock. - - Args: - token: The token to modify the state for. - - Yields: - The state for the token. - """ - yield self.state() - - -@dataclasses.dataclass -class StateManagerMemory(StateManager): - """A state manager that stores states in memory.""" - - # The mapping of client ids to states. - states: dict[str, BaseState] = dataclasses.field(default_factory=dict) - - # The mutex ensures the dict of mutexes is updated exclusively - _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) - - # The dict of mutexes for each client - _states_locks: dict[str, asyncio.Lock] = dataclasses.field( - default_factory=dict, init=False - ) - - @override - async def get_state(self, token: str) -> BaseState: - """Get the state for a token. - - Args: - token: The token to get the state for. - - Returns: - The state for the token. - """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] - if token not in self.states: - self.states[token] = self.state(_reflex_internal_init=True) - return self.states[token] - - @override - async def set_state(self, token: str, state: BaseState): - """Set the state for a token. - - Args: - token: The token to set the state for. - state: The state to set. - """ - token = _split_substate_key(token)[0] - self.states[token] = state - - @override - @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[BaseState]: - """Modify the state for a token while holding exclusive lock. - - Args: - token: The token to modify the state for. - - Yields: - The state for the token. - """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] - if token not in self._states_locks: - async with self._state_manager_lock: - if token not in self._states_locks: - self._states_locks[token] = asyncio.Lock() - - async with self._states_locks[token]: - state = await self.get_state(token) - yield state - - -def _default_token_expiration() -> int: - """Get the default token expiration time. - - Returns: - The default token expiration time. - """ - return get_config().redis_token_expiration - - -def reset_disk_state_manager(): - """Reset the disk state manager.""" - console.debug("Resetting disk state manager.") - states_directory = prerequisites.get_states_dir() - if states_directory.exists(): - for path in states_directory.iterdir(): - path.unlink() - - -@dataclasses.dataclass -class StateManagerDisk(StateManager): - """A state manager that stores states in memory.""" - - # The mapping of client ids to states. - states: dict[str, BaseState] = dataclasses.field(default_factory=dict) - - # The mutex ensures the dict of mutexes is updated exclusively - _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) - - # The dict of mutexes for each client - _states_locks: dict[str, asyncio.Lock] = dataclasses.field( - default_factory=dict, - init=False, - ) - - # The token expiration time (s). - token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) - - def __post_init_(self): - """Create a new state manager.""" - path_ops.mkdir(self.states_directory) - - self._purge_expired_states() - - @functools.cached_property - def states_directory(self) -> Path: - """Get the states directory. - - Returns: - The states directory. - """ - return prerequisites.get_states_dir() - - def _purge_expired_states(self): - """Purge expired states from the disk.""" - import time - - for path in path_ops.ls(self.states_directory): - # check path is a pickle file - if path.suffix != ".pkl": - continue - - # load last edited field from file - last_edited = path.stat().st_mtime - - # check if the file is older than the token expiration time - if time.time() - last_edited > self.token_expiration: - # remove the file - path.unlink() - - def token_path(self, token: str) -> Path: - """Get the path for a token. - - Args: - token: The token to get the path for. - - Returns: - The path for the token. - """ - return ( - self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl" - ).absolute() - - async def load_state(self, token: str) -> BaseState | None: - """Load a state object based on the provided token. - - Args: - token: The token used to identify the state object. - - Returns: - The loaded state object or None. - """ - token_path = self.token_path(token) - - if token_path.exists(): - try: - with token_path.open(mode="rb") as file: - return BaseState._deserialize(fp=file) - except Exception: - pass - return None - - async def populate_substates( - self, client_token: str, state: BaseState, root_state: BaseState - ): - """Populate the substates of a state object. - - Args: - client_token: The client token. - state: The state object to populate. - root_state: The root state object. - """ - for substate in state.get_substates(): - substate_token = _substate_key(client_token, substate) - - fresh_instance = await root_state.get_state(substate) - instance = await self.load_state(substate_token) - if instance is not None: - # Ensure all substates exist, even if they weren't serialized previously. - instance.substates = fresh_instance.substates - else: - instance = fresh_instance - state.substates[substate.get_name()] = instance - instance.parent_state = state - - await self.populate_substates(client_token, instance, root_state) - - @override - async def get_state( - self, - token: str, - ) -> BaseState: - """Get the state for a token. - - Args: - token: The token to get the state for. - - Returns: - The state for the token. - """ - client_token = _split_substate_key(token)[0] - root_state = self.states.get(client_token) - if root_state is not None: - # Retrieved state from memory. - return root_state - - # Deserialize root state from disk. - root_state = await self.load_state(_substate_key(client_token, self.state)) - # Create a new root state tree with all substates instantiated. - fresh_root_state = self.state(_reflex_internal_init=True) - if root_state is None: - root_state = fresh_root_state - else: - # Ensure all substates exist, even if they were not serialized previously. - root_state.substates = fresh_root_state.substates - self.states[client_token] = root_state - await self.populate_substates(client_token, root_state, root_state) - return root_state - - async def set_state_for_substate(self, client_token: str, substate: BaseState): - """Set the state for a substate. - - Args: - client_token: The client token. - substate: The substate to set. - """ - substate_token = _substate_key(client_token, substate) - - if substate._get_was_touched(): - substate._was_touched = False # Reset the touched flag after serializing. - pickle_state = substate._serialize() - if pickle_state: - if not self.states_directory.exists(): - self.states_directory.mkdir(parents=True, exist_ok=True) - self.token_path(substate_token).write_bytes(pickle_state) - - for substate_substate in substate.substates.values(): - await self.set_state_for_substate(client_token, substate_substate) - - @override - async def set_state(self, token: str, state: BaseState): - """Set the state for a token. - - Args: - token: The token to set the state for. - state: The state to set. - """ - client_token, _ = _split_substate_key(token) - await self.set_state_for_substate(client_token, state) - - @override - @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[BaseState]: - """Modify the state for a token while holding exclusive lock. - - Args: - token: The token to modify the state for. - - Yields: - The state for the token. - """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - client_token, _ = _split_substate_key(token) - if client_token not in self._states_locks: - async with self._state_manager_lock: - if client_token not in self._states_locks: - self._states_locks[client_token] = asyncio.Lock() - - async with self._states_locks[client_token]: - state = await self.get_state(token) - yield state - await self.set_state(token, state) - - def _default_lock_expiration() -> int: """Get the default lock expiration time. @@ -748,7 +390,7 @@ async def _get_pubsub_message( if timeout is None: timeout = self.lock_expiration / 1000.0 - started = time.time() + started = time.monotonic() message = await pubsub.get_message( ignore_subscribe_messages=True, timeout=timeout, @@ -757,7 +399,7 @@ async def _get_pubsub_message( message is None or message["data"] not in self._redis_keyspace_lock_release_events ): - remaining = timeout - (time.time() - started) + remaining = timeout - (time.monotonic() - started) if remaining <= 0: return await self._get_pubsub_message(pubsub, timeout=remaining) @@ -847,12 +489,3 @@ async def close(self): Note: Connections will be automatically reopened when needed. """ await self.redis.aclose(close_connection_pool=True) - - -def get_state_manager() -> StateManager: - """Get the state manager for the app that is currently running. - - Returns: - The state manager. - """ - return prerequisites.get_and_validate_app().app.state_manager diff --git a/reflex/state.py b/reflex/state.py index b1abb1fee33..ddb95ca39c1 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1564,6 +1564,8 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: RuntimeError: If redis is not used in this backend process. StateMismatchError: If the state instance is not of the expected type. """ + from reflex.istate.manager.redis import StateManagerRedis + # Then get the target state and all its substates. state_manager = get_state_manager() if not isinstance(state_manager, StateManagerRedis): @@ -2742,11 +2744,7 @@ def reload_state_module( state.get_class_substate.cache_clear() -from reflex.istate.manager import LockExpiredError as LockExpiredError # noqa: E402 from reflex.istate.manager import StateManager as StateManager # noqa: E402 -from reflex.istate.manager import StateManagerDisk as StateManagerDisk # noqa: E402 -from reflex.istate.manager import StateManagerMemory as StateManagerMemory # noqa: E402 -from reflex.istate.manager import StateManagerRedis as StateManagerRedis # noqa: E402 from reflex.istate.manager import get_state_manager as get_state_manager # noqa: E402 from reflex.istate.manager import ( # noqa: E402 reset_disk_state_manager as reset_disk_state_manager, diff --git a/reflex/testing.py b/reflex/testing.py index 2f04b08ad1f..77edea6b371 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -37,14 +37,10 @@ from reflex.components.component import CustomComponent from reflex.config import get_config from reflex.environment import environment -from reflex.state import ( - BaseState, - StateManager, - StateManagerDisk, - StateManagerMemory, - StateManagerRedis, - reload_state_module, -) +from reflex.istate.manager.disk import StateManagerDisk +from reflex.istate.manager.memory import StateManagerMemory +from reflex.istate.manager.redis import StateManagerRedis +from reflex.state import BaseState, StateManager, reload_state_module from reflex.utils import console, js_runtimes from reflex.utils.export import export from reflex.utils.token_manager import TokenManager diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index 994fa265b24..313f527d62b 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -11,13 +11,10 @@ from selenium.webdriver.remote.webdriver import WebDriver from reflex.constants.state import FIELD_MARKER -from reflex.state import ( - State, - StateManagerDisk, - StateManagerMemory, - StateManagerRedis, - _substate_key, -) +from reflex.istate.manager.disk import StateManagerDisk +from reflex.istate.manager.memory import StateManagerMemory +from reflex.istate.manager.redis import StateManagerRedis +from reflex.state import State, _substate_key from reflex.testing import AppHarness from . import utils diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 044d431edb7..9af3bdcc9cd 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -8,7 +8,7 @@ from reflex import constants from reflex.environment import environment -from reflex.istate.manager import StateManagerRedis +from reflex.istate.manager.redis import StateManagerRedis from reflex.testing import AppHarness, WebDriver from reflex.utils.token_manager import RedisTokenManager diff --git a/tests/units/test_app.py b/tests/units/test_app.py index d77c9f804e7..9a51081a661 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -36,6 +36,9 @@ from reflex.components.radix.themes.typography.text import Text from reflex.constants.state import FIELD_MARKER from reflex.event import Event +from reflex.istate.manager.disk import StateManagerDisk +from reflex.istate.manager.memory import StateManagerMemory +from reflex.istate.manager.redis import StateManagerRedis from reflex.middleware import HydrateMiddleware from reflex.model import Model from reflex.state import ( @@ -43,9 +46,6 @@ OnLoadInternalState, RouterData, State, - StateManagerDisk, - StateManagerMemory, - StateManagerRedis, StateUpdate, _substate_key, ) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 7394ff55703..756c87e793a 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -29,13 +29,10 @@ from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.constants.state import FIELD_MARKER from reflex.event import Event, EventHandler -from reflex.istate.manager import ( - LockExpiredError, - StateManager, - StateManagerDisk, - StateManagerMemory, - StateManagerRedis, -) +from reflex.istate.manager import StateManager +from reflex.istate.manager.disk import StateManagerDisk +from reflex.istate.manager.memory import StateManagerMemory +from reflex.istate.manager.redis import StateManagerRedis from reflex.state import ( BaseState, ImmutableStateError, @@ -51,6 +48,7 @@ from reflex.utils import format, prerequisites, types from reflex.utils.exceptions import ( InvalidLockWarningThresholdError, + LockExpiredError, ReflexRuntimeError, SetUndefinedStateVarError, StateSerializationError, diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py index d3afa05fa0e..7ed19500cc2 100644 --- a/tests/units/test_state_tree.py +++ b/tests/units/test_state_tree.py @@ -7,7 +7,8 @@ import reflex as rx from reflex.constants.state import FIELD_MARKER -from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key +from reflex.istate.manager.redis import StateManagerRedis +from reflex.state import BaseState, StateManager, _substate_key class Root(BaseState): From 344abbb67d62fa5cc18cee64f50953abe3086a55 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 3 Oct 2025 16:54:11 -0700 Subject: [PATCH 2/6] sure --- reflex/istate/manager/redis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index b8dd3d82fd5..7981974c36b 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -6,11 +6,11 @@ import time import uuid from collections.abc import AsyncIterator -from typing import override from redis import ResponseError from redis.asyncio import Redis from redis.asyncio.client import PubSub +from typing_extensions import override from reflex.config import get_config from reflex.environment import environment From dea39849760caf254a23689bd9962a82d48b99f8 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 3 Oct 2025 16:56:12 -0700 Subject: [PATCH 3/6] huh --- reflex/istate/manager/disk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/istate/manager/disk.py b/reflex/istate/manager/disk.py index 0db38b567fb..3d1e2c71f56 100644 --- a/reflex/istate/manager/disk.py +++ b/reflex/istate/manager/disk.py @@ -34,7 +34,7 @@ class StateManagerDisk(StateManager): # The token expiration time (s). token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) - def __post_init_(self): + def __post_init__(self): """Create a new state manager.""" path_ops.mkdir(self.states_directory) From 85c00e7ae814b5d31ace9410d5dedd6ad60d6edd Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 3 Oct 2025 16:56:47 -0700 Subject: [PATCH 4/6] typo --- reflex/istate/manager/disk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/istate/manager/disk.py b/reflex/istate/manager/disk.py index 3d1e2c71f56..10fecf7f3c0 100644 --- a/reflex/istate/manager/disk.py +++ b/reflex/istate/manager/disk.py @@ -197,7 +197,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: Yields: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. + # Disk state manager ignores the substate suffix and always returns the top-level state. client_token, _ = _split_substate_key(token) if client_token not in self._states_locks: async with self._state_manager_lock: From e8918bdd82de37031f84e75554778938021e47d0 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 3 Oct 2025 17:10:05 -0700 Subject: [PATCH 5/6] a few more ig --- reflex/utils/prerequisites.py | 11 ++++++++--- reflex/utils/processes.py | 4 +++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index e892f42396b..ef2eda23194 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -17,9 +17,6 @@ from alembic.util.exc import CommandError from packaging import version -from redis import Redis as RedisSync -from redis.asyncio import Redis -from redis.exceptions import RedisError from reflex import constants, model from reflex.config import Config, get_config @@ -29,6 +26,9 @@ from reflex.utils.misc import get_module_path if typing.TYPE_CHECKING: + from redis import Redis as RedisSync + from redis.asyncio import Redis + from reflex.app import App @@ -370,6 +370,8 @@ def get_redis() -> Redis | None: Returns: The asynchronous redis client. """ + from redis.asyncio import Redis + from redis.exceptions import RedisError if (redis_url := parse_redis_url()) is not None: return Redis.from_url( redis_url, @@ -384,6 +386,8 @@ def get_redis_sync() -> RedisSync | None: Returns: The synchronous redis client. """ + from redis import Redis as RedisSync + from redis.exceptions import RedisError if (redis_url := parse_redis_url()) is not None: return RedisSync.from_url( redis_url, @@ -418,6 +422,7 @@ async def get_redis_status() -> dict[str, bool | None]: Returns: The status of the Redis connection. """ + from redis.exceptions import RedisError try: status = True redis_client = get_redis() diff --git a/reflex/utils/processes.py b/reflex/utils/processes.py index 0891798994d..d3ff92beecc 100644 --- a/reflex/utils/processes.py +++ b/reflex/utils/processes.py @@ -16,7 +16,6 @@ from typing import Any, Literal, overload import rich.markup -from redis.exceptions import RedisError from rich.progress import Progress from reflex import constants @@ -45,6 +44,9 @@ def get_num_workers() -> int: """ if (redis_client := prerequisites.get_redis_sync()) is None: return 1 + + from redis.exceptions import RedisError + try: redis_client.ping() except RedisError as re: From 6f149c2625dbda62cc954f3ef0ab79eaa4b9a528 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 3 Oct 2025 17:14:40 -0700 Subject: [PATCH 6/6] okie --- reflex/utils/prerequisites.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index ef2eda23194..a543c4dad78 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -370,8 +370,12 @@ def get_redis() -> Redis | None: Returns: The asynchronous redis client. """ - from redis.asyncio import Redis - from redis.exceptions import RedisError + try: + from redis.asyncio import Redis + from redis.exceptions import RedisError + except ImportError: + console.debug("Redis package not installed.") + return None if (redis_url := parse_redis_url()) is not None: return Redis.from_url( redis_url, @@ -386,8 +390,12 @@ def get_redis_sync() -> RedisSync | None: Returns: The synchronous redis client. """ - from redis import Redis as RedisSync - from redis.exceptions import RedisError + try: + from redis import Redis as RedisSync + from redis.exceptions import RedisError + except ImportError: + console.debug("Redis package not installed.") + return None if (redis_url := parse_redis_url()) is not None: return RedisSync.from_url( redis_url, @@ -423,6 +431,7 @@ async def get_redis_status() -> dict[str, bool | None]: The status of the Redis connection. """ from redis.exceptions import RedisError + try: status = True redis_client = get_redis()