Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions reflex/istate/manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""State manager for managing client states."""

import contextlib
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator

from reflex import constants
from reflex.config import get_config
from reflex.state import BaseState
from reflex.utils import console, prerequisites
from reflex.utils.exceptions import InvalidStateManagerModeError


@dataclasses.dataclass
class StateManager(ABC):
"""A class to manage many client states."""

# The state class to use.
state: type[BaseState]

@classmethod
def create(cls, state: type[BaseState]):
"""Create a new state manager.

Args:
state: The state class to use.

Raises:
InvalidStateManagerModeError: If the state manager mode is invalid.

Returns:
The state manager (either disk, memory or redis).
"""
config = get_config()
if prerequisites.parse_redis_url() is not None:
config.state_manager_mode = constants.StateManagerMode.REDIS
if config.state_manager_mode == constants.StateManagerMode.MEMORY:
from reflex.istate.manager.memory import StateManagerMemory

return StateManagerMemory(state=state)
if config.state_manager_mode == constants.StateManagerMode.DISK:
from reflex.istate.manager.disk import StateManagerDisk

return StateManagerDisk(state=state)
if config.state_manager_mode == constants.StateManagerMode.REDIS:
redis = prerequisites.get_redis()
if redis is not None:
from reflex.istate.manager.redis import StateManagerRedis

# make sure expiration values are obtained only from the config object on creation
return StateManagerRedis(
state=state,
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
lock_warning_threshold=config.redis_lock_warning_threshold,
)
msg = f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
raise InvalidStateManagerModeError(msg)

@abstractmethod
async def get_state(self, token: str) -> BaseState:
"""Get the state for a token.

Args:
token: The token to get the state for.

Returns:
The state for the token.
"""

@abstractmethod
async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.

Args:
token: The token to set the state for.
state: The state to set.
"""

@abstractmethod
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.

Args:
token: The token to modify the state for.

Yields:
The state for the token.
"""
yield self.state()


def _default_token_expiration() -> int:
"""Get the default token expiration time.

Returns:
The default token expiration time.
"""
return get_config().redis_token_expiration


def reset_disk_state_manager():
"""Reset the disk state manager."""
console.debug("Resetting disk state manager.")
states_directory = prerequisites.get_states_dir()
if states_directory.exists():
for path in states_directory.iterdir():
path.unlink()


def get_state_manager() -> StateManager:
"""Get the state manager for the app that is currently running.

Returns:
The state manager.
"""
return prerequisites.get_and_validate_app().app.state_manager
210 changes: 210 additions & 0 deletions reflex/istate/manager/disk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""A state manager that stores states on disk."""

import asyncio
import contextlib
import dataclasses
import functools
from collections.abc import AsyncIterator
from hashlib import md5
from pathlib import Path

from typing_extensions import override

from reflex.istate.manager import StateManager, _default_token_expiration
from reflex.state import BaseState, _split_substate_key, _substate_key
from reflex.utils import path_ops, prerequisites


@dataclasses.dataclass
class StateManagerDisk(StateManager):
"""A state manager that stores states on disk."""

# The mapping of client ids to states.
states: dict[str, BaseState] = dataclasses.field(default_factory=dict)

# The mutex ensures the dict of mutexes is updated exclusively
_state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())

# The dict of mutexes for each client
_states_locks: dict[str, asyncio.Lock] = dataclasses.field(
default_factory=dict,
init=False,
)

# The token expiration time (s).
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)

def __post_init__(self):
"""Create a new state manager."""
path_ops.mkdir(self.states_directory)

self._purge_expired_states()

@functools.cached_property
def states_directory(self) -> Path:
"""Get the states directory.

Returns:
The states directory.
"""
return prerequisites.get_states_dir()

def _purge_expired_states(self):
"""Purge expired states from the disk."""
import time

for path in path_ops.ls(self.states_directory):
# check path is a pickle file
if path.suffix != ".pkl":
continue

# load last edited field from file
last_edited = path.stat().st_mtime

# check if the file is older than the token expiration time
if time.time() - last_edited > self.token_expiration:
# remove the file
path.unlink()

def token_path(self, token: str) -> Path:
"""Get the path for a token.

Args:
token: The token to get the path for.

Returns:
The path for the token.
"""
return (
self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
).absolute()

async def load_state(self, token: str) -> BaseState | None:
"""Load a state object based on the provided token.

Args:
token: The token used to identify the state object.

Returns:
The loaded state object or None.
"""
token_path = self.token_path(token)

if token_path.exists():
try:
with token_path.open(mode="rb") as file:
return BaseState._deserialize(fp=file)
except Exception:
pass
return None

async def populate_substates(
self, client_token: str, state: BaseState, root_state: BaseState
):
"""Populate the substates of a state object.

Args:
client_token: The client token.
state: The state object to populate.
root_state: The root state object.
"""
for substate in state.get_substates():
substate_token = _substate_key(client_token, substate)

fresh_instance = await root_state.get_state(substate)
instance = await self.load_state(substate_token)
if instance is not None:
# Ensure all substates exist, even if they weren't serialized previously.
instance.substates = fresh_instance.substates
else:
instance = fresh_instance
state.substates[substate.get_name()] = instance
instance.parent_state = state

await self.populate_substates(client_token, instance, root_state)

@override
async def get_state(
self,
token: str,
) -> BaseState:
"""Get the state for a token.

Args:
token: The token to get the state for.

Returns:
The state for the token.
"""
client_token = _split_substate_key(token)[0]
root_state = self.states.get(client_token)
if root_state is not None:
# Retrieved state from memory.
return root_state

# Deserialize root state from disk.
root_state = await self.load_state(_substate_key(client_token, self.state))
# Create a new root state tree with all substates instantiated.
fresh_root_state = self.state(_reflex_internal_init=True)
if root_state is None:
root_state = fresh_root_state
else:
# Ensure all substates exist, even if they were not serialized previously.
root_state.substates = fresh_root_state.substates
self.states[client_token] = root_state
await self.populate_substates(client_token, root_state, root_state)
return root_state

async def set_state_for_substate(self, client_token: str, substate: BaseState):
"""Set the state for a substate.

Args:
client_token: The client token.
substate: The substate to set.
"""
substate_token = _substate_key(client_token, substate)

if substate._get_was_touched():
substate._was_touched = False # Reset the touched flag after serializing.
pickle_state = substate._serialize()
if pickle_state:
if not self.states_directory.exists():
self.states_directory.mkdir(parents=True, exist_ok=True)
self.token_path(substate_token).write_bytes(pickle_state)

for substate_substate in substate.substates.values():
await self.set_state_for_substate(client_token, substate_substate)

@override
async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.

Args:
token: The token to set the state for.
state: The state to set.
"""
client_token, _ = _split_substate_key(token)
await self.set_state_for_substate(client_token, state)

@override
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.

Args:
token: The token to modify the state for.

Yields:
The state for the token.
"""
# Disk state manager ignores the substate suffix and always returns the top-level state.
client_token, _ = _split_substate_key(token)
if client_token not in self._states_locks:
async with self._state_manager_lock:
if client_token not in self._states_locks:
self._states_locks[client_token] = asyncio.Lock()

async with self._states_locks[client_token]:
state = await self.get_state(token)
yield state
await self.set_state(token, state)
Loading