Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyi_hashes.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"reflex/__init__.pyi": "b304ed6f7a2fa028a194cad81bd83112",
"reflex/__init__.pyi": "0a3ae880e256b9fd3b960e12a2cb51a7",
"reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb",
"reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a",
"reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1",
Expand Down
1 change: 1 addition & 0 deletions reflex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@
"State",
"dynamic",
],
"istate.shared": ["SharedState"],
"istate.wrappers": ["get_state"],
"style": ["Style", "toggle_color_mode"],
"utils.imports": ["ImportDict", "ImportVar"],
Expand Down
16 changes: 12 additions & 4 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,13 +1562,17 @@ def all_routes(_request: Request) -> Response:

@contextlib.asynccontextmanager
async def modify_state(
self, token: str, background: bool = False
self,
token: str,
background: bool = False,
previous_dirty_vars: dict[str, set[str]] | None = None,
) -> AsyncIterator[BaseState]:
"""Modify the state out of band.

Args:
token: The token to modify the state for.
background: Whether the modification is happening in a background task.
previous_dirty_vars: Vars that are considered dirty from a previous operation.

Yields:
The state to modify.
Expand All @@ -1581,7 +1585,9 @@ async def modify_state(
raise RuntimeError(msg)

# Get exclusive access to the state.
async with self.state_manager.modify_state(token) as state:
async with self.state_manager.modify_state_with_links(
token, previous_dirty_vars=previous_dirty_vars
) as state:
# No other event handler can modify the state while in this context.
yield state
delta = await state._get_resolved_delta()
Expand Down Expand Up @@ -1769,7 +1775,7 @@ async def process(
constants.RouteVar.CLIENT_IP: client_ip,
})
# Get the state for the session exclusively.
async with app.state_manager.modify_state(
async with app.state_manager.modify_state_with_links(
event.substate_token, event=event
) as state:
# When this is a brand new instance of the state, signal the
Expand Down Expand Up @@ -2003,7 +2009,9 @@ async def _ndjson_updates():
Each state update as JSON followed by a new line.
"""
# Process the event.
async with app.state_manager.modify_state(event.substate_token) as state:
async with app.state_manager.modify_state_with_links(
event.substate_token
) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)
Expand Down
29 changes: 29 additions & 0 deletions reflex/istate/manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,35 @@ async def modify_state(
"""
yield self.state()

@contextlib.asynccontextmanager
async def modify_state_with_links(
self,
token: str,
previous_dirty_vars: dict[str, set[str]] | None = None,
**context: Unpack[StateModificationContext],
) -> AsyncIterator[BaseState]:
"""Modify the state for a token, including linked substates, while holding exclusive lock.

Args:
token: The token to modify the state for.
previous_dirty_vars: The previously dirty vars for linked states.
context: The state modification context.

Yields:
The state for the token with linked states patched in.
"""
async with self.modify_state(token, **context) as root_state:
if getattr(root_state, "_reflex_internal_links", None) is not None:
from reflex.istate.shared import SharedStateBaseInternal

shared_state = await root_state.get_state(SharedStateBaseInternal)
async with shared_state._modify_linked_states(
previous_dirty_vars=previous_dirty_vars
) as _:
yield root_state
else:
yield root_state

async def close(self): # noqa: B027
"""Close the state manager."""

Expand Down
Loading
Loading