Explicit deps and interval for computed vars (#3231)

This commit is contained in:
benedikt-bartscher 2024-05-28 21:27:27 +02:00 committed by GitHub
parent ac1c660bf0
commit 93de407007
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 321 additions and 5 deletions

View File

@ -0,0 +1,210 @@
"""Test computed vars."""
from __future__ import annotations
import time
from typing import Generator
import pytest
from selenium.webdriver.common.by import By
from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
def ComputedVars():
"""Test app for computed vars."""
import reflex as rx
class State(rx.State):
count: int = 0
# cached var with dep on count
@rx.cached_var(interval=15)
def count1(self) -> int:
return self.count
# same as above, different notation
@rx.var(interval=15, cache=True)
def count2(self) -> int:
return self.count
# explicit disabled auto_deps
@rx.var(interval=15, cache=True, auto_deps=False)
def count3(self) -> int:
# this will not add deps, because auto_deps is False
print(self.count1)
print(self.count2)
return self.count
# explicit dependency on count1 var
@rx.cached_var(deps=[count1], auto_deps=False)
def depends_on_count1(self) -> int:
return self.count
@rx.var(deps=[count3], auto_deps=False, cache=True)
def depends_on_count3(self) -> int:
return self.count
def increment(self):
self.count += 1
def mark_dirty(self):
self._mark_dirty()
def index() -> rx.Component:
return rx.center(
rx.vstack(
rx.input(
id="token",
value=State.router.session.client_token,
is_read_only=True,
),
rx.button("Increment", on_click=State.increment, id="increment"),
rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"),
rx.text("count:"),
rx.text(State.count, id="count"),
rx.text("count1:"),
rx.text(State.count1, id="count1"),
rx.text("count2:"),
rx.text(State.count2, id="count2"),
rx.text("count3:"),
rx.text(State.count3, id="count3"),
rx.text("depends_on_count1:"),
rx.text(
State.depends_on_count1,
id="depends_on_count1",
),
rx.text("depends_on_count3:"),
rx.text(
State.depends_on_count3,
id="depends_on_count3",
),
),
)
# raise Exception(State.count3._deps(objclass=State))
app = rx.App()
app.add_page(index)
@pytest.fixture(scope="module")
def computed_vars(
tmp_path_factory,
) -> Generator[AppHarness, None, None]:
"""Start ComputedVars app at tmp_path via AppHarness.
Args:
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path_factory.mktemp(f"computed_vars"),
app_source=ComputedVars, # type: ignore
) as harness:
yield harness
@pytest.fixture
def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the computed_vars app.
Args:
computed_vars: harness for ComputedVars app
Yields:
WebDriver instance.
"""
assert computed_vars.app_instance is not None, "app is not running"
driver = computed_vars.frontend()
try:
yield driver
finally:
driver.quit()
@pytest.fixture()
def token(computed_vars: AppHarness, driver: WebDriver) -> str:
"""Get a function that returns the active token.
Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.
Returns:
The token for the connected client
"""
assert computed_vars.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
assert token is not None
return token
def test_computed_vars(
computed_vars: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that computed vars are working as expected.
Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert computed_vars.app_instance is not None
count = driver.find_element(By.ID, "count")
assert count
assert count.text == "0"
count1 = driver.find_element(By.ID, "count1")
assert count1
assert count1.text == "0"
count2 = driver.find_element(By.ID, "count2")
assert count2
assert count2.text == "0"
count3 = driver.find_element(By.ID, "count3")
assert count3
assert count3.text == "0"
depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
assert depends_on_count1
assert depends_on_count1.text == "0"
depends_on_count3 = driver.find_element(By.ID, "depends_on_count3")
assert depends_on_count3
assert depends_on_count3.text == "0"
increment = driver.find_element(By.ID, "increment")
assert increment.is_enabled()
mark_dirty = driver.find_element(By.ID, "mark_dirty")
assert mark_dirty.is_enabled()
mark_dirty.click()
increment.click()
assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"
mark_dirty.click()
with pytest.raises(TimeoutError):
computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
time.sleep(10)
assert count3.text == "0"
assert depends_on_count3.text == "0"
mark_dirty.click()
assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1"
assert depends_on_count3.text == "1"

View File

@ -1536,6 +1536,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if actual_var is not None: if actual_var is not None:
actual_var.mark_dirty(instance=self) actual_var.mark_dirty(instance=self)
def _expired_computed_vars(self) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the expiration time.
Returns:
Set of computed vars to include in the delta.
"""
return set(
cvar
for cvar in self.computed_vars
if self.computed_vars[cvar].needs_update(instance=self)
)
def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]: def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars. """Determine ComputedVars that need to be recalculated based on the given vars.
@ -1588,6 +1600,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# and always dirty computed vars (cache=False) # and always dirty computed vars (cache=False)
delta_vars = ( delta_vars = (
self.dirty_vars.intersection(self.base_vars) self.dirty_vars.intersection(self.base_vars)
.union(self.dirty_vars.intersection(self.computed_vars))
.union(self._dirty_computed_vars()) .union(self._dirty_computed_vars())
.union(self._always_dirty_computed_vars) .union(self._always_dirty_computed_vars)
) )
@ -1621,6 +1634,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self.parent_state.dirty_substates.add(self.get_name()) self.parent_state.dirty_substates.add(self.get_name())
self.parent_state._mark_dirty() self.parent_state._mark_dirty()
# Append expired computed vars to dirty_vars to trigger recalculation
self.dirty_vars.update(self._expired_computed_vars())
# have to mark computed vars dirty to allow access to newly computed # have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function # values within the same ComputedVar function
self._mark_dirty_computed_vars() self._mark_dirty_computed_vars()

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import contextlib import contextlib
import dataclasses import dataclasses
import datetime
import dis import dis
import functools import functools
import inspect import inspect
@ -1873,13 +1874,26 @@ class ComputedVar(Var, property):
# Whether to track dependencies and cache computed values # Whether to track dependencies and cache computed values
_cache: bool = dataclasses.field(default=False) _cache: bool = dataclasses.field(default=False)
_initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset) # The initial value of the computed var
_initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
# Explicit var dependencies to track
_static_deps: set[str] = dataclasses.field(default_factory=set)
# Whether var dependencies should be auto-determined
_auto_deps: bool = dataclasses.field(default=True)
# Interval at which the computed var should be updated
_update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None)
def __init__( def __init__(
self, self,
fget: Callable[[BaseState], Any], fget: Callable[[BaseState], Any],
initial_value: Any | types.Unset = types.Unset(), initial_value: Any | types.Unset = types.Unset(),
cache: bool = False, cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[int, datetime.timedelta]] = None,
**kwargs, **kwargs,
): ):
"""Initialize a ComputedVar. """Initialize a ComputedVar.
@ -1888,10 +1902,22 @@ class ComputedVar(Var, property):
fget: The getter function. fget: The getter function.
initial_value: The initial value of the computed var. initial_value: The initial value of the computed var.
cache: Whether to cache the computed value. cache: Whether to cache the computed value.
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
**kwargs: additional attributes to set on the instance **kwargs: additional attributes to set on the instance
""" """
self._initial_value = initial_value self._initial_value = initial_value
self._cache = cache self._cache = cache
if isinstance(interval, int):
interval = datetime.timedelta(seconds=interval)
self._update_interval = interval
if deps is None:
deps = []
self._static_deps = {
dep._var_name if isinstance(dep, Var) else dep for dep in deps
}
self._auto_deps = auto_deps
property.__init__(self, fget) property.__init__(self, fget)
kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__) kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type()) kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
@ -1912,6 +1938,9 @@ class ComputedVar(Var, property):
fget=kwargs.get("fget", self.fget), fget=kwargs.get("fget", self.fget),
initial_value=kwargs.get("initial_value", self._initial_value), initial_value=kwargs.get("initial_value", self._initial_value),
cache=kwargs.get("cache", self._cache), cache=kwargs.get("cache", self._cache),
deps=kwargs.get("deps", self._static_deps),
auto_deps=kwargs.get("auto_deps", self._auto_deps),
interval=kwargs.get("interval", self._update_interval),
_var_name=kwargs.get("_var_name", self._var_name), _var_name=kwargs.get("_var_name", self._var_name),
_var_type=kwargs.get("_var_type", self._var_type), _var_type=kwargs.get("_var_type", self._var_type),
_var_is_local=kwargs.get("_var_is_local", self._var_is_local), _var_is_local=kwargs.get("_var_is_local", self._var_is_local),
@ -1932,7 +1961,32 @@ class ComputedVar(Var, property):
""" """
return f"__cached_{self._var_name}" return f"__cached_{self._var_name}"
def __get__(self, instance, owner): @property
def _last_updated_attr(self) -> str:
"""Get the attribute used to store the last updated timestamp.
Returns:
An attribute name.
"""
return f"__last_updated_{self._var_name}"
def needs_update(self, instance: BaseState) -> bool:
"""Check if the computed var needs to be updated.
Args:
instance: The state instance that the computed var is attached to.
Returns:
True if the computed var needs to be updated, False otherwise.
"""
if self._update_interval is None:
return False
last_updated = getattr(instance, self._last_updated_attr, None)
if last_updated is None:
return True
return datetime.datetime.now() - last_updated > self._update_interval
def __get__(self, instance: BaseState | None, owner):
"""Get the ComputedVar value. """Get the ComputedVar value.
If the value is already cached on the instance, return the cached value. If the value is already cached on the instance, return the cached value.
@ -1948,10 +2002,13 @@ class ComputedVar(Var, property):
return super().__get__(instance, owner) return super().__get__(instance, owner)
# handle caching # handle caching
if not hasattr(instance, self._cache_attr): if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, super().__get__(instance, owner)) setattr(instance, self._cache_attr, super().__get__(instance, owner))
# Ensure the computed var gets serialized to redis. # Ensure the computed var gets serialized to redis.
instance._was_touched = True instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
return getattr(instance, self._cache_attr) return getattr(instance, self._cache_attr)
def _deps( def _deps(
@ -1978,7 +2035,9 @@ class ComputedVar(Var, property):
VarValueError: if the function references the get_state, parent_state, or substates attributes VarValueError: if the function references the get_state, parent_state, or substates attributes
(cannot track deps in a related state, only implicitly via parent state). (cannot track deps in a related state, only implicitly via parent state).
""" """
d = set() if not self._auto_deps:
return self._static_deps
d = self._static_deps.copy()
if obj is None: if obj is None:
fget = property.__getattribute__(self, "fget") fget = property.__getattribute__(self, "fget")
if fget is not None: if fget is not None:
@ -2076,6 +2135,9 @@ def computed_var(
fget: Callable[[BaseState], Any] | None = None, fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None, initial_value: Any | None = None,
cache: bool = False, cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
**kwargs, **kwargs,
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]: ) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
"""A ComputedVar decorator with or without kwargs. """A ComputedVar decorator with or without kwargs.
@ -2084,19 +2146,31 @@ def computed_var(
fget: The getter function. fget: The getter function.
initial_value: The initial value of the computed var. initial_value: The initial value of the computed var.
cache: Whether to cache the computed value. cache: Whether to cache the computed value.
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
**kwargs: additional attributes to set on the instance **kwargs: additional attributes to set on the instance
Returns: Returns:
A ComputedVar instance. A ComputedVar instance.
Raises:
ValueError: If caching is disabled and an update interval is set.
""" """
if cache is False and interval is not None:
raise ValueError("Cannot set update interval without caching.")
if fget is not None: if fget is not None:
return ComputedVar(fget=fget, cache=cache) return ComputedVar(fget=fget, cache=cache)
def wrapper(fget): def wrapper(fget: Callable[[BaseState], Any]) -> ComputedVar:
return ComputedVar( return ComputedVar(
fget=fget, fget=fget,
initial_value=initial_value, initial_value=initial_value,
cache=cache, cache=cache,
deps=deps,
auto_deps=auto_deps,
interval=interval,
**kwargs, **kwargs,
) )

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import datetime
from dataclasses import dataclass from dataclasses import dataclass
from _typeshed import Incomplete from _typeshed import Incomplete
from reflex import constants as constants from reflex import constants as constants
@ -141,6 +142,7 @@ class ComputedVar(Var):
def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ... def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ...
def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ... def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ...
def mark_dirty(self, instance) -> None: ... def mark_dirty(self, instance) -> None: ...
def needs_update(self, instance) -> bool: ...
def _determine_var_type(self) -> Type: ... def _determine_var_type(self) -> Type: ...
@overload @overload
def __init__( def __init__(
@ -155,10 +157,24 @@ class ComputedVar(Var):
def computed_var( def computed_var(
fget: Callable[[BaseState], Any] | None = None, fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None, initial_value: Any | None = None,
cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
**kwargs, **kwargs,
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ... ) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
@overload @overload
def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ... def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
@overload
def cached_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
**kwargs,
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
@overload
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ... def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
class CallableVar(BaseVar): class CallableVar(BaseVar):