Explicit deps and interval for computed vars (#3231)
This commit is contained in:
parent
ac1c660bf0
commit
93de407007
210
integration/test_computed_vars.py
Normal file
210
integration/test_computed_vars.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Test computed vars."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from selenium.webdriver.common.by import By
|
||||
|
||||
from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
|
||||
|
||||
|
||||
def ComputedVars():
|
||||
"""Test app for computed vars."""
|
||||
import reflex as rx
|
||||
|
||||
class State(rx.State):
|
||||
count: int = 0
|
||||
|
||||
# cached var with dep on count
|
||||
@rx.cached_var(interval=15)
|
||||
def count1(self) -> int:
|
||||
return self.count
|
||||
|
||||
# same as above, different notation
|
||||
@rx.var(interval=15, cache=True)
|
||||
def count2(self) -> int:
|
||||
return self.count
|
||||
|
||||
# explicit disabled auto_deps
|
||||
@rx.var(interval=15, cache=True, auto_deps=False)
|
||||
def count3(self) -> int:
|
||||
# this will not add deps, because auto_deps is False
|
||||
print(self.count1)
|
||||
print(self.count2)
|
||||
|
||||
return self.count
|
||||
|
||||
# explicit dependency on count1 var
|
||||
@rx.cached_var(deps=[count1], auto_deps=False)
|
||||
def depends_on_count1(self) -> int:
|
||||
return self.count
|
||||
|
||||
@rx.var(deps=[count3], auto_deps=False, cache=True)
|
||||
def depends_on_count3(self) -> int:
|
||||
return self.count
|
||||
|
||||
def increment(self):
|
||||
self.count += 1
|
||||
|
||||
def mark_dirty(self):
|
||||
self._mark_dirty()
|
||||
|
||||
def index() -> rx.Component:
|
||||
return rx.center(
|
||||
rx.vstack(
|
||||
rx.input(
|
||||
id="token",
|
||||
value=State.router.session.client_token,
|
||||
is_read_only=True,
|
||||
),
|
||||
rx.button("Increment", on_click=State.increment, id="increment"),
|
||||
rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"),
|
||||
rx.text("count:"),
|
||||
rx.text(State.count, id="count"),
|
||||
rx.text("count1:"),
|
||||
rx.text(State.count1, id="count1"),
|
||||
rx.text("count2:"),
|
||||
rx.text(State.count2, id="count2"),
|
||||
rx.text("count3:"),
|
||||
rx.text(State.count3, id="count3"),
|
||||
rx.text("depends_on_count1:"),
|
||||
rx.text(
|
||||
State.depends_on_count1,
|
||||
id="depends_on_count1",
|
||||
),
|
||||
rx.text("depends_on_count3:"),
|
||||
rx.text(
|
||||
State.depends_on_count3,
|
||||
id="depends_on_count3",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# raise Exception(State.count3._deps(objclass=State))
|
||||
app = rx.App()
|
||||
app.add_page(index)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def computed_vars(
|
||||
tmp_path_factory,
|
||||
) -> Generator[AppHarness, None, None]:
|
||||
"""Start ComputedVars app at tmp_path via AppHarness.
|
||||
|
||||
Args:
|
||||
tmp_path_factory: pytest tmp_path_factory fixture
|
||||
|
||||
Yields:
|
||||
running AppHarness instance
|
||||
"""
|
||||
with AppHarness.create(
|
||||
root=tmp_path_factory.mktemp(f"computed_vars"),
|
||||
app_source=ComputedVars, # type: ignore
|
||||
) as harness:
|
||||
yield harness
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]:
|
||||
"""Get an instance of the browser open to the computed_vars app.
|
||||
|
||||
Args:
|
||||
computed_vars: harness for ComputedVars app
|
||||
|
||||
Yields:
|
||||
WebDriver instance.
|
||||
"""
|
||||
assert computed_vars.app_instance is not None, "app is not running"
|
||||
driver = computed_vars.frontend()
|
||||
try:
|
||||
yield driver
|
||||
finally:
|
||||
driver.quit()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def token(computed_vars: AppHarness, driver: WebDriver) -> str:
|
||||
"""Get a function that returns the active token.
|
||||
|
||||
Args:
|
||||
computed_vars: harness for ComputedVars app.
|
||||
driver: WebDriver instance.
|
||||
|
||||
Returns:
|
||||
The token for the connected client
|
||||
"""
|
||||
assert computed_vars.app_instance is not None
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
assert token_input
|
||||
|
||||
# wait for the backend connection to send the token
|
||||
token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
|
||||
assert token is not None
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def test_computed_vars(
|
||||
computed_vars: AppHarness,
|
||||
driver: WebDriver,
|
||||
token: str,
|
||||
):
|
||||
"""Test that computed vars are working as expected.
|
||||
|
||||
Args:
|
||||
computed_vars: harness for ComputedVars app.
|
||||
driver: WebDriver instance.
|
||||
token: The token for the connected client.
|
||||
"""
|
||||
assert computed_vars.app_instance is not None
|
||||
|
||||
count = driver.find_element(By.ID, "count")
|
||||
assert count
|
||||
assert count.text == "0"
|
||||
|
||||
count1 = driver.find_element(By.ID, "count1")
|
||||
assert count1
|
||||
assert count1.text == "0"
|
||||
|
||||
count2 = driver.find_element(By.ID, "count2")
|
||||
assert count2
|
||||
assert count2.text == "0"
|
||||
|
||||
count3 = driver.find_element(By.ID, "count3")
|
||||
assert count3
|
||||
assert count3.text == "0"
|
||||
|
||||
depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
|
||||
assert depends_on_count1
|
||||
assert depends_on_count1.text == "0"
|
||||
|
||||
depends_on_count3 = driver.find_element(By.ID, "depends_on_count3")
|
||||
assert depends_on_count3
|
||||
assert depends_on_count3.text == "0"
|
||||
|
||||
increment = driver.find_element(By.ID, "increment")
|
||||
assert increment.is_enabled()
|
||||
|
||||
mark_dirty = driver.find_element(By.ID, "mark_dirty")
|
||||
assert mark_dirty.is_enabled()
|
||||
|
||||
mark_dirty.click()
|
||||
|
||||
increment.click()
|
||||
assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
|
||||
assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
|
||||
assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"
|
||||
|
||||
mark_dirty.click()
|
||||
with pytest.raises(TimeoutError):
|
||||
computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
|
||||
|
||||
time.sleep(10)
|
||||
assert count3.text == "0"
|
||||
assert depends_on_count3.text == "0"
|
||||
mark_dirty.click()
|
||||
assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1"
|
||||
assert depends_on_count3.text == "1"
|
@ -1536,6 +1536,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if actual_var is not None:
|
||||
actual_var.mark_dirty(instance=self)
|
||||
|
||||
def _expired_computed_vars(self) -> set[str]:
|
||||
"""Determine ComputedVars that need to be recalculated based on the expiration time.
|
||||
|
||||
Returns:
|
||||
Set of computed vars to include in the delta.
|
||||
"""
|
||||
return set(
|
||||
cvar
|
||||
for cvar in self.computed_vars
|
||||
if self.computed_vars[cvar].needs_update(instance=self)
|
||||
)
|
||||
|
||||
def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
|
||||
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
||||
|
||||
@ -1588,6 +1600,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# and always dirty computed vars (cache=False)
|
||||
delta_vars = (
|
||||
self.dirty_vars.intersection(self.base_vars)
|
||||
.union(self.dirty_vars.intersection(self.computed_vars))
|
||||
.union(self._dirty_computed_vars())
|
||||
.union(self._always_dirty_computed_vars)
|
||||
)
|
||||
@ -1621,6 +1634,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
self.parent_state.dirty_substates.add(self.get_name())
|
||||
self.parent_state._mark_dirty()
|
||||
|
||||
# Append expired computed vars to dirty_vars to trigger recalculation
|
||||
self.dirty_vars.update(self._expired_computed_vars())
|
||||
|
||||
# have to mark computed vars dirty to allow access to newly computed
|
||||
# values within the same ComputedVar function
|
||||
self._mark_dirty_computed_vars()
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import datetime
|
||||
import dis
|
||||
import functools
|
||||
import inspect
|
||||
@ -1873,13 +1874,26 @@ class ComputedVar(Var, property):
|
||||
# Whether to track dependencies and cache computed values
|
||||
_cache: bool = dataclasses.field(default=False)
|
||||
|
||||
_initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
|
||||
# The initial value of the computed var
|
||||
_initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
|
||||
|
||||
# Explicit var dependencies to track
|
||||
_static_deps: set[str] = dataclasses.field(default_factory=set)
|
||||
|
||||
# Whether var dependencies should be auto-determined
|
||||
_auto_deps: bool = dataclasses.field(default=True)
|
||||
|
||||
# Interval at which the computed var should be updated
|
||||
_update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fget: Callable[[BaseState], Any],
|
||||
initial_value: Any | types.Unset = types.Unset(),
|
||||
cache: bool = False,
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
auto_deps: bool = True,
|
||||
interval: Optional[Union[int, datetime.timedelta]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize a ComputedVar.
|
||||
@ -1888,10 +1902,22 @@ class ComputedVar(Var, property):
|
||||
fget: The getter function.
|
||||
initial_value: The initial value of the computed var.
|
||||
cache: Whether to cache the computed value.
|
||||
deps: Explicit var dependencies to track.
|
||||
auto_deps: Whether var dependencies should be auto-determined.
|
||||
interval: Interval at which the computed var should be updated.
|
||||
**kwargs: additional attributes to set on the instance
|
||||
"""
|
||||
self._initial_value = initial_value
|
||||
self._cache = cache
|
||||
if isinstance(interval, int):
|
||||
interval = datetime.timedelta(seconds=interval)
|
||||
self._update_interval = interval
|
||||
if deps is None:
|
||||
deps = []
|
||||
self._static_deps = {
|
||||
dep._var_name if isinstance(dep, Var) else dep for dep in deps
|
||||
}
|
||||
self._auto_deps = auto_deps
|
||||
property.__init__(self, fget)
|
||||
kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
|
||||
kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
|
||||
@ -1912,6 +1938,9 @@ class ComputedVar(Var, property):
|
||||
fget=kwargs.get("fget", self.fget),
|
||||
initial_value=kwargs.get("initial_value", self._initial_value),
|
||||
cache=kwargs.get("cache", self._cache),
|
||||
deps=kwargs.get("deps", self._static_deps),
|
||||
auto_deps=kwargs.get("auto_deps", self._auto_deps),
|
||||
interval=kwargs.get("interval", self._update_interval),
|
||||
_var_name=kwargs.get("_var_name", self._var_name),
|
||||
_var_type=kwargs.get("_var_type", self._var_type),
|
||||
_var_is_local=kwargs.get("_var_is_local", self._var_is_local),
|
||||
@ -1932,7 +1961,32 @@ class ComputedVar(Var, property):
|
||||
"""
|
||||
return f"__cached_{self._var_name}"
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
@property
|
||||
def _last_updated_attr(self) -> str:
|
||||
"""Get the attribute used to store the last updated timestamp.
|
||||
|
||||
Returns:
|
||||
An attribute name.
|
||||
"""
|
||||
return f"__last_updated_{self._var_name}"
|
||||
|
||||
def needs_update(self, instance: BaseState) -> bool:
|
||||
"""Check if the computed var needs to be updated.
|
||||
|
||||
Args:
|
||||
instance: The state instance that the computed var is attached to.
|
||||
|
||||
Returns:
|
||||
True if the computed var needs to be updated, False otherwise.
|
||||
"""
|
||||
if self._update_interval is None:
|
||||
return False
|
||||
last_updated = getattr(instance, self._last_updated_attr, None)
|
||||
if last_updated is None:
|
||||
return True
|
||||
return datetime.datetime.now() - last_updated > self._update_interval
|
||||
|
||||
def __get__(self, instance: BaseState | None, owner):
|
||||
"""Get the ComputedVar value.
|
||||
|
||||
If the value is already cached on the instance, return the cached value.
|
||||
@ -1948,10 +2002,13 @@ class ComputedVar(Var, property):
|
||||
return super().__get__(instance, owner)
|
||||
|
||||
# handle caching
|
||||
if not hasattr(instance, self._cache_attr):
|
||||
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
|
||||
# Set cache attr on state instance.
|
||||
setattr(instance, self._cache_attr, super().__get__(instance, owner))
|
||||
# Ensure the computed var gets serialized to redis.
|
||||
instance._was_touched = True
|
||||
# Set the last updated timestamp on the state instance.
|
||||
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||
return getattr(instance, self._cache_attr)
|
||||
|
||||
def _deps(
|
||||
@ -1978,7 +2035,9 @@ class ComputedVar(Var, property):
|
||||
VarValueError: if the function references the get_state, parent_state, or substates attributes
|
||||
(cannot track deps in a related state, only implicitly via parent state).
|
||||
"""
|
||||
d = set()
|
||||
if not self._auto_deps:
|
||||
return self._static_deps
|
||||
d = self._static_deps.copy()
|
||||
if obj is None:
|
||||
fget = property.__getattribute__(self, "fget")
|
||||
if fget is not None:
|
||||
@ -2076,6 +2135,9 @@ def computed_var(
|
||||
fget: Callable[[BaseState], Any] | None = None,
|
||||
initial_value: Any | None = None,
|
||||
cache: bool = False,
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
auto_deps: bool = True,
|
||||
interval: Optional[Union[datetime.timedelta, int]] = None,
|
||||
**kwargs,
|
||||
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
|
||||
"""A ComputedVar decorator with or without kwargs.
|
||||
@ -2084,19 +2146,31 @@ def computed_var(
|
||||
fget: The getter function.
|
||||
initial_value: The initial value of the computed var.
|
||||
cache: Whether to cache the computed value.
|
||||
deps: Explicit var dependencies to track.
|
||||
auto_deps: Whether var dependencies should be auto-determined.
|
||||
interval: Interval at which the computed var should be updated.
|
||||
**kwargs: additional attributes to set on the instance
|
||||
|
||||
Returns:
|
||||
A ComputedVar instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If caching is disabled and an update interval is set.
|
||||
"""
|
||||
if cache is False and interval is not None:
|
||||
raise ValueError("Cannot set update interval without caching.")
|
||||
|
||||
if fget is not None:
|
||||
return ComputedVar(fget=fget, cache=cache)
|
||||
|
||||
def wrapper(fget):
|
||||
def wrapper(fget: Callable[[BaseState], Any]) -> ComputedVar:
|
||||
return ComputedVar(
|
||||
fget=fget,
|
||||
initial_value=initial_value,
|
||||
cache=cache,
|
||||
deps=deps,
|
||||
auto_deps=auto_deps,
|
||||
interval=interval,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from _typeshed import Incomplete
|
||||
from reflex import constants as constants
|
||||
@ -141,6 +142,7 @@ class ComputedVar(Var):
|
||||
def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ...
|
||||
def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ...
|
||||
def mark_dirty(self, instance) -> None: ...
|
||||
def needs_update(self, instance) -> bool: ...
|
||||
def _determine_var_type(self) -> Type: ...
|
||||
@overload
|
||||
def __init__(
|
||||
@ -155,10 +157,24 @@ class ComputedVar(Var):
|
||||
def computed_var(
|
||||
fget: Callable[[BaseState], Any] | None = None,
|
||||
initial_value: Any | None = None,
|
||||
cache: bool = False,
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
auto_deps: bool = True,
|
||||
interval: Optional[Union[datetime.timedelta, int]] = None,
|
||||
**kwargs,
|
||||
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
|
||||
@overload
|
||||
def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
|
||||
@overload
|
||||
def cached_var(
|
||||
fget: Callable[[BaseState], Any] | None = None,
|
||||
initial_value: Any | None = None,
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
auto_deps: bool = True,
|
||||
interval: Optional[Union[datetime.timedelta, int]] = None,
|
||||
**kwargs,
|
||||
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
|
||||
@overload
|
||||
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
|
||||
|
||||
class CallableVar(BaseVar):
|
||||
|
Loading…
Reference in New Issue
Block a user