diff --git a/pyi_hashes.json b/pyi_hashes.json index f579320dbf2..e52649c2450 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -1,5 +1,5 @@ { - "reflex/__init__.pyi": "b304ed6f7a2fa028a194cad81bd83112", + "reflex/__init__.pyi": "0a3ae880e256b9fd3b960e12a2cb51a7", "reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb", "reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a", "reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1", diff --git a/reflex/__init__.py b/reflex/__init__.py index 48604373a60..066df110f02 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -336,6 +336,7 @@ "State", "dynamic", ], + "istate.shared": ["SharedState"], "istate.wrappers": ["get_state"], "style": ["Style", "toggle_color_mode"], "utils.imports": ["ImportDict", "ImportVar"], diff --git a/reflex/app.py b/reflex/app.py index 9fa9e2c43f9..32be4188feb 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1562,13 +1562,17 @@ def all_routes(_request: Request) -> Response: @contextlib.asynccontextmanager async def modify_state( - self, token: str, background: bool = False + self, + token: str, + background: bool = False, + previous_dirty_vars: dict[str, set[str]] | None = None, ) -> AsyncIterator[BaseState]: """Modify the state out of band. Args: token: The token to modify the state for. background: Whether the modification is happening in a background task. + previous_dirty_vars: Vars that are considered dirty from a previous operation. Yields: The state to modify. @@ -1581,7 +1585,9 @@ async def modify_state( raise RuntimeError(msg) # Get exclusive access to the state. - async with self.state_manager.modify_state(token) as state: + async with self.state_manager.modify_state_with_links( + token, previous_dirty_vars=previous_dirty_vars + ) as state: # No other event handler can modify the state while in this context. yield state delta = await state._get_resolved_delta() @@ -1769,7 +1775,7 @@ async def process( constants.RouteVar.CLIENT_IP: client_ip, }) # Get the state for the session exclusively. - async with app.state_manager.modify_state( + async with app.state_manager.modify_state_with_links( event.substate_token, event=event ) as state: # When this is a brand new instance of the state, signal the @@ -2003,7 +2009,9 @@ async def _ndjson_updates(): Each state update as JSON followed by a new line. """ # Process the event. - async with app.state_manager.modify_state(event.substate_token) as state: + async with app.state_manager.modify_state_with_links( + event.substate_token + ) as state: async for update in state._process(event): # Postprocess the event. update = await app._postprocess(state, event, update) diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index 3717bf08839..1eae7550de3 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -114,6 +114,35 @@ async def modify_state( """ yield self.state() + @contextlib.asynccontextmanager + async def modify_state_with_links( + self, + token: str, + previous_dirty_vars: dict[str, set[str]] | None = None, + **context: Unpack[StateModificationContext], + ) -> AsyncIterator[BaseState]: + """Modify the state for a token, including linked substates, while holding exclusive lock. + + Args: + token: The token to modify the state for. + previous_dirty_vars: The previously dirty vars for linked states. + context: The state modification context. + + Yields: + The state for the token with linked states patched in. + """ + async with self.modify_state(token, **context) as root_state: + if getattr(root_state, "_reflex_internal_links", None) is not None: + from reflex.istate.shared import SharedStateBaseInternal + + shared_state = await root_state.get_state(SharedStateBaseInternal) + async with shared_state._modify_linked_states( + previous_dirty_vars=previous_dirty_vars + ) as _: + yield root_state + else: + yield root_state + async def close(self): # noqa: B027 """Close the state manager.""" diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py new file mode 100644 index 00000000000..b07a489050a --- /dev/null +++ b/reflex/istate/shared.py @@ -0,0 +1,402 @@ +"""Base classes for shared / linked states.""" + +import asyncio +import contextlib +from collections.abc import AsyncIterator +from typing import Self, TypeVar + +from reflex.constants import ROUTER_DATA +from reflex.event import Event, get_hydrate_event +from reflex.state import BaseState, State, _override_base_method, _substate_key +from reflex.utils import console +from reflex.utils.exceptions import ReflexRuntimeError + +UPDATE_OTHER_CLIENT_TASKS: set[asyncio.Task] = set() +LINKED_STATE = TypeVar("LINKED_STATE", bound="SharedStateBaseInternal") + + +def _log_update_client_errors(task: asyncio.Task): + """Log errors from updating other clients. + + Args: + task: The asyncio task to check for errors. + """ + try: + task.result() + except Exception as e: + console.warn(f"Error updating linked client: {e}") + finally: + UPDATE_OTHER_CLIENT_TASKS.discard(task) + + +def _do_update_other_tokens( + affected_tokens: set[str], + previous_dirty_vars: dict[str, set[str]], + state_type: type[BaseState], +) -> list[asyncio.Task]: + """Update other clients after a shared state update. + + Submit the updates in separate asyncio tasks to avoid deadlocking. + + Args: + affected_tokens: The tokens to update. + previous_dirty_vars: The dirty vars to apply to other clients. + state_type: The type of the shared state. + + Returns: + The list of asyncio tasks created to perform the updates. + """ + from reflex.utils.prerequisites import get_app + + app = get_app().app + + async def _update_client(token: str): + async with app.modify_state( + _substate_key(token, state_type), + previous_dirty_vars=previous_dirty_vars, + ): + pass + + tasks = [] + for affected_token in affected_tokens: + # Don't send updates for disconnected clients. + if affected_token not in app.event_namespace._token_manager.token_to_socket: + continue + # TODO: remove disconnected clients after some time. + t = asyncio.create_task(_update_client(affected_token)) + UPDATE_OTHER_CLIENT_TASKS.add(t) + t.add_done_callback(_log_update_client_errors) + tasks.append(t) + return tasks + + +@contextlib.asynccontextmanager +async def _patch_state( + original_state: BaseState, linked_state: BaseState, full_delta: bool = False +): + """Patch the linked state into the original state's tree, restoring it afterward. + + Args: + original_state: The original shared state. + linked_state: The linked shared state. + full_delta: If True, mark all Vars in linked_state dirty and resolve + the delta from the root. This option is used when linking or unlinking + to ensure that other computed vars in the tree pick up the newly + linked/unlinked values. + """ + if (original_parent_state := original_state.parent_state) is None: + msg = "Cannot patch root state as linked state." + raise ReflexRuntimeError(msg) + + state_name = original_state.get_name() + original_parent_state.substates[state_name] = linked_state + linked_parent_state = linked_state.parent_state + linked_state.parent_state = original_parent_state + try: + if full_delta: + linked_state.dirty_vars.update(linked_state.base_vars) + linked_state.dirty_vars.update(linked_state.backend_vars) + linked_state.dirty_vars.update(linked_state.computed_vars) + linked_state._mark_dirty() + # Apply the updates into the existing state tree for rehydrate. + root_state = original_state._get_root_state() + root_state.dirty_vars.add("router") + root_state.dirty_vars.add(ROUTER_DATA) + root_state._mark_dirty() + await root_state._get_resolved_delta() + yield + finally: + original_parent_state.substates[state_name] = original_state + linked_state.parent_state = linked_parent_state + + +class SharedStateBaseInternal(State): + """The private base state for all shared states.""" + + _exit_stack: contextlib.AsyncExitStack | None = None + _held_locks: dict[str, dict[type[BaseState], BaseState]] | None = None + + def __getstate__(self): + """Override redis serialization to remove temporary fields. + + Returns: + The state dictionary without temporary fields. + """ + s = super().__getstate__() + s.pop("_previous_dirty_vars", None) + s.pop("_exit_stack", None) + s.pop("_held_locks", None) + return s + + @_override_base_method + def _clean(self): + """Override BaseState._clean to track the last set of dirty vars. + + This is necessary for applying dirty vars from one event to other linked states. + """ + if ( + previous_dirty_vars := getattr(self, "_previous_dirty_vars", None) + ) is not None: + previous_dirty_vars.clear() + previous_dirty_vars.update(self.dirty_vars) + super()._clean() + + @_override_base_method + def _mark_dirty(self): + """Override BaseState._mark_dirty to avoid marking certain vars as dirty. + + Since these internal fields are not persisted to redis, they shouldn't cause the + state to be considered dirty either. + """ + self.dirty_vars.discard("_previous_dirty_vars") + self.dirty_vars.discard("_exit_stack") + self.dirty_vars.discard("_held_locks") + # Only mark dirty if there are still dirty vars, or any substate is dirty + if self.dirty_vars or any( + substate.dirty_vars for substate in self.substates.values() + ): + super()._mark_dirty() + + def _rehydrate(self): + """Get the events to rehydrate the state. + + Returns: + The events to rehydrate the state (these should be returned/yielded). + """ + return [ + Event( + token=self.router.session.client_token, + name=get_hydrate_event(self._get_root_state()), + ), + State.set_is_hydrated(True), + ] + + async def _link_to(self, token: str) -> Self: + """Link this shared state to a token. + + After linking, subsequent access to this shared state will affect the + linked token's state, and cause changes to be propagated to all other + clients linked to that token. + + Args: + token: The token to link to (Cannot contain underscore characters). + + Returns: + The newly linked state. + + Raises: + ReflexRuntimeError: If linking fails or token is invalid. + """ + if not token: + msg = "Cannot link shared state to empty token." + raise ReflexRuntimeError(msg) + if self._linked_to == token: + return self # already linked to this token + if self._linked_to and self._linked_to != token: + # Disassociate from previous linked token since unlink will not be called. + self._linked_from.discard(self.router.session.client_token) + # TODO: Change StateManager to accept token + class instead of combining them in a string. + if "_" in token: + msg = f"Invalid token {token} for linking state {self.get_full_name()}, cannot use underscore (_) in the token name." + raise ReflexRuntimeError(msg) + + # Associate substate with the given link token. + state_name = self.get_full_name() + if self._reflex_internal_links is None: + self._reflex_internal_links = {} + self._reflex_internal_links[state_name] = token + return await self._internal_patch_linked_state(token, full_delta=True) + + async def _unlink(self): + """Unlink this shared state from its linked token. + + Returns: + The events to rehydrate the state after unlinking (these should be returned/yielded). + """ + from reflex.istate.manager import get_state_manager + + state_name = self.get_full_name() + if ( + not self._reflex_internal_links + or state_name not in self._reflex_internal_links + ): + msg = f"State {state_name} is not linked and cannot be unlinked." + raise ReflexRuntimeError(msg) + + # Break the linkage for future events. + self._reflex_internal_links.pop(state_name) + self._linked_from.discard(self.router.session.client_token) + + # Patch in the original state, apply updates, then rehydrate. + private_root_state = await get_state_manager().get_state( + _substate_key(self.router.session.client_token, type(self)) + ) + private_state = await private_root_state.get_state(type(self)) + async with _patch_state( + original_state=self, + linked_state=private_state, + full_delta=True, + ): + return self._rehydrate() + + async def _internal_patch_linked_state( + self, token: str, full_delta: bool = False + ) -> Self: + """Load and replace this state with the linked state for a given token. + + Must be called inside a `_modify_linked_states` context, to ensure locks are + released after the event is done processing. + + Args: + token: The token of the linked state. + full_delta: If True, mark all Vars in linked_state dirty and resolve + delta to update cached computed vars + + Returns: + The state that was linked into the tree. + """ + from reflex.istate.manager import get_state_manager + + if self._exit_stack is None or self._held_locks is None: + msg = "Cannot link shared state outside of _modify_linked_states context." + raise ReflexRuntimeError(msg) + + # Get the newly linked state and update pointers/delta for subsequent events. + if token not in self._held_locks: + linked_root_state = await self._exit_stack.enter_async_context( + get_state_manager().modify_state(_substate_key(token, type(self))) + ) + self._held_locks.setdefault(token, {}) + else: + linked_root_state = await get_state_manager().get_state( + _substate_key(token, type(self)) + ) + linked_state = await linked_root_state.get_state(type(self)) + # Avoid unnecessary dirtiness of shared state when there are no changes. + if type(self) not in self._held_locks[token]: + self._held_locks[token][type(self)] = linked_state + if self.router.session.client_token not in linked_state._linked_from: + linked_state._linked_from.add(self.router.session.client_token) + if linked_state._linked_to != token: + linked_state._linked_to = token + await self._exit_stack.enter_async_context( + _patch_state( + original_state=self, + linked_state=linked_state, + full_delta=full_delta, + ) + ) + return linked_state + + def _held_locks_linked_states(self) -> list["SharedState"]: + """Get all linked states currently held by this state. + + Returns: + The list of linked states currently held. + """ + if self._held_locks is None: + return [] + return [ + linked_state + for linked_state_cls_to_instance in self._held_locks.values() + for linked_state in linked_state_cls_to_instance.values() + if isinstance(linked_state, SharedState) + ] + + @contextlib.asynccontextmanager + async def _modify_linked_states( + self, previous_dirty_vars: dict[str, set[str]] | None = None + ) -> AsyncIterator[None]: + """Take lock, fetch all linked states, and patch them into the current state tree. + + If previous_dirty_vars is NOT provided, then any dirty vars after + exiting the context will be applied to all other clients linked to this + state's linked token. + + Args: + previous_dirty_vars: When apply linked state changes to other + tokens, provide mapping of state full_name to set of dirty vars. + + Yields: + None. + """ + if self._exit_stack is not None: + msg = "Cannot nest _modify_linked_states contexts." + raise ReflexRuntimeError(msg) + if self._reflex_internal_links is None: + msg = "No linked states to modify." + raise ReflexRuntimeError(msg) + self._exit_stack = contextlib.AsyncExitStack() + self._held_locks = {} + current_dirty_vars: dict[str, set[str]] = {} + affected_tokens: set[str] = set() + try: + # Go through all linked states and patch them in if they are present in the tree + for linked_state_name, linked_token in self._reflex_internal_links.items(): + linked_state_cls: type[SharedState] = ( + self.get_root_state().get_class_substate( # pyright: ignore[reportAssignmentType] + linked_state_name + ) + ) + # TODO: Avoid always fetched linked states, it should be based on + # whether the state is accessed, however then `get_state` would need + # to know how to fetch in a linked state. + original_state = await self.get_state(linked_state_cls) + linked_state = await original_state._internal_patch_linked_state( + linked_token + ) + if ( + previous_dirty_vars + and (dv := previous_dirty_vars.get(linked_state_name)) is not None + ): + linked_state.dirty_vars.update(dv) + linked_state._mark_dirty() + async with self._exit_stack: + yield None + # Collect dirty vars and other affected clients that need to be updated. + for linked_state in self._held_locks_linked_states(): + if linked_state._previous_dirty_vars is not None: + current_dirty_vars[linked_state.get_full_name()] = set( + linked_state._previous_dirty_vars + ) + if ( + linked_state._get_was_touched() + or linked_state._previous_dirty_vars is not None + ): + affected_tokens.update( + token + for token in linked_state._linked_from + if token != self.router.session.client_token + ) + finally: + self._exit_stack = None + + # Only propagate dirty vars when we are not already propagating from another state. + if previous_dirty_vars is None: + _do_update_other_tokens( + affected_tokens=affected_tokens, + previous_dirty_vars=current_dirty_vars, + state_type=type(self), + ) + + +class SharedState(SharedStateBaseInternal, mixin=True): + """Mixin for defining new shared states.""" + + _linked_from: set[str] = set() + _linked_to: str = "" + _previous_dirty_vars: set[str] = set() + + @classmethod + def __init_subclass__(cls, **kwargs): + """Initialize subclass and set up shared state fields. + + Args: + **kwargs: The kwargs to pass to the init_subclass method. + """ + kwargs["mixin"] = False + cls._mixin = False + super().__init_subclass__(**kwargs) + root_state = cls.get_root_state() + if root_state.backend_vars["_reflex_internal_links"] is None: + root_state.backend_vars["_reflex_internal_links"] = {} diff --git a/reflex/state.py b/reflex/state.py index 4c9b9136ec2..f08327c27fb 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -21,7 +21,16 @@ from hashlib import md5 from importlib.util import find_spec from types import FunctionType -from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, TypeVar, cast, get_type_hints +from typing import ( + TYPE_CHECKING, + Any, + BinaryIO, + ClassVar, + ParamSpec, + TypeVar, + cast, + get_type_hints, +) from rich.markup import escape from typing_extensions import Self @@ -298,6 +307,23 @@ async def _resolve_delta(delta: Delta) -> Delta: return delta +RETURN = TypeVar("RETURN") +PARAMS = ParamSpec("PARAMS") + + +def _override_base_method(fn: Callable[PARAMS, RETURN]) -> Callable[PARAMS, RETURN]: + """Mark a method as overriding a base method. + + Args: + fn: The function to mark. + + Returns: + The marked function. + """ + fn.__override_base_method__ = True # pyright: ignore[reportFunctionMemberAccess] + return fn + + _deserializers = { int: int, float: float, @@ -851,6 +877,7 @@ def _check_overridden_methods(cls): not name.startswith("__") and method.__name__ in state_base_functions and state_base_functions[method.__name__] != method + and not getattr(method, "__override_base_method__", False) ): overridden_methods.add(method.__name__) @@ -2437,6 +2464,8 @@ class State(BaseState): # The hydrated bool. is_hydrated: bool = False + # Maps the state full_name to an arbitrary token it is linked to for shared state. + _reflex_internal_links: dict[str, str] | None = None @event def set_is_hydrated(self, value: bool) -> None: diff --git a/tests/integration/test_computed_vars.py b/tests/integration/test_computed_vars.py index 7877f55daf2..3b929ab149c 100644 --- a/tests/integration/test_computed_vars.py +++ b/tests/integration/test_computed_vars.py @@ -74,7 +74,7 @@ def increment(self): def mark_dirty(self): self._mark_dirty() - assert State.backend_vars == {} + assert State.backend_vars == {"_reflex_internal_links": None} def index() -> rx.Component: return rx.center( diff --git a/tests/integration/test_linked_state.py b/tests/integration/test_linked_state.py new file mode 100644 index 00000000000..c892c48d7aa --- /dev/null +++ b/tests/integration/test_linked_state.py @@ -0,0 +1,388 @@ +"""Test linked state.""" + +from __future__ import annotations + +import uuid +from collections.abc import Callable, Generator + +import pytest +from selenium.webdriver.common.by import By +from selenium.webdriver.common.keys import Keys + +from reflex.testing import AppHarness, WebDriver + +from . import utils + + +def LinkedStateApp(): + """Test that linked state works as expected.""" + import uuid + from typing import Any + + import reflex as rx + + class SharedState(rx.SharedState): + _who: str = "world" + n_changes: int = 0 + counter: int = 0 + + @rx.event + def set_counter(self, value: int) -> None: + self.counter = value + + @rx.event + def set_who(self, who: str) -> None: + self._who = who + self.n_changes += 1 + + @rx.event + async def link_to(self, token: str): + await self._link_to(token) + + @rx.event + async def link_to_and_increment(self): + linked_state = await self._link_to(f"arbitrary-token-{uuid.uuid4()}") + linked_state.counter += 1 + + @rx.event + async def unlink(self): + return await self._unlink() + + @rx.event + async def on_load_link_default(self): + linked_state = await self._link_to(self.room or "default") + if self.room: + assert linked_state._linked_to == self.room + else: + assert linked_state._linked_to == "default" + + @rx.event + async def handle_submit(self, form_data: dict[str, Any]): + if "who" in form_data: + self.set_who(form_data["who"]) + if "token" in form_data: + await self.link_to(form_data["token"]) + + class PrivateState(rx.State): + @rx.var + async def greeting(self) -> str: + ss = await self.get_state(SharedState) + return f"Hello, {ss._who}!" + + @rx.var + async def linked_to(self) -> str: + ss = await self.get_state(SharedState) + return ss._linked_to + + @rx.event(background=True) + async def bump_counter_bg(self): + for _ in range(5): + async with self: + ss = await self.get_state(SharedState) + ss.counter += 1 + async with self: + ss = await self.get_state(SharedState) + for _ in range(5): + async with ss: + ss.counter += 1 + + @rx.event + async def bump_counter_yield(self): + ss = await self.get_state(SharedState) + for _ in range(5): + ss.counter += 1 + yield + + def index() -> rx.Component: + return rx.vstack( + rx.text( + SharedState.n_changes, + id="n-changes", + ), + rx.text( + PrivateState.greeting, + id="greeting", + ), + rx.form( + rx.input(name="who", id="who-input"), + rx.button("Set Who"), + on_submit=SharedState.handle_submit, + reset_on_submit=True, + ), + rx.text(PrivateState.linked_to, id="linked-to"), + rx.button("Unlink", id="unlink-button", on_click=SharedState.unlink), + rx.form( + rx.input(name="token", id="token-input"), + rx.button("Link To Token"), + on_submit=SharedState.handle_submit, + reset_on_submit=True, + ), + rx.button( + SharedState.counter, + id="counter-button", + on_click=SharedState.set_counter(SharedState.counter + 1), + on_context_menu=SharedState.set_counter( + SharedState.counter - 1 + ).prevent_default, + ), + rx.button( + "Bump Counter in Background", + on_click=PrivateState.bump_counter_bg, + id="bg-button", + ), + rx.button( + "Bump Counter with Yield", + on_click=PrivateState.bump_counter_yield, + id="yield-button", + ), + rx.button( + "Link to arbitrary token and Increment n_changes", + on_click=SharedState.link_to_and_increment, + id="link-increment-button", + ), + ) + + app = rx.App() + app.add_page(index, route="/room/[room]", on_load=SharedState.on_load_link_default) + app.add_page(index) + + +@pytest.fixture +def linked_state( + tmp_path_factory, +) -> Generator[AppHarness, None, None]: + """Start LinkedStateApp at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + + """ + with AppHarness.create( + root=tmp_path_factory.mktemp("linked_state"), + app_source=LinkedStateApp, + ) as harness: + yield harness + + +@pytest.fixture +def tab_factory( + linked_state: AppHarness, +) -> Generator[Callable[[], WebDriver], None, None]: + """Get an instance of the browser open to the linked_state app. + + Args: + linked_state: harness for LinkedStateApp + + Yields: + WebDriver instance. + + """ + assert linked_state.app_instance is not None, "app is not running" + + drivers = [] + + def driver() -> WebDriver: + d = linked_state.frontend() + drivers.append(d) + return d + + try: + yield driver + finally: + for d in drivers: + d.quit() + + +def test_linked_state( + linked_state: AppHarness, + tab_factory: Callable[[], WebDriver], +): + """Test that multiple tabs can link to and share state. + + Args: + linked_state: harness for LinkedStateApp. + tab_factory: factory to create WebDriver instances. + + """ + assert linked_state.app_instance is not None + + tab1 = tab_factory() + tab2 = tab_factory() + ss = utils.SessionStorage(tab1) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + n_changes_1 = tab1.find_element(By.ID, "n-changes") + greeting_1 = tab1.find_element(By.ID, "greeting") + ss = utils.SessionStorage(tab2) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + n_changes_2 = tab2.find_element(By.ID, "n-changes") + greeting_2 = tab2.find_element(By.ID, "greeting") + + # Initial state + assert n_changes_1.text == "0" + assert greeting_1.text == "Hello, world!" + assert n_changes_2.text == "0" + assert greeting_2.text == "Hello, world!" + + # Change state in tab 1 + tab1.find_element(By.ID, "who-input").send_keys("Alice", Keys.ENTER) + assert linked_state.poll_for_content(n_changes_1, exp_not_equal="0") == "1" + assert ( + linked_state.poll_for_content(greeting_1, exp_not_equal="Hello, world!") + == "Hello, Alice!" + ) + + # Change state in tab 2 + tab2.find_element(By.ID, "who-input").send_keys("Bob", Keys.ENTER) + assert linked_state.poll_for_content(n_changes_2, exp_not_equal="0") == "1" + assert ( + linked_state.poll_for_content(greeting_2, exp_not_equal="Hello, world!") + == "Hello, Bob!" + ) + + # Link both tabs to the same token, "shared-foo" + shared_token = f"shared-foo-{uuid.uuid4()}" + for tab in (tab1, tab2): + tab.find_element(By.ID, "token-input").send_keys(shared_token, Keys.ENTER) + assert linked_state.poll_for_content(n_changes_1, exp_not_equal="1") == "0" + assert ( + linked_state.poll_for_content(greeting_1, exp_not_equal="Hello, Alice!") + == "Hello, world!" + ) + assert linked_state.poll_for_content(n_changes_2, exp_not_equal="1") == "0" + assert ( + linked_state.poll_for_content(greeting_2, exp_not_equal="Hello, Bob!") + == "Hello, world!" + ) + + # Set a new value in tab 1, should reflect in tab 2 + tab1.find_element(By.ID, "who-input").send_keys("Charlie", Keys.ENTER) + assert linked_state.poll_for_content(n_changes_1, exp_not_equal="0") == "1" + assert ( + linked_state.poll_for_content(greeting_1, exp_not_equal="Hello, world!") + == "Hello, Charlie!" + ) + assert linked_state.poll_for_content(n_changes_2, exp_not_equal="0") == "1" + assert ( + linked_state.poll_for_content(greeting_2, exp_not_equal="Hello, world!") + == "Hello, Charlie!" + ) + + # Bump the counter in tab 2, should reflect in tab 1 + counter_button_1 = tab1.find_element(By.ID, "counter-button") + counter_button_2 = tab2.find_element(By.ID, "counter-button") + assert counter_button_1.text == "0" + assert counter_button_2.text == "0" + counter_button_2.click() + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "1" + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="0") == "1" + counter_button_1.click() + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="1") == "2" + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="1") == "2" + counter_button_2.click() + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="2") == "3" + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="2") == "3" + + # Unlink tab 2, should revert to previous private values + tab2.find_element(By.ID, "unlink-button").click() + assert n_changes_2.text == "1" + assert ( + linked_state.poll_for_content(greeting_2, exp_not_equal="Hello, Charlie!") + == "Hello, Bob!" + ) + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="3") == "0" + + # Relink tab 2, should go back to shared values + tab2.find_element(By.ID, "token-input").send_keys(shared_token, Keys.ENTER) + assert n_changes_2.text == "1" + assert ( + linked_state.poll_for_content(greeting_2, exp_not_equal="Hello, Bob!") + == "Hello, Charlie!" + ) + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="0") == "3" + + # Unlink tab 1, change the shared value in tab 2, and relink tab 1 + tab1.find_element(By.ID, "unlink-button").click() + assert n_changes_1.text == "1" + assert ( + linked_state.poll_for_content(greeting_1, exp_not_equal="Hello, Charlie!") + == "Hello, Alice!" + ) + tab2.find_element(By.ID, "who-input").send_keys("Diana", Keys.ENTER) + assert linked_state.poll_for_content(n_changes_2, exp_not_equal="1") == "2" + assert ( + linked_state.poll_for_content(greeting_2, exp_not_equal="Hello, Charlie!") + == "Hello, Diana!" + ) + assert counter_button_2.text == "3" + assert n_changes_1.text == "1" + assert greeting_1.text == "Hello, Alice!" + tab1.find_element(By.ID, "token-input").send_keys(shared_token, Keys.ENTER) + assert linked_state.poll_for_content(n_changes_1, exp_not_equal="1") == "2" + assert ( + linked_state.poll_for_content(greeting_1, exp_not_equal="Hello, Alice!") + == "Hello, Diana!" + ) + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "3" + + # Open a third tab linked to the shared token on_load + tab3 = tab_factory() + tab3.get(f"{linked_state.frontend_url}room/{shared_token}") + ss = utils.SessionStorage(tab3) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + n_changes_3 = AppHarness._poll_for(lambda: tab3.find_element(By.ID, "n-changes")) + assert n_changes_3 + greeting_3 = tab3.find_element(By.ID, "greeting") + counter_button_3 = tab3.find_element(By.ID, "counter-button") + assert linked_state.poll_for_content(n_changes_3, exp_not_equal="0") == "2" + assert ( + linked_state.poll_for_content(greeting_3, exp_not_equal="Hello, world!") + == "Hello, Diana!" + ) + assert linked_state.poll_for_content(counter_button_3, exp_not_equal="0") == "3" + assert tab3.find_element(By.ID, "linked-to").text == shared_token + + # Trigger a background task in all shared states, assert on final value + tab1.find_element(By.ID, "bg-button").click() + tab2.find_element(By.ID, "bg-button").click() + tab3.find_element(By.ID, "bg-button").click() + assert AppHarness._poll_for(lambda: counter_button_1.text == "33") + assert AppHarness._poll_for(lambda: counter_button_2.text == "33") + assert AppHarness._poll_for(lambda: counter_button_3.text == "33") + + # Trigger a yield-based task in all shared states, assert on final value + tab1.find_element(By.ID, "yield-button").click() + tab2.find_element(By.ID, "yield-button").click() + tab3.find_element(By.ID, "yield-button").click() + assert AppHarness._poll_for(lambda: counter_button_1.text == "48") + assert AppHarness._poll_for(lambda: counter_button_2.text == "48") + assert AppHarness._poll_for(lambda: counter_button_3.text == "48") + + # Link to a new token when we're already linked + new_shared_token = f"shared-bar-{uuid.uuid4()}" + tab1.find_element(By.ID, "token-input").send_keys(new_shared_token, Keys.ENTER) + assert linked_state.poll_for_content(n_changes_1, exp_not_equal="2") == "0" + assert ( + linked_state.poll_for_content(greeting_1, exp_not_equal="Hello, Diana!") + == "Hello, world!" + ) + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="48") == "0" + counter_button_1.click() + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "1" + counter_button_1.click() + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="1") == "2" + counter_button_1.click() + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="2") == "3" + # Ensure other tabs are unaffected + assert n_changes_2.text == "2" + assert greeting_2.text == "Hello, Diana!" + assert counter_button_2.text == "48" + assert n_changes_3.text == "2" + assert greeting_3.text == "Hello, Diana!" + assert counter_button_3.text == "48" + + # Link to a new state and increment the counter in the same event + tab1.find_element(By.ID, "link-increment-button").click() + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="3") == "1" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 941c569095f..c186fba1f89 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3656,7 +3656,11 @@ def test_mixin_state() -> None: """Test that a mixin state works correctly.""" assert "num" in UsesMixinState.base_vars assert "num" in UsesMixinState.vars - assert UsesMixinState.backend_vars == {"_backend": 0, "_backend_no_default": {}} + assert UsesMixinState.backend_vars == { + "_backend": 0, + "_backend_no_default": {}, + "_reflex_internal_links": None, + } assert "computed" in UsesMixinState.computed_vars assert "computed" in UsesMixinState.vars