Validate ComputedVar dependencies, add tests (#3527)

This commit is contained in:
benedikt-bartscher 2024-06-25 01:11:42 +02:00 committed by GitHub
parent f037df0977
commit 41efe12e9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 145 additions and 9 deletions

View File

@ -37,8 +37,13 @@ def ComputedVars():
return self.count return self.count
# explicit dependency on count var
@rx.var(cache=True, deps=["count"], auto_deps=False)
def depends_on_count(self) -> int:
return self.count
# explicit dependency on count1 var # explicit dependency on count1 var
@rx.cached_var(deps=[count1], auto_deps=False) @rx.var(cache=True, deps=[count1], auto_deps=False)
def depends_on_count1(self) -> int: def depends_on_count1(self) -> int:
return self.count return self.count
@ -70,6 +75,11 @@ def ComputedVars():
rx.text(State.count2, id="count2"), rx.text(State.count2, id="count2"),
rx.text("count3:"), rx.text("count3:"),
rx.text(State.count3, id="count3"), rx.text(State.count3, id="count3"),
rx.text("depends_on_count:"),
rx.text(
State.depends_on_count,
id="depends_on_count",
),
rx.text("depends_on_count1:"), rx.text("depends_on_count1:"),
rx.text( rx.text(
State.depends_on_count1, State.depends_on_count1,
@ -90,7 +100,7 @@ def ComputedVars():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def computed_vars( def computed_vars(
tmp_path_factory, tmp_path_factory: pytest.TempPathFactory,
) -> Generator[AppHarness, None, None]: ) -> Generator[AppHarness, None, None]:
"""Start ComputedVars app at tmp_path via AppHarness. """Start ComputedVars app at tmp_path via AppHarness.
@ -177,6 +187,10 @@ def test_computed_vars(
assert count3 assert count3
assert count3.text == "0" assert count3.text == "0"
depends_on_count = driver.find_element(By.ID, "depends_on_count")
assert depends_on_count
assert depends_on_count.text == "0"
depends_on_count1 = driver.find_element(By.ID, "depends_on_count1") depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
assert depends_on_count1 assert depends_on_count1
assert depends_on_count1.text == "0" assert depends_on_count1.text == "0"
@ -197,10 +211,14 @@ def test_computed_vars(
assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1" assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1" assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1" assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"
assert (
computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0")
== "1"
)
mark_dirty.click() mark_dirty.click()
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0") _ = computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
time.sleep(10) time.sleep(10)
assert count3.text == "0" assert count3.text == "0"

View File

@ -797,6 +797,34 @@ class App(LifespanMixin, Base):
for render, kwargs in DECORATED_PAGES[get_config().app_name]: for render, kwargs in DECORATED_PAGES[get_config().app_name]:
self.add_page(render, **kwargs) self.add_page(render, **kwargs)
def _validate_var_dependencies(
self, state: Optional[Type[BaseState]] = None
) -> None:
"""Validate the dependencies of the vars in the app.
Args:
state: The state to validate the dependencies for.
Raises:
VarDependencyError: When a computed var has an invalid dependency.
"""
if not self.state:
return
if not state:
state = self.state
for var in state.computed_vars.values():
deps = var._deps(objclass=state)
for dep in deps:
if dep not in state.vars and dep not in state.backend_vars:
raise exceptions.VarDependencyError(
f"ComputedVar {var._var_name} on state {state.__name__} has an invalid dependency {dep}"
)
for substate in state.class_subclasses:
self._validate_var_dependencies(substate)
def _compile(self, export: bool = False): def _compile(self, export: bool = False):
"""Compile the app and output it to the pages folder. """Compile the app and output it to the pages folder.
@ -818,6 +846,7 @@ class App(LifespanMixin, Base):
if not self._should_compile(): if not self._should_compile():
return return
self._validate_var_dependencies()
self._setup_overlay_component() self._setup_overlay_component()
# Create a progress bar. # Create a progress bar.

View File

@ -61,6 +61,10 @@ class VarOperationTypeError(ReflexError, TypeError):
"""Custom TypeError for when unsupported operations are performed on vars.""" """Custom TypeError for when unsupported operations are performed on vars."""
class VarDependencyError(ReflexError, ValueError):
"""Custom ValueError for when a var depends on a non-existent var."""
class InvalidStylePropError(ReflexError, TypeError): class InvalidStylePropError(ReflexError, TypeError):
"""Custom Type Error when style props have invalid values.""" """Custom Type Error when style props have invalid values."""

View File

@ -1971,6 +1971,9 @@ class ComputedVar(Var, property):
auto_deps: Whether var dependencies should be auto-determined. auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated. interval: Interval at which the computed var should be updated.
**kwargs: additional attributes to set on the instance **kwargs: additional attributes to set on the instance
Raises:
TypeError: If the computed var dependencies are not Var instances or var names.
""" """
self._initial_value = initial_value self._initial_value = initial_value
self._cache = cache self._cache = cache
@ -1979,6 +1982,15 @@ class ComputedVar(Var, property):
self._update_interval = interval self._update_interval = interval
if deps is None: if deps is None:
deps = [] deps = []
else:
for dep in deps:
if isinstance(dep, Var):
continue
if isinstance(dep, str) and dep != "":
continue
raise TypeError(
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
)
self._static_deps = { self._static_deps = {
dep._var_name if isinstance(dep, Var) else dep for dep in deps dep._var_name if isinstance(dep, Var) else dep for dep in deps
} }

View File

@ -46,7 +46,7 @@ from reflex.state import (
) )
from reflex.style import Style from reflex.style import Style
from reflex.utils import exceptions, format from reflex.utils import exceptions, format
from reflex.vars import ComputedVar from reflex.vars import computed_var
from .conftest import chdir from .conftest import chdir
from .states import ( from .states import (
@ -951,7 +951,7 @@ class DynamicState(BaseState):
"""Increment the counter var.""" """Increment the counter var."""
self.counter = self.counter + 1 self.counter = self.counter + 1
@ComputedVar @computed_var
def comp_dynamic(self) -> str: def comp_dynamic(self) -> str:
"""A computed var that depends on the dynamic var. """A computed var that depends on the dynamic var.
@ -1278,7 +1278,7 @@ def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]:
yield app, web_dir yield app, web_dir
def test_app_wrap_compile_theme(compilable_app): def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]):
"""Test that the radix theme component wraps the app. """Test that the radix theme component wraps the app.
Args: Args:
@ -1306,7 +1306,7 @@ def test_app_wrap_compile_theme(compilable_app):
) in "".join(app_js_lines) ) in "".join(app_js_lines)
def test_app_wrap_priority(compilable_app): def test_app_wrap_priority(compilable_app: tuple[App, Path]):
"""Test that the app wrap components are wrapped in the correct order. """Test that the app wrap components are wrapped in the correct order.
Args: Args:
@ -1490,7 +1490,7 @@ def test_add_page_component_returning_tuple():
@pytest.mark.parametrize("export", (True, False)) @pytest.mark.parametrize("export", (True, False))
def test_app_with_transpile_packages(compilable_app, export): def test_app_with_transpile_packages(compilable_app: tuple[App, Path], export: bool):
class C1(rx.Component): class C1(rx.Component):
library = "foo@1.2.3" library = "foo@1.2.3"
tag = "Foo" tag = "Foo"
@ -1539,3 +1539,35 @@ def test_app_with_transpile_packages(compilable_app, export):
else: else:
assert 'output: "export"' not in next_config assert 'output: "export"' not in next_config
assert f'distDir: "{constants.Dirs.STATIC}"' not in next_config assert f'distDir: "{constants.Dirs.STATIC}"' not in next_config
def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
app, _ = compilable_app
class ValidDepState(BaseState):
base: int = 0
_backend: int = 0
@computed_var
def foo(self) -> str:
return "foo"
@computed_var(deps=["_backend", "base", foo])
def bar(self) -> str:
return "bar"
app.state = ValidDepState
app._compile()
def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
app, _ = compilable_app
class InvalidDepState(BaseState):
@computed_var(deps=["foolksjdf"])
def bar(self) -> str:
return "bar"
app.state = InvalidDepState
with pytest.raises(exceptions.VarDependencyError):
app._compile()

View File

@ -1,6 +1,6 @@
import json import json
import typing import typing
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple, Union
import pytest import pytest
from pandas import DataFrame from pandas import DataFrame
@ -9,6 +9,7 @@ from reflex.base import Base
from reflex.state import BaseState from reflex.state import BaseState
from reflex.vars import ( from reflex.vars import (
BaseVar, BaseVar,
ComputedVar,
Var, Var,
computed_var, computed_var,
) )
@ -1388,3 +1389,43 @@ def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: List
) )
def test_var_name_unwrapped(var, expected): def test_var_name_unwrapped(var, expected):
assert var._var_name_unwrapped == expected assert var._var_name_unwrapped == expected
def cv_fget(state: BaseState) -> int:
return 1
@pytest.mark.parametrize(
"deps,expected",
[
(["a"], {"a"}),
(["b"], {"b"}),
([ComputedVar(fget=cv_fget)], {"cv_fget"}),
],
)
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
@computed_var(
deps=deps,
)
def test_var(state) -> int:
return 1
assert test_var._static_deps == expected
@pytest.mark.parametrize(
"deps",
[
[""],
[1],
["", "abc"],
],
)
def test_invalid_computed_var_deps(deps: List):
with pytest.raises(TypeError):
@computed_var(
deps=deps,
)
def test_var(state) -> int:
return 1