-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Allow computed vars to yield events and update state reliably #5990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -366,6 +366,9 @@ class BaseState(EvenMoreBasicBaseState): | |||||||||||||||||||||
| # The parent state. | ||||||||||||||||||||||
| parent_state: BaseState | None = field(default=None, is_var=False) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Events triggered by computed vars. | ||||||||||||||||||||||
| _computed_var_events: list[EventSpec] = field(default_factory=list, is_var=False) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # The substates of the state. | ||||||||||||||||||||||
| substates: builtins.dict[str, BaseState] = field( | ||||||||||||||||||||||
| default_factory=builtins.dict, is_var=False | ||||||||||||||||||||||
|
|
@@ -441,6 +444,9 @@ def __init__( | |||||||||||||||||||||
| # Create a fresh copy of the backend variables for this instance | ||||||||||||||||||||||
| self._backend_vars = copy.deepcopy(self.backend_vars) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Initialize computed var events list | ||||||||||||||||||||||
| self._computed_var_events = [] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __repr__(self) -> str: | ||||||||||||||||||||||
| """Get the string representation of the state. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -1795,6 +1801,14 @@ async def _as_state_update( | |||||||||||||||||||||
| try: | ||||||||||||||||||||||
| # Get the delta after processing the event. | ||||||||||||||||||||||
| delta = await state._get_resolved_delta() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Collect events from computed vars | ||||||||||||||||||||||
| computed_var_events = state._collect_computed_var_events() | ||||||||||||||||||||||
| if computed_var_events: | ||||||||||||||||||||||
| fixed_events.extend( | ||||||||||||||||||||||
| fix_events(self._check_valid(handler, computed_var_events), token) | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| state._clean() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return StateUpdate( | ||||||||||||||||||||||
|
|
@@ -1964,14 +1978,21 @@ async def _process_event( | |||||||||||||||||||||
| final=True, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _mark_dirty_computed_vars(self) -> None: | ||||||||||||||||||||||
| """Mark ComputedVars that need to be recalculated based on dirty_vars.""" | ||||||||||||||||||||||
| # Append expired computed vars to dirty_vars to trigger recalculation | ||||||||||||||||||||||
| self.dirty_vars.update(self._expired_computed_vars()) | ||||||||||||||||||||||
| # Append always dirty computed vars to dirty_vars to trigger recalculation | ||||||||||||||||||||||
| self.dirty_vars.update(self._always_dirty_computed_vars) | ||||||||||||||||||||||
| def _mark_dirty_computed_vars(self, from_vars: set[str] | None = None) -> None: | ||||||||||||||||||||||
| """Mark ComputedVars that need to be recalculated based on dirty_vars. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Args: | ||||||||||||||||||||||
| from_vars: The vars to start the propagation from. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| if from_vars is None: | ||||||||||||||||||||||
| # Append expired computed vars to dirty_vars to trigger recalculation | ||||||||||||||||||||||
| self.dirty_vars.update(self._expired_computed_vars()) | ||||||||||||||||||||||
| # Append always dirty computed vars to dirty_vars to trigger recalculation | ||||||||||||||||||||||
| self.dirty_vars.update(self._always_dirty_computed_vars) | ||||||||||||||||||||||
| dirty_vars = self.dirty_vars | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| dirty_vars = from_vars | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| dirty_vars = self.dirty_vars | ||||||||||||||||||||||
| while dirty_vars: | ||||||||||||||||||||||
| calc_vars, dirty_vars = dirty_vars, set() | ||||||||||||||||||||||
| for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars): | ||||||||||||||||||||||
|
|
@@ -2022,6 +2043,21 @@ def _dirty_computed_vars( | |||||||||||||||||||||
| if include_backend or not self.computed_vars[cvar]._backend | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _collect_computed_var_events(self) -> list[EventSpec]: | ||||||||||||||||||||||
| """Collect events triggered by computed vars. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||
| The list of events. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| events = self._computed_var_events | ||||||||||||||||||||||
| self._computed_var_events = [] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for substate in self.dirty_substates.union(self._always_dirty_substates): | ||||||||||||||||||||||
| if substate in self.substates: | ||||||||||||||||||||||
| events.extend(self.substates[substate]._collect_computed_var_events()) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return events | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def get_delta(self) -> Delta: | ||||||||||||||||||||||
| """Get the delta for the state. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -2030,22 +2066,36 @@ def get_delta(self) -> Delta: | |||||||||||||||||||||
| """ | ||||||||||||||||||||||
| delta = {} | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| self._mark_dirty_computed_vars() | ||||||||||||||||||||||
| frontend_computed_vars: set[str] = { | ||||||||||||||||||||||
| name for name, cv in self.computed_vars.items() if not cv._backend | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| # Loop to stabilize state | ||||||||||||||||||||||
| # We limit iterations to avoid infinite loops (e.g. oscillating states) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| previous_dirty_vars = self.dirty_vars.copy() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Return the dirty vars for this instance, any cached/dependent computed vars, | ||||||||||||||||||||||
| # and always dirty computed vars (cache=False) | ||||||||||||||||||||||
| delta_vars = self.dirty_vars.intersection(self.base_vars).union( | ||||||||||||||||||||||
| self.dirty_vars.intersection(frontend_computed_vars) | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| for i in range(10): | ||||||||||||||||||||||
| if i == 0: | ||||||||||||||||||||||
| self._mark_dirty_computed_vars() | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| new_dirty_vars = self.dirty_vars - previous_dirty_vars | ||||||||||||||||||||||
| if not new_dirty_vars: | ||||||||||||||||||||||
| break | ||||||||||||||||||||||
| self._mark_dirty_computed_vars(from_vars=new_dirty_vars) | ||||||||||||||||||||||
| previous_dirty_vars = self.dirty_vars.copy() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| subdelta: dict[str, Any] = { | ||||||||||||||||||||||
| prop + FIELD_MARKER: self.get_value(prop) | ||||||||||||||||||||||
| for prop in delta_vars | ||||||||||||||||||||||
| if not types.is_backend_base_variable(prop, type(self)) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| frontend_computed_vars: set[str] = { | ||||||||||||||||||||||
| name for name, cv in self.computed_vars.items() if not cv._backend | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Return the dirty vars for this instance, any cached/dependent computed vars, | ||||||||||||||||||||||
| # and always dirty computed vars (cache=False) | ||||||||||||||||||||||
| delta_vars = self.dirty_vars.intersection(self.base_vars).union( | ||||||||||||||||||||||
| self.dirty_vars.intersection(frontend_computed_vars) | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| subdelta: dict[str, Any] = { | ||||||||||||||||||||||
| prop + FIELD_MARKER: self.get_value(prop) | ||||||||||||||||||||||
| for prop in delta_vars | ||||||||||||||||||||||
| if not types.is_backend_base_variable(prop, type(self)) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
Comment on lines
+2094
to
+2098
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Move this outside the loop and update it incrementally:
Suggested change
Prompt To Fix With AIThis is a comment left during a code review.
Path: reflex/state.py
Line: 2094:2098
Comment:
**logic:** `subdelta` is recreated in each loop iteration, so only the last iteration's changes will be included in the final delta. This will cause state updates from earlier iterations to be lost.
Move this outside the loop and update it incrementally:
```suggestion
# Only compute delta on first iteration or when there are new dirty vars
if i == 0 or new_dirty_vars:
for prop in delta_vars:
if not types.is_backend_base_variable(prop, type(self)):
subdelta[prop + FIELD_MARKER] = self.get_value(prop)
```
How can I resolve this? If you propose a fix, please make it concise. |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if len(subdelta) > 0: | ||||||||||||||||||||||
| delta[self.get_full_name()] = subdelta | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -2273,17 +2273,47 @@ def needs_update(self, instance: BaseState) -> bool: | |||
| """Check if the computed var needs to be updated. | ||||
|
|
||||
| Args: | ||||
| instance: The state instance that the computed var is attached to. | ||||
| instance: The state instance. | ||||
|
|
||||
| Returns: | ||||
| True if the computed var needs to be updated, False otherwise. | ||||
| Whether the computed var needs to be updated. | ||||
| """ | ||||
| if self._update_interval is None: | ||||
| return False | ||||
| last_updated = getattr(instance, self._last_updated_attr, None) | ||||
| if last_updated is None: | ||||
| # If the var is dirty, it needs to be updated. | ||||
| if self._name in instance.dirty_vars: | ||||
| return True | ||||
| return datetime.datetime.now() - last_updated > self._update_interval | ||||
|
|
||||
| # If the var is expired, it needs to be updated. | ||||
| if self._update_interval is not None: | ||||
| last_updated = getattr( | ||||
| instance, self._last_updated_attr, datetime.datetime.min | ||||
| ) | ||||
| if datetime.datetime.now() - last_updated > self._update_interval: | ||||
| return True | ||||
|
|
||||
| return False | ||||
|
|
||||
| def _needs_update_check(self, instance: BaseState) -> bool: | ||||
| """Check if the computed var needs to be updated, respecting cache. | ||||
|
|
||||
| This checks for expiration but ignores dirty_vars if cache is present. | ||||
| The assumption is that if cache is present, mark_dirty was NOT called, | ||||
| so we are valid despite dirty_vars persisting in get_delta loop. | ||||
|
|
||||
| Args: | ||||
| instance: The state instance. | ||||
|
|
||||
| Returns: | ||||
| Whether the computed var needs to be updated. | ||||
| """ | ||||
| # If the var is expired, it needs to be updated. | ||||
| if self._update_interval is not None: | ||||
| last_updated = getattr( | ||||
| instance, self._last_updated_attr, datetime.datetime.min | ||||
| ) | ||||
| if datetime.datetime.now() - last_updated > self._update_interval: | ||||
| return True | ||||
|
|
||||
| return False | ||||
|
|
||||
| @overload | ||||
| def __get__( | ||||
|
|
@@ -2385,12 +2415,15 @@ def __get__(self, instance: BaseState | None, owner: type): | |||
| ) | ||||
|
|
||||
| if not self._cache: | ||||
| value = self.fget(instance) | ||||
| value = self._get_value(instance) | ||||
| else: | ||||
| # handle caching | ||||
| if not hasattr(instance, self._cache_attr) or self.needs_update(instance): | ||||
| # If we have a cache, and we haven't been explicitly marked for update (via mark_dirty), | ||||
| # we shouldn't update just because we are in dirty_vars (which persists across get_delta loop). | ||||
|
|
||||
| if not hasattr(instance, self._cache_attr) or self._needs_update_check(instance): | ||||
| # Set cache attr on state instance. | ||||
| setattr(instance, self._cache_attr, self.fget(instance)) | ||||
| setattr(instance, self._cache_attr, self._get_value(instance)) | ||||
| # Ensure the computed var gets serialized to redis. | ||||
| instance._was_touched = True | ||||
| # Set the last updated timestamp on the state instance. | ||||
|
|
@@ -2401,6 +2434,27 @@ def __get__(self, instance: BaseState | None, owner: type): | |||
|
|
||||
| return value | ||||
|
|
||||
| def _get_value(self, instance: BaseState) -> Any: | ||||
| """Get the value of the computed var, handling generators. | ||||
|
|
||||
| Args: | ||||
| instance: The state instance. | ||||
|
|
||||
| Returns: | ||||
| The value of the computed var. | ||||
| """ | ||||
| print(f"DEBUG: Computing {self._name} for {type(instance).__name__}") | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: Debug print statement left in production code
Suggested change
Context Used: Rule from Prompt To Fix With AIThis is a comment left during a code review.
Path: reflex/vars/base.py
Line: 2446:2446
Comment:
**syntax:** Debug print statement left in production code
```suggestion
```
**Context Used:** Rule from `dashboard` - Remove commented-out code before merging PRs. ([source](https://app.greptile.com/review/custom-context?memory=d49e2a0e-27a4-4cd6-b764-58c8a6fc4032))
How can I resolve this? If you propose a fix, please make it concise. |
||||
| value = self.fget(instance) | ||||
| if inspect.isgenerator(value): | ||||
| try: | ||||
| while True: | ||||
| event = next(value) | ||||
| if hasattr(instance, "_computed_var_events"): | ||||
| instance._computed_var_events.append(event) | ||||
| except StopIteration as e: | ||||
| return e.value | ||||
| return value | ||||
|
|
||||
| def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None: | ||||
| if not _isinstance(value, self._var_type, nested=1, treat_var_as_type=False): | ||||
| console.error( | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| from typing import AsyncIterator, Generator, List | ||
|
|
||
| import pytest | ||
|
|
||
| import reflex as rx | ||
| from reflex.constants.state import FIELD_MARKER | ||
| from reflex.state import BaseState, StateUpdate | ||
| from reflex.vars.base import computed_var | ||
|
|
||
|
|
||
| class SideEffectState(BaseState): | ||
| """State for testing computed var side effects.""" | ||
|
|
||
| count: int = 0 | ||
| triggered: bool = False | ||
| side_effect_value: str = "" | ||
|
|
||
| @computed_var | ||
| def computed_with_side_effect(self) -> int: | ||
| if self.count > 0: | ||
| self.triggered = True | ||
| yield rx.window_alert("Triggered!") | ||
| return self.count * 2 | ||
| return 0 | ||
|
|
||
| @computed_var | ||
| def computed_modifying_other_var(self) -> str: | ||
| if self.count == 5: | ||
| self.side_effect_value = "Five" | ||
| return "Modified" | ||
| return "Not Modified" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_computed_var_yields_event(): | ||
| """Test that a computed var can yield an event.""" | ||
| state = SideEffectState() | ||
| state.count = 1 | ||
|
|
||
| # This should trigger the computed var | ||
| # In a real app, this happens via get_delta, but we can simulate the process | ||
| # The key is that accessing the var triggers the generator and collection | ||
|
|
||
| # Manually trigger calculation as get_delta would | ||
| state._mark_dirty_computed_vars() | ||
|
|
||
| # Accessing the property should run the getter | ||
| val = state.computed_with_side_effect | ||
| assert val == 2 | ||
| assert state.triggered is True | ||
|
|
||
| # Check if event was collected | ||
| assert hasattr(state, "_computed_var_events") | ||
| assert len(state._computed_var_events) > 0 | ||
| # window_alert uses run_script which uses call_function which creates an EventHandler with _call_function | ||
| # so checking the handler name is tricky. We check if the event spec is returned. | ||
| event = state._computed_var_events[0] | ||
| assert event.handler.fn.__qualname__ == "_call_function" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_computed_var_modifies_state(): | ||
| """Test that a computed var can modify other state variables.""" | ||
| state = SideEffectState() | ||
| state.count = 5 | ||
|
|
||
| # This call to get_delta mimics the backend processing loop | ||
| delta = state.get_delta() | ||
|
|
||
| full_name = state.get_full_name() | ||
| # Check that the computed var was calculated | ||
| assert delta[full_name]["computed_modifying_other_var" + FIELD_MARKER] == "Modified" | ||
|
|
||
| # Check that the side effect on 'side_effect_value' was captured in the delta | ||
| # The fix involves iterating in get_delta to capture these changes | ||
| assert delta[full_name]["side_effect_value" + FIELD_MARKER] == "Five" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic:
subdeltaneeds to be initialized outside the loop to accumulate changes across iterationsPrompt To Fix With AI