Explicit deps and interval for computed vars (#3231)

This commit is contained in:
benedikt-bartscher 2024-05-28 21:27:27 +02:00 committed by GitHub
parent ac1c660bf0
commit 93de407007
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 321 additions and 5 deletions

View File

@ -0,0 +1,210 @@
"""Test computed vars."""
from __future__ import annotations
import time
from typing import Generator
import pytest
from selenium.webdriver.common.by import By
from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
def ComputedVars():
"""Test app for computed vars."""
import reflex as rx
class State(rx.State):
count: int = 0
# cached var with dep on count
@rx.cached_var(interval=15)
def count1(self) -> int:
return self.count
# same as above, different notation
@rx.var(interval=15, cache=True)
def count2(self) -> int:
return self.count
# explicit disabled auto_deps
@rx.var(interval=15, cache=True, auto_deps=False)
def count3(self) -> int:
# this will not add deps, because auto_deps is False
print(self.count1)
print(self.count2)
return self.count
# explicit dependency on count1 var
@rx.cached_var(deps=[count1], auto_deps=False)
def depends_on_count1(self) -> int:
return self.count
@rx.var(deps=[count3], auto_deps=False, cache=True)
def depends_on_count3(self) -> int:
return self.count
def increment(self):
self.count += 1
def mark_dirty(self):
self._mark_dirty()
def index() -> rx.Component:
return rx.center(
rx.vstack(
rx.input(
id="token",
value=State.router.session.client_token,
is_read_only=True,
),
rx.button("Increment", on_click=State.increment, id="increment"),
rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"),
rx.text("count:"),
rx.text(State.count, id="count"),
rx.text("count1:"),
rx.text(State.count1, id="count1"),
rx.text("count2:"),
rx.text(State.count2, id="count2"),
rx.text("count3:"),
rx.text(State.count3, id="count3"),
rx.text("depends_on_count1:"),
rx.text(
State.depends_on_count1,
id="depends_on_count1",
),
rx.text("depends_on_count3:"),
rx.text(
State.depends_on_count3,
id="depends_on_count3",
),
),
)
# raise Exception(State.count3._deps(objclass=State))
app = rx.App()
app.add_page(index)
@pytest.fixture(scope="module")
def computed_vars(
tmp_path_factory,
) -> Generator[AppHarness, None, None]:
"""Start ComputedVars app at tmp_path via AppHarness.
Args:
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path_factory.mktemp(f"computed_vars"),
app_source=ComputedVars, # type: ignore
) as harness:
yield harness
@pytest.fixture
def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the computed_vars app.
Args:
computed_vars: harness for ComputedVars app
Yields:
WebDriver instance.
"""
assert computed_vars.app_instance is not None, "app is not running"
driver = computed_vars.frontend()
try:
yield driver
finally:
driver.quit()
@pytest.fixture()
def token(computed_vars: AppHarness, driver: WebDriver) -> str:
"""Get a function that returns the active token.
Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.
Returns:
The token for the connected client
"""
assert computed_vars.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
assert token is not None
return token
def test_computed_vars(
computed_vars: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that computed vars are working as expected.
Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert computed_vars.app_instance is not None
count = driver.find_element(By.ID, "count")
assert count
assert count.text == "0"
count1 = driver.find_element(By.ID, "count1")
assert count1
assert count1.text == "0"
count2 = driver.find_element(By.ID, "count2")
assert count2
assert count2.text == "0"
count3 = driver.find_element(By.ID, "count3")
assert count3
assert count3.text == "0"
depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
assert depends_on_count1
assert depends_on_count1.text == "0"
depends_on_count3 = driver.find_element(By.ID, "depends_on_count3")
assert depends_on_count3
assert depends_on_count3.text == "0"
increment = driver.find_element(By.ID, "increment")
assert increment.is_enabled()
mark_dirty = driver.find_element(By.ID, "mark_dirty")
assert mark_dirty.is_enabled()
mark_dirty.click()
increment.click()
assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"
mark_dirty.click()
with pytest.raises(TimeoutError):
computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
time.sleep(10)
assert count3.text == "0"
assert depends_on_count3.text == "0"
mark_dirty.click()
assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1"
assert depends_on_count3.text == "1"

View File

@ -1536,6 +1536,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if actual_var is not None:
actual_var.mark_dirty(instance=self)
def _expired_computed_vars(self) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the expiration time.
Returns:
Set of computed vars to include in the delta.
"""
return set(
cvar
for cvar in self.computed_vars
if self.computed_vars[cvar].needs_update(instance=self)
)
def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars.
@ -1588,6 +1600,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self.dirty_vars.intersection(self.computed_vars))
.union(self._dirty_computed_vars())
.union(self._always_dirty_computed_vars)
)
@ -1621,6 +1634,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self.parent_state.dirty_substates.add(self.get_name())
self.parent_state._mark_dirty()
# Append expired computed vars to dirty_vars to trigger recalculation
self.dirty_vars.update(self._expired_computed_vars())
# have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function
self._mark_dirty_computed_vars()

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import contextlib
import dataclasses
import datetime
import dis
import functools
import inspect
@ -1873,13 +1874,26 @@ class ComputedVar(Var, property):
# Whether to track dependencies and cache computed values
_cache: bool = dataclasses.field(default=False)
_initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
# The initial value of the computed var
_initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
# Explicit var dependencies to track
_static_deps: set[str] = dataclasses.field(default_factory=set)
# Whether var dependencies should be auto-determined
_auto_deps: bool = dataclasses.field(default=True)
# Interval at which the computed var should be updated
_update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None)
def __init__(
self,
fget: Callable[[BaseState], Any],
initial_value: Any | types.Unset = types.Unset(),
cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[int, datetime.timedelta]] = None,
**kwargs,
):
"""Initialize a ComputedVar.
@ -1888,10 +1902,22 @@ class ComputedVar(Var, property):
fget: The getter function.
initial_value: The initial value of the computed var.
cache: Whether to cache the computed value.
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
**kwargs: additional attributes to set on the instance
"""
self._initial_value = initial_value
self._cache = cache
if isinstance(interval, int):
interval = datetime.timedelta(seconds=interval)
self._update_interval = interval
if deps is None:
deps = []
self._static_deps = {
dep._var_name if isinstance(dep, Var) else dep for dep in deps
}
self._auto_deps = auto_deps
property.__init__(self, fget)
kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
@ -1912,6 +1938,9 @@ class ComputedVar(Var, property):
fget=kwargs.get("fget", self.fget),
initial_value=kwargs.get("initial_value", self._initial_value),
cache=kwargs.get("cache", self._cache),
deps=kwargs.get("deps", self._static_deps),
auto_deps=kwargs.get("auto_deps", self._auto_deps),
interval=kwargs.get("interval", self._update_interval),
_var_name=kwargs.get("_var_name", self._var_name),
_var_type=kwargs.get("_var_type", self._var_type),
_var_is_local=kwargs.get("_var_is_local", self._var_is_local),
@ -1932,7 +1961,32 @@ class ComputedVar(Var, property):
"""
return f"__cached_{self._var_name}"
def __get__(self, instance, owner):
@property
def _last_updated_attr(self) -> str:
"""Get the attribute used to store the last updated timestamp.
Returns:
An attribute name.
"""
return f"__last_updated_{self._var_name}"
def needs_update(self, instance: BaseState) -> bool:
"""Check if the computed var needs to be updated.
Args:
instance: The state instance that the computed var is attached to.
Returns:
True if the computed var needs to be updated, False otherwise.
"""
if self._update_interval is None:
return False
last_updated = getattr(instance, self._last_updated_attr, None)
if last_updated is None:
return True
return datetime.datetime.now() - last_updated > self._update_interval
def __get__(self, instance: BaseState | None, owner):
"""Get the ComputedVar value.
If the value is already cached on the instance, return the cached value.
@ -1948,10 +2002,13 @@ class ComputedVar(Var, property):
return super().__get__(instance, owner)
# handle caching
if not hasattr(instance, self._cache_attr):
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, super().__get__(instance, owner))
# Ensure the computed var gets serialized to redis.
instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
return getattr(instance, self._cache_attr)
def _deps(
@ -1978,7 +2035,9 @@ class ComputedVar(Var, property):
VarValueError: if the function references the get_state, parent_state, or substates attributes
(cannot track deps in a related state, only implicitly via parent state).
"""
d = set()
if not self._auto_deps:
return self._static_deps
d = self._static_deps.copy()
if obj is None:
fget = property.__getattribute__(self, "fget")
if fget is not None:
@ -2076,6 +2135,9 @@ def computed_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
**kwargs,
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
"""A ComputedVar decorator with or without kwargs.
@ -2084,19 +2146,31 @@ def computed_var(
fget: The getter function.
initial_value: The initial value of the computed var.
cache: Whether to cache the computed value.
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
**kwargs: additional attributes to set on the instance
Returns:
A ComputedVar instance.
Raises:
ValueError: If caching is disabled and an update interval is set.
"""
if cache is False and interval is not None:
raise ValueError("Cannot set update interval without caching.")
if fget is not None:
return ComputedVar(fget=fget, cache=cache)
def wrapper(fget):
def wrapper(fget: Callable[[BaseState], Any]) -> ComputedVar:
return ComputedVar(
fget=fget,
initial_value=initial_value,
cache=cache,
deps=deps,
auto_deps=auto_deps,
interval=interval,
**kwargs,
)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import datetime
from dataclasses import dataclass
from _typeshed import Incomplete
from reflex import constants as constants
@ -141,6 +142,7 @@ class ComputedVar(Var):
def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ...
def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ...
def mark_dirty(self, instance) -> None: ...
def needs_update(self, instance) -> bool: ...
def _determine_var_type(self) -> Type: ...
@overload
def __init__(
@ -155,10 +157,24 @@ class ComputedVar(Var):
def computed_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
**kwargs,
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
@overload
def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
@overload
def cached_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
**kwargs,
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
@overload
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
class CallableVar(BaseVar):