Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyi_hashes.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"reflex/__init__.pyi": "b304ed6f7a2fa028a194cad81bd83112",
"reflex/__init__.pyi": "0a3ae880e256b9fd3b960e12a2cb51a7",
"reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb",
"reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a",
"reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1",
Expand Down
1 change: 1 addition & 0 deletions reflex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@
"State",
"dynamic",
],
"istate.shared": ["SharedState"],
"istate.wrappers": ["get_state"],
"style": ["Style", "toggle_color_mode"],
"utils.imports": ["ImportDict", "ImportVar"],
Expand Down
16 changes: 12 additions & 4 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,13 +1562,17 @@ def all_routes(_request: Request) -> Response:

@contextlib.asynccontextmanager
async def modify_state(
self, token: str, background: bool = False
self,
token: str,
background: bool = False,
previous_dirty_vars: set[str] | None = None,
) -> AsyncIterator[BaseState]:
"""Modify the state out of band.

Args:
token: The token to modify the state for.
background: Whether the modification is happening in a background task.
previous_dirty_vars: Vars that are considered dirty from a previous operation.

Yields:
The state to modify.
Expand All @@ -1581,7 +1585,9 @@ async def modify_state(
raise RuntimeError(msg)

# Get exclusive access to the state.
async with self.state_manager.modify_state(token) as state:
async with self.state_manager.modify_state_with_links(
token, previous_dirty_vars=previous_dirty_vars
) as state:
# No other event handler can modify the state while in this context.
yield state
delta = await state._get_resolved_delta()
Expand Down Expand Up @@ -1769,7 +1775,7 @@ async def process(
constants.RouteVar.CLIENT_IP: client_ip,
})
# Get the state for the session exclusively.
async with app.state_manager.modify_state(
async with app.state_manager.modify_state_with_links(
event.substate_token, event=event
) as state:
# When this is a brand new instance of the state, signal the
Expand Down Expand Up @@ -2003,7 +2009,9 @@ async def _ndjson_updates():
Each state update as JSON followed by a new line.
"""
# Process the event.
async with app.state_manager.modify_state(event.substate_token) as state:
async with app.state_manager.modify_state_with_links(
event.substate_token
) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)
Expand Down
30 changes: 30 additions & 0 deletions reflex/istate/manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,36 @@ async def modify_state(
"""
yield self.state()

@contextlib.asynccontextmanager
async def modify_state_with_links(
self,
token: str,
previous_dirty_vars: set[str] | None = None,
**context: Unpack[StateModificationContext],
) -> AsyncIterator[BaseState]:
"""Modify the state for a token, including linked substates, while holding exclusive lock.

Args:
token: The token to modify the state for.
previous_dirty_vars: The previously dirty vars for linked states.
context: The state modification context.

Yields:
The state for the token with linked states patched in.
"""
from reflex.istate.shared import SharedStateBaseInternal

shared_state_name = SharedStateBaseInternal.get_name()

async with self.modify_state(token, **context) as root_state:
if shared_state_name in root_state.substates:
async with root_state.substates[
shared_state_name
]._modify_linked_states(previous_dirty_vars=previous_dirty_vars) as _:
yield root_state
else:
yield root_state

async def close(self): # noqa: B027
"""Close the state manager."""

Expand Down
255 changes: 255 additions & 0 deletions reflex/istate/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
"""Base classes for shared / linked states."""

import contextlib
from collections.abc import AsyncIterator

from reflex.event import Event, get_hydrate_event
from reflex.state import BaseState, State, _override_base_method, _substate_key
from reflex.utils.exceptions import ReflexRuntimeError


class SharedStateBaseInternal(State):
"""The private base state for all shared states."""

# Maps the state full_name to an arbitrary token it is linked to.
_links: dict[str, str]
# While _modify_linked_states is active, this holds the original substates for the client's tree.
_original_substates: dict[str, tuple[BaseState, BaseState | None]]

@classmethod
def _init_var_dependency_dicts(cls):
super()._init_var_dependency_dicts()
if (
"_links" in cls.inherited_backend_vars
or (parent_state_cls := cls.get_parent_state()) is None
):
return
# Mark the internal state as always dirty so the state manager
# automatically fetches this state containing the _links.
parent_state_cls._always_dirty_substates.add(cls.get_name())

def __getstate__(self):
"""Override redis serialization to remove temporary fields.

Returns:
The state dictionary without temporary fields.
"""
s = super().__getstate__()
# Don't want to persist the cached substates
s.pop("_original_substates", None)
s.pop("_previous_dirty_vars", None)
return s

@_override_base_method
def _clean(self):
"""Override BaseState._clean to track the last set of dirty vars.

This is necessary for applying dirty vars from one event to other linked states.
"""
if hasattr(self, "_previous_dirty_vars"):
self._previous_dirty_vars.clear()
self._previous_dirty_vars.update(self.dirty_vars)
super()._clean()

@_override_base_method
def _mark_dirty(self):
"""Override BaseState._mark_dirty to avoid marking certain vars as dirty.

Since these internal fields are not persisted to redis, they shouldn't cause the
state to be considered dirty either.
"""
self.dirty_vars.discard("_original_substates")
self.dirty_vars.discard("_previously_dirty_substates")
if self.dirty_vars:
super()._mark_dirty()

def _rehydrate(self):
"""Get the events to rehydrate the state.

Returns:
The events to rehydrate the state (these should be returned/yielded).
"""
return [
Event(
token=self.router.session.client_token,
name=get_hydrate_event(self._get_root_state()),
),
State.set_is_hydrated(True),
]

async def _link_to(self, token: str):
"""Link this shared state to a token.

After linking, subsequent access to this shared state will affect the
linked token's state, and cause changes to be propagated to all other
clients linked to that token.

Args:
token: The token to link to.

Returns:
The events to rehydrate the state after linking (these should be returned/yielded).
"""
# TODO: Change StateManager to accept token + class instead of combining them in a string.
if "_" in token:
msg = f"Invalid token {token} for linking state {self.get_full_name()}, cannot use underscore (_) in the token name."
raise ReflexRuntimeError(msg)
state_name = self.get_full_name()
self._links[state_name] = token
async with self._modify_linked_states() as _:
linked_state = await self.get_state(type(self))
linked_state._linked_from.add(self.router.session.client_token)
linked_state._linked_to = token
linked_state.dirty_vars.update(self.base_vars)
linked_state.dirty_vars.update(self.backend_vars)
linked_state.dirty_vars.update(self.computed_vars)
linked_state._mark_dirty()
# Apply the updates into the existing state tree, then rehydrate.
root_state = self._get_root_state()
await root_state._get_resolved_delta()
root_state._clean()
return self._rehydrate()

async def _unlink(self):
"""Unlink this shared state from its linked token.

Returns:
The events to rehydrate the state after unlinking (these should be returned/yielded
"""
state_name = self.get_full_name()
if state_name not in self._links:
msg = f"State {state_name} is not linked and cannot be unlinked."
raise ReflexRuntimeError(msg)
self._links.pop(state_name)
self._linked_from.discard(self.router.session.client_token)
# Rehydrate after unlinking to restore original values.
return self._rehydrate()

async def _restore_original_substates(self, *_exc_info) -> None:
"""Restore the original substates that were linked."""
root_state = self._get_root_state()
for linked_state_name, (
original_state,
linked_parent_state,
) in self._original_substates.items():
linked_state_cls = root_state.get_class_substate(linked_state_name)
linked_state = await root_state.get_state(linked_state_cls)
if (parent_state := linked_state.parent_state) is not None:
parent_state.substates[original_state.get_name()] = original_state
linked_state.parent_state = linked_parent_state
self._original_substates = {}

@contextlib.asynccontextmanager
async def _modify_linked_states(
self, previous_dirty_vars: dict[str, set[str]] | None = None
) -> AsyncIterator[None]:
"""Take lock, fetch all linked states, and patch them into the current state tree.

If previous_dirty_vars is NOT provided, then any dirty vars after
exiting the context will be applied to all other clients linked to this
state's linked token.

Args:
previous_dirty_vars: When apply linked state changes to other
tokens, provide mapping of state full_name to set of dirty vars.

Yields:
None.
"""
from reflex.istate.manager import get_state_manager

exit_stack = contextlib.AsyncExitStack()
held_locks: set[str] = set()
linked_states: list[BaseState] = []
current_dirty_vars: dict[str, set[str]] = {}
affected_tokens: set[str] = set()
# Go through all linked states and patch them in if they are present in the tree
for linked_state_name, linked_token in self._links.items():
linked_state_cls = self.get_root_state().get_class_substate(
linked_state_name
)
# TODO: Avoid always fetched linked states, it should be based on
# whether the state is accessed, however then `get_state` would need
# to know how to fetch in a linked state.
original_state = await self.get_state(linked_state_cls)
if linked_token not in held_locks:
linked_root_state = await exit_stack.enter_async_context(
get_state_manager().modify_state(
_substate_key(linked_token, linked_state_cls)
)
)
held_locks.add(linked_token)
else:
linked_root_state = await get_state_manager().get_state(
_substate_key(linked_token, linked_state_cls)
)
linked_state = await linked_root_state.get_state(linked_state_cls)
self._original_substates[linked_state_name] = (
original_state,
linked_state.parent_state,
)
if (parent_state := original_state.parent_state) is not None:
parent_state.substates[original_state.get_name()] = linked_state
linked_state.parent_state = parent_state
linked_states.append(linked_state)
if (
previous_dirty_vars
and (dv := previous_dirty_vars.get(linked_state_name)) is not None
):
linked_state.dirty_vars.update(dv)
linked_state._mark_dirty()
# Make sure to restore the non-linked substates after exiting the context.
if self._original_substates:
exit_stack.push_async_exit(self._restore_original_substates)
async with exit_stack:
yield None
# Collect dirty vars and other affected clients that need to be updated.
for linked_state in linked_states:
if hasattr(linked_state, "_previous_dirty_vars"):
current_dirty_vars[linked_state.get_full_name()] = set(
linked_state._previous_dirty_vars
)
if linked_state._get_was_touched():
affected_tokens.update(
token
for token in linked_state._linked_from
if token != self.router.session.client_token
)

# Only propagate dirty vars when we are not already propagating from another state.
if previous_dirty_vars is None:
from reflex.utils.prerequisites import get_app

app = get_app().app

for affected_token in affected_tokens:
# Don't send updates for disconnected clients.
if (
affected_token
not in app.event_namespace._token_manager.token_to_socket
):
continue
async with app.modify_state(
_substate_key(affected_token, type(self)),
previous_dirty_vars=current_dirty_vars,
):
pass


class SharedState(SharedStateBaseInternal, mixin=True):
"""Mixin for defining new shared states."""

_linked_from: set[str]
_linked_to: str
_previous_dirty_vars: set[str]

@classmethod
def __init_subclass__(cls, **kwargs):
"""Initialize subclass and set up shared state fields.

Args:
**kwargs: The kwargs to pass to the init_subclass method.
"""
kwargs["mixin"] = False
cls._mixin = False
super().__init_subclass__(**kwargs)
Loading
Loading