diff --git a/reflex/state.py b/reflex/state.py index f47cb207113..0761b0978f6 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -366,6 +366,9 @@ class BaseState(EvenMoreBasicBaseState): # The parent state. parent_state: BaseState | None = field(default=None, is_var=False) + # Events triggered by computed vars. + _computed_var_events: list[EventSpec] = field(default_factory=list, is_var=False) + # The substates of the state. substates: builtins.dict[str, BaseState] = field( default_factory=builtins.dict, is_var=False @@ -441,6 +444,9 @@ def __init__( # Create a fresh copy of the backend variables for this instance self._backend_vars = copy.deepcopy(self.backend_vars) + # Initialize computed var events list + self._computed_var_events = [] + def __repr__(self) -> str: """Get the string representation of the state. @@ -1795,6 +1801,14 @@ async def _as_state_update( try: # Get the delta after processing the event. delta = await state._get_resolved_delta() + + # Collect events from computed vars + computed_var_events = state._collect_computed_var_events() + if computed_var_events: + fixed_events.extend( + fix_events(self._check_valid(handler, computed_var_events), token) + ) + state._clean() return StateUpdate( @@ -1964,14 +1978,21 @@ async def _process_event( final=True, ) - def _mark_dirty_computed_vars(self) -> None: - """Mark ComputedVars that need to be recalculated based on dirty_vars.""" - # Append expired computed vars to dirty_vars to trigger recalculation - self.dirty_vars.update(self._expired_computed_vars()) - # Append always dirty computed vars to dirty_vars to trigger recalculation - self.dirty_vars.update(self._always_dirty_computed_vars) + def _mark_dirty_computed_vars(self, from_vars: set[str] | None = None) -> None: + """Mark ComputedVars that need to be recalculated based on dirty_vars. + + Args: + from_vars: The vars to start the propagation from. + """ + if from_vars is None: + # Append expired computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._expired_computed_vars()) + # Append always dirty computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._always_dirty_computed_vars) + dirty_vars = self.dirty_vars + else: + dirty_vars = from_vars - dirty_vars = self.dirty_vars while dirty_vars: calc_vars, dirty_vars = dirty_vars, set() for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars): @@ -2022,6 +2043,21 @@ def _dirty_computed_vars( if include_backend or not self.computed_vars[cvar]._backend } + def _collect_computed_var_events(self) -> list[EventSpec]: + """Collect events triggered by computed vars. + + Returns: + The list of events. + """ + events = self._computed_var_events + self._computed_var_events = [] + + for substate in self.dirty_substates.union(self._always_dirty_substates): + if substate in self.substates: + events.extend(self.substates[substate]._collect_computed_var_events()) + + return events + def get_delta(self) -> Delta: """Get the delta for the state. @@ -2030,22 +2066,36 @@ def get_delta(self) -> Delta: """ delta = {} - self._mark_dirty_computed_vars() - frontend_computed_vars: set[str] = { - name for name, cv in self.computed_vars.items() if not cv._backend - } + # Loop to stabilize state + # We limit iterations to avoid infinite loops (e.g. oscillating states) + + previous_dirty_vars = self.dirty_vars.copy() - # Return the dirty vars for this instance, any cached/dependent computed vars, - # and always dirty computed vars (cache=False) - delta_vars = self.dirty_vars.intersection(self.base_vars).union( - self.dirty_vars.intersection(frontend_computed_vars) - ) + for i in range(10): + if i == 0: + self._mark_dirty_computed_vars() + else: + new_dirty_vars = self.dirty_vars - previous_dirty_vars + if not new_dirty_vars: + break + self._mark_dirty_computed_vars(from_vars=new_dirty_vars) + previous_dirty_vars = self.dirty_vars.copy() - subdelta: dict[str, Any] = { - prop + FIELD_MARKER: self.get_value(prop) - for prop in delta_vars - if not types.is_backend_base_variable(prop, type(self)) - } + frontend_computed_vars: set[str] = { + name for name, cv in self.computed_vars.items() if not cv._backend + } + + # Return the dirty vars for this instance, any cached/dependent computed vars, + # and always dirty computed vars (cache=False) + delta_vars = self.dirty_vars.intersection(self.base_vars).union( + self.dirty_vars.intersection(frontend_computed_vars) + ) + + subdelta: dict[str, Any] = { + prop + FIELD_MARKER: self.get_value(prop) + for prop in delta_vars + if not types.is_backend_base_variable(prop, type(self)) + } if len(subdelta) > 0: delta[self.get_full_name()] = subdelta diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 412c0e41c70..bcd059dd3a9 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2273,17 +2273,47 @@ def needs_update(self, instance: BaseState) -> bool: """Check if the computed var needs to be updated. Args: - instance: The state instance that the computed var is attached to. + instance: The state instance. Returns: - True if the computed var needs to be updated, False otherwise. + Whether the computed var needs to be updated. """ - if self._update_interval is None: - return False - last_updated = getattr(instance, self._last_updated_attr, None) - if last_updated is None: + # If the var is dirty, it needs to be updated. + if self._name in instance.dirty_vars: return True - return datetime.datetime.now() - last_updated > self._update_interval + + # If the var is expired, it needs to be updated. + if self._update_interval is not None: + last_updated = getattr( + instance, self._last_updated_attr, datetime.datetime.min + ) + if datetime.datetime.now() - last_updated > self._update_interval: + return True + + return False + + def _needs_update_check(self, instance: BaseState) -> bool: + """Check if the computed var needs to be updated, respecting cache. + + This checks for expiration but ignores dirty_vars if cache is present. + The assumption is that if cache is present, mark_dirty was NOT called, + so we are valid despite dirty_vars persisting in get_delta loop. + + Args: + instance: The state instance. + + Returns: + Whether the computed var needs to be updated. + """ + # If the var is expired, it needs to be updated. + if self._update_interval is not None: + last_updated = getattr( + instance, self._last_updated_attr, datetime.datetime.min + ) + if datetime.datetime.now() - last_updated > self._update_interval: + return True + + return False @overload def __get__( @@ -2385,12 +2415,15 @@ def __get__(self, instance: BaseState | None, owner: type): ) if not self._cache: - value = self.fget(instance) + value = self._get_value(instance) else: # handle caching - if not hasattr(instance, self._cache_attr) or self.needs_update(instance): + # If we have a cache, and we haven't been explicitly marked for update (via mark_dirty), + # we shouldn't update just because we are in dirty_vars (which persists across get_delta loop). + + if not hasattr(instance, self._cache_attr) or self._needs_update_check(instance): # Set cache attr on state instance. - setattr(instance, self._cache_attr, self.fget(instance)) + setattr(instance, self._cache_attr, self._get_value(instance)) # Ensure the computed var gets serialized to redis. instance._was_touched = True # Set the last updated timestamp on the state instance. @@ -2401,6 +2434,27 @@ def __get__(self, instance: BaseState | None, owner: type): return value + def _get_value(self, instance: BaseState) -> Any: + """Get the value of the computed var, handling generators. + + Args: + instance: The state instance. + + Returns: + The value of the computed var. + """ + print(f"DEBUG: Computing {self._name} for {type(instance).__name__}") + value = self.fget(instance) + if inspect.isgenerator(value): + try: + while True: + event = next(value) + if hasattr(instance, "_computed_var_events"): + instance._computed_var_events.append(event) + except StopIteration as e: + return e.value + return value + def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None: if not _isinstance(value, self._var_type, nested=1, treat_var_as_type=False): console.error( diff --git a/tests/units/test_computed_var_side_effects.py b/tests/units/test_computed_var_side_effects.py new file mode 100644 index 00000000000..7a54bb2f65b --- /dev/null +++ b/tests/units/test_computed_var_side_effects.py @@ -0,0 +1,76 @@ +from typing import AsyncIterator, Generator, List + +import pytest + +import reflex as rx +from reflex.constants.state import FIELD_MARKER +from reflex.state import BaseState, StateUpdate +from reflex.vars.base import computed_var + + +class SideEffectState(BaseState): + """State for testing computed var side effects.""" + + count: int = 0 + triggered: bool = False + side_effect_value: str = "" + + @computed_var + def computed_with_side_effect(self) -> int: + if self.count > 0: + self.triggered = True + yield rx.window_alert("Triggered!") + return self.count * 2 + return 0 + + @computed_var + def computed_modifying_other_var(self) -> str: + if self.count == 5: + self.side_effect_value = "Five" + return "Modified" + return "Not Modified" + + +@pytest.mark.asyncio +async def test_computed_var_yields_event(): + """Test that a computed var can yield an event.""" + state = SideEffectState() + state.count = 1 + + # This should trigger the computed var + # In a real app, this happens via get_delta, but we can simulate the process + # The key is that accessing the var triggers the generator and collection + + # Manually trigger calculation as get_delta would + state._mark_dirty_computed_vars() + + # Accessing the property should run the getter + val = state.computed_with_side_effect + assert val == 2 + assert state.triggered is True + + # Check if event was collected + assert hasattr(state, "_computed_var_events") + assert len(state._computed_var_events) > 0 + # window_alert uses run_script which uses call_function which creates an EventHandler with _call_function + # so checking the handler name is tricky. We check if the event spec is returned. + event = state._computed_var_events[0] + assert event.handler.fn.__qualname__ == "_call_function" + + +@pytest.mark.asyncio +async def test_computed_var_modifies_state(): + """Test that a computed var can modify other state variables.""" + state = SideEffectState() + state.count = 5 + + # This call to get_delta mimics the backend processing loop + delta = state.get_delta() + + full_name = state.get_full_name() + # Check that the computed var was calculated + assert delta[full_name]["computed_modifying_other_var" + FIELD_MARKER] == "Modified" + + # Check that the side effect on 'side_effect_value' was captured in the delta + # The fix involves iterating in get_delta to capture these changes + assert delta[full_name]["side_effect_value" + FIELD_MARKER] == "Five"