diff --git a/integration/test_computed_vars.py b/integration/test_computed_vars.py new file mode 100644 index 000000000..2523248a9 --- /dev/null +++ b/integration/test_computed_vars.py @@ -0,0 +1,210 @@ +"""Test computed vars.""" + +from __future__ import annotations + +import time +from typing import Generator + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver + + +def ComputedVars(): + """Test app for computed vars.""" + import reflex as rx + + class State(rx.State): + count: int = 0 + + # cached var with dep on count + @rx.cached_var(interval=15) + def count1(self) -> int: + return self.count + + # same as above, different notation + @rx.var(interval=15, cache=True) + def count2(self) -> int: + return self.count + + # explicit disabled auto_deps + @rx.var(interval=15, cache=True, auto_deps=False) + def count3(self) -> int: + # this will not add deps, because auto_deps is False + print(self.count1) + print(self.count2) + + return self.count + + # explicit dependency on count1 var + @rx.cached_var(deps=[count1], auto_deps=False) + def depends_on_count1(self) -> int: + return self.count + + @rx.var(deps=[count3], auto_deps=False, cache=True) + def depends_on_count3(self) -> int: + return self.count + + def increment(self): + self.count += 1 + + def mark_dirty(self): + self._mark_dirty() + + def index() -> rx.Component: + return rx.center( + rx.vstack( + rx.input( + id="token", + value=State.router.session.client_token, + is_read_only=True, + ), + rx.button("Increment", on_click=State.increment, id="increment"), + rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"), + rx.text("count:"), + rx.text(State.count, id="count"), + rx.text("count1:"), + rx.text(State.count1, id="count1"), + rx.text("count2:"), + rx.text(State.count2, id="count2"), + rx.text("count3:"), + rx.text(State.count3, id="count3"), + rx.text("depends_on_count1:"), + rx.text( + State.depends_on_count1, + id="depends_on_count1", + ), + rx.text("depends_on_count3:"), + rx.text( + State.depends_on_count3, + id="depends_on_count3", + ), + ), + ) + + # raise Exception(State.count3._deps(objclass=State)) + app = rx.App() + app.add_page(index) + + +@pytest.fixture(scope="module") +def computed_vars( + tmp_path_factory, +) -> Generator[AppHarness, None, None]: + """Start ComputedVars app at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path_factory.mktemp(f"computed_vars"), + app_source=ComputedVars, # type: ignore + ) as harness: + yield harness + + +@pytest.fixture +def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]: + """Get an instance of the browser open to the computed_vars app. + + Args: + computed_vars: harness for ComputedVars app + + Yields: + WebDriver instance. + """ + assert computed_vars.app_instance is not None, "app is not running" + driver = computed_vars.frontend() + try: + yield driver + finally: + driver.quit() + + +@pytest.fixture() +def token(computed_vars: AppHarness, driver: WebDriver) -> str: + """Get a function that returns the active token. + + Args: + computed_vars: harness for ComputedVars app. + driver: WebDriver instance. + + Returns: + The token for the connected client + """ + assert computed_vars.app_instance is not None + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2) + assert token is not None + + return token + + +def test_computed_vars( + computed_vars: AppHarness, + driver: WebDriver, + token: str, +): + """Test that computed vars are working as expected. + + Args: + computed_vars: harness for ComputedVars app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert computed_vars.app_instance is not None + + count = driver.find_element(By.ID, "count") + assert count + assert count.text == "0" + + count1 = driver.find_element(By.ID, "count1") + assert count1 + assert count1.text == "0" + + count2 = driver.find_element(By.ID, "count2") + assert count2 + assert count2.text == "0" + + count3 = driver.find_element(By.ID, "count3") + assert count3 + assert count3.text == "0" + + depends_on_count1 = driver.find_element(By.ID, "depends_on_count1") + assert depends_on_count1 + assert depends_on_count1.text == "0" + + depends_on_count3 = driver.find_element(By.ID, "depends_on_count3") + assert depends_on_count3 + assert depends_on_count3.text == "0" + + increment = driver.find_element(By.ID, "increment") + assert increment.is_enabled() + + mark_dirty = driver.find_element(By.ID, "mark_dirty") + assert mark_dirty.is_enabled() + + mark_dirty.click() + + increment.click() + assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1" + assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1" + assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1" + + mark_dirty.click() + with pytest.raises(TimeoutError): + computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0") + + time.sleep(10) + assert count3.text == "0" + assert depends_on_count3.text == "0" + mark_dirty.click() + assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1" + assert depends_on_count3.text == "1" diff --git a/reflex/state.py b/reflex/state.py index e778946c0..e889fe42e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1536,6 +1536,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if actual_var is not None: actual_var.mark_dirty(instance=self) + def _expired_computed_vars(self) -> set[str]: + """Determine ComputedVars that need to be recalculated based on the expiration time. + + Returns: + Set of computed vars to include in the delta. + """ + return set( + cvar + for cvar in self.computed_vars + if self.computed_vars[cvar].needs_update(instance=self) + ) + def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]: """Determine ComputedVars that need to be recalculated based on the given vars. @@ -1588,6 +1600,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # and always dirty computed vars (cache=False) delta_vars = ( self.dirty_vars.intersection(self.base_vars) + .union(self.dirty_vars.intersection(self.computed_vars)) .union(self._dirty_computed_vars()) .union(self._always_dirty_computed_vars) ) @@ -1621,6 +1634,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): self.parent_state.dirty_substates.add(self.get_name()) self.parent_state._mark_dirty() + # Append expired computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._expired_computed_vars()) + # have to mark computed vars dirty to allow access to newly computed # values within the same ComputedVar function self._mark_dirty_computed_vars() diff --git a/reflex/vars.py b/reflex/vars.py index 287502b43..cf6f2eed6 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -4,6 +4,7 @@ from __future__ import annotations import contextlib import dataclasses +import datetime import dis import functools import inspect @@ -1873,13 +1874,26 @@ class ComputedVar(Var, property): # Whether to track dependencies and cache computed values _cache: bool = dataclasses.field(default=False) - _initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset) + # The initial value of the computed var + _initial_value: Any | types.Unset = dataclasses.field(default=types.Unset()) + + # Explicit var dependencies to track + _static_deps: set[str] = dataclasses.field(default_factory=set) + + # Whether var dependencies should be auto-determined + _auto_deps: bool = dataclasses.field(default=True) + + # Interval at which the computed var should be updated + _update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None) def __init__( self, fget: Callable[[BaseState], Any], initial_value: Any | types.Unset = types.Unset(), cache: bool = False, + deps: Optional[List[Union[str, Var]]] = None, + auto_deps: bool = True, + interval: Optional[Union[int, datetime.timedelta]] = None, **kwargs, ): """Initialize a ComputedVar. @@ -1888,10 +1902,22 @@ class ComputedVar(Var, property): fget: The getter function. initial_value: The initial value of the computed var. cache: Whether to cache the computed value. + deps: Explicit var dependencies to track. + auto_deps: Whether var dependencies should be auto-determined. + interval: Interval at which the computed var should be updated. **kwargs: additional attributes to set on the instance """ self._initial_value = initial_value self._cache = cache + if isinstance(interval, int): + interval = datetime.timedelta(seconds=interval) + self._update_interval = interval + if deps is None: + deps = [] + self._static_deps = { + dep._var_name if isinstance(dep, Var) else dep for dep in deps + } + self._auto_deps = auto_deps property.__init__(self, fget) kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__) kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type()) @@ -1912,6 +1938,9 @@ class ComputedVar(Var, property): fget=kwargs.get("fget", self.fget), initial_value=kwargs.get("initial_value", self._initial_value), cache=kwargs.get("cache", self._cache), + deps=kwargs.get("deps", self._static_deps), + auto_deps=kwargs.get("auto_deps", self._auto_deps), + interval=kwargs.get("interval", self._update_interval), _var_name=kwargs.get("_var_name", self._var_name), _var_type=kwargs.get("_var_type", self._var_type), _var_is_local=kwargs.get("_var_is_local", self._var_is_local), @@ -1932,7 +1961,32 @@ class ComputedVar(Var, property): """ return f"__cached_{self._var_name}" - def __get__(self, instance, owner): + @property + def _last_updated_attr(self) -> str: + """Get the attribute used to store the last updated timestamp. + + Returns: + An attribute name. + """ + return f"__last_updated_{self._var_name}" + + def needs_update(self, instance: BaseState) -> bool: + """Check if the computed var needs to be updated. + + Args: + instance: The state instance that the computed var is attached to. + + Returns: + True if the computed var needs to be updated, False otherwise. + """ + if self._update_interval is None: + return False + last_updated = getattr(instance, self._last_updated_attr, None) + if last_updated is None: + return True + return datetime.datetime.now() - last_updated > self._update_interval + + def __get__(self, instance: BaseState | None, owner): """Get the ComputedVar value. If the value is already cached on the instance, return the cached value. @@ -1948,10 +2002,13 @@ class ComputedVar(Var, property): return super().__get__(instance, owner) # handle caching - if not hasattr(instance, self._cache_attr): + if not hasattr(instance, self._cache_attr) or self.needs_update(instance): + # Set cache attr on state instance. setattr(instance, self._cache_attr, super().__get__(instance, owner)) # Ensure the computed var gets serialized to redis. instance._was_touched = True + # Set the last updated timestamp on the state instance. + setattr(instance, self._last_updated_attr, datetime.datetime.now()) return getattr(instance, self._cache_attr) def _deps( @@ -1978,7 +2035,9 @@ class ComputedVar(Var, property): VarValueError: if the function references the get_state, parent_state, or substates attributes (cannot track deps in a related state, only implicitly via parent state). """ - d = set() + if not self._auto_deps: + return self._static_deps + d = self._static_deps.copy() if obj is None: fget = property.__getattribute__(self, "fget") if fget is not None: @@ -2076,6 +2135,9 @@ def computed_var( fget: Callable[[BaseState], Any] | None = None, initial_value: Any | None = None, cache: bool = False, + deps: Optional[List[Union[str, Var]]] = None, + auto_deps: bool = True, + interval: Optional[Union[datetime.timedelta, int]] = None, **kwargs, ) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]: """A ComputedVar decorator with or without kwargs. @@ -2084,19 +2146,31 @@ def computed_var( fget: The getter function. initial_value: The initial value of the computed var. cache: Whether to cache the computed value. + deps: Explicit var dependencies to track. + auto_deps: Whether var dependencies should be auto-determined. + interval: Interval at which the computed var should be updated. **kwargs: additional attributes to set on the instance Returns: A ComputedVar instance. + + Raises: + ValueError: If caching is disabled and an update interval is set. """ + if cache is False and interval is not None: + raise ValueError("Cannot set update interval without caching.") + if fget is not None: return ComputedVar(fget=fget, cache=cache) - def wrapper(fget): + def wrapper(fget: Callable[[BaseState], Any]) -> ComputedVar: return ComputedVar( fget=fget, initial_value=initial_value, cache=cache, + deps=deps, + auto_deps=auto_deps, + interval=interval, **kwargs, ) diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 169e2d919..01b276342 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -2,6 +2,7 @@ from __future__ import annotations +import datetime from dataclasses import dataclass from _typeshed import Incomplete from reflex import constants as constants @@ -141,6 +142,7 @@ class ComputedVar(Var): def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ... def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ... def mark_dirty(self, instance) -> None: ... + def needs_update(self, instance) -> bool: ... def _determine_var_type(self) -> Type: ... @overload def __init__( @@ -155,10 +157,24 @@ class ComputedVar(Var): def computed_var( fget: Callable[[BaseState], Any] | None = None, initial_value: Any | None = None, + cache: bool = False, + deps: Optional[List[Union[str, Var]]] = None, + auto_deps: bool = True, + interval: Optional[Union[datetime.timedelta, int]] = None, **kwargs, ) -> Callable[[Callable[[Any], Any]], ComputedVar]: ... @overload def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ... +@overload +def cached_var( + fget: Callable[[BaseState], Any] | None = None, + initial_value: Any | None = None, + deps: Optional[List[Union[str, Var]]] = None, + auto_deps: bool = True, + interval: Optional[Union[datetime.timedelta, int]] = None, + **kwargs, +) -> Callable[[Callable[[Any], Any]], ComputedVar]: ... +@overload def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ... class CallableVar(BaseVar):