Skip to content

Commit

Permalink
simple pytest benchmark for measuring event <=> state update round tr…
Browse files Browse the repository at this point in the history
…ip time (reflex-dev#2489)
  • Loading branch information
jackie-pc authored Jan 30, 2024
1 parent 6cf411a commit 032017d
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 9 deletions.
90 changes: 90 additions & 0 deletions integration/test_large_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Test large state."""
import time

import jinja2
import pytest
from selenium.webdriver.common.by import By

from reflex.testing import AppHarness, WebDriver

LARGE_STATE_APP_TEMPLATE = """
import reflex as rx
class State(rx.State):
var0: int = 0
{% for i in range(1, var_count) %}
var{{ i }}: str = "{{ i }}" * 10000
{% endfor %}
def increment_var0(self):
self.var0 += 1
def index() -> rx.Component:
return rx.box(rx.button(State.var0, on_click=State.increment_var0, id="button"))
app = rx.App()
app.add_page(index)
"""


def get_driver(large_state) -> WebDriver:
"""Get an instance of the browser open to the large_state app.
Args:
large_state: harness for LargeState app
Returns:
WebDriver instance.
"""
assert large_state.app_instance is not None, "app is not running"
return large_state.frontend()


@pytest.mark.parametrize("var_count", [1, 10, 100, 1000, 10000])
def test_large_state(var_count: int, tmp_path_factory, benchmark):
"""Measure how long it takes for button click => state update to round trip.
Args:
var_count: number of variables to store in the state
tmp_path_factory: pytest fixture
benchmark: pytest fixture
Raises:
TimeoutError: if the state doesn't update within 30 seconds
"""
template = jinja2.Template(LARGE_STATE_APP_TEMPLATE)
large_state_rendered = template.render(var_count=var_count)

with AppHarness.create(
root=tmp_path_factory.mktemp(f"large_state"),
app_source=large_state_rendered,
app_name="large_state",
) as large_state:
driver = get_driver(large_state)
try:
assert large_state.app_instance is not None
button = driver.find_element(By.ID, "button")

t = time.time()
while button.text != "0":
time.sleep(0.1)
if time.time() - t > 30.0:
raise TimeoutError("Timeout waiting for initial state")

times_clicked = 0

def round_trip(clicks: int, timeout: float):
t = time.time()
for _ in range(clicks):
button.click()
nonlocal times_clicked
times_clicked += clicks
while button.text != str(times_clicked):
time.sleep(0.005)
if time.time() - t > timeout:
raise TimeoutError("Timeout waiting for state update")

benchmark(round_trip, clicks=10, timeout=30.0)
finally:
driver.quit()
27 changes: 18 additions & 9 deletions reflex/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class AppHarness:
"""AppHarness executes a reflex app in-process for testing."""

app_name: str
app_source: Optional[types.FunctionType | types.ModuleType]
app_source: Optional[types.FunctionType | types.ModuleType] | str
app_path: pathlib.Path
app_module_path: pathlib.Path
app_module: Optional[types.ModuleType] = None
Expand All @@ -119,18 +119,21 @@ class AppHarness:
def create(
cls,
root: pathlib.Path,
app_source: Optional[types.FunctionType | types.ModuleType] = None,
app_source: Optional[types.FunctionType | types.ModuleType | str] = None,
app_name: Optional[str] = None,
) -> "AppHarness":
"""Create an AppHarness instance at root.
Args:
root: the directory that will contain the app under test.
app_source: if specified, the source code from this function or module is used
as the main module for the app. If unspecified, then root must already
contain a working reflex app and will be used directly.
as the main module for the app. It may also be the raw source code text, as a str.
If unspecified, then root must already contain a working reflex app and will be used directly.
app_name: provide the name of the app, otherwise will be derived from app_source or root.
Raises:
ValueError: when app_source is a string and app_name is not provided.
Returns:
AppHarness instance
"""
Expand All @@ -139,6 +142,10 @@ def create(
app_name = root.name.lower()
elif isinstance(app_source, functools.partial):
app_name = app_source.func.__name__.lower()
elif isinstance(app_source, str):
raise ValueError(
"app_name must be provided when app_source is a string."
)
else:
app_name = app_source.__name__.lower()
return cls(
Expand Down Expand Up @@ -170,16 +177,18 @@ def _get_globals_from_signature(self, func: Any) -> dict[str, Any]:
glbs.update(overrides)
return glbs

def _get_source_from_func(self, func: Any) -> str:
"""Get the source from a function or module object.
def _get_source_from_app_source(self, app_source: Any) -> str:
"""Get the source from app_source.
Args:
func: function or module object
app_source: function or module or str
Returns:
source code
"""
source = inspect.getsource(func)
if isinstance(app_source, str):
return app_source
source = inspect.getsource(app_source)
source = re.sub(r"^\s*def\s+\w+\s*\(.*?\):", "", source, flags=re.DOTALL)
return textwrap.dedent(source)

Expand All @@ -194,7 +203,7 @@ def _initialize_app(self):
source_code = "\n".join(
[
"\n".join(f"{k} = {v!r}" for k, v in app_globals.items()),
self._get_source_from_func(self.app_source),
self._get_source_from_app_source(self.app_source),
]
)
with chdir(self.app_path):
Expand Down

0 comments on commit 032017d

Please sign in to comment.