Validate ComputedVar dependencies, add tests (#3527)
This commit is contained in:
parent
f037df0977
commit
41efe12e9a
@ -37,8 +37,13 @@ def ComputedVars():
|
||||
|
||||
return self.count
|
||||
|
||||
# explicit dependency on count var
|
||||
@rx.var(cache=True, deps=["count"], auto_deps=False)
|
||||
def depends_on_count(self) -> int:
|
||||
return self.count
|
||||
|
||||
# explicit dependency on count1 var
|
||||
@rx.cached_var(deps=[count1], auto_deps=False)
|
||||
@rx.var(cache=True, deps=[count1], auto_deps=False)
|
||||
def depends_on_count1(self) -> int:
|
||||
return self.count
|
||||
|
||||
@ -70,6 +75,11 @@ def ComputedVars():
|
||||
rx.text(State.count2, id="count2"),
|
||||
rx.text("count3:"),
|
||||
rx.text(State.count3, id="count3"),
|
||||
rx.text("depends_on_count:"),
|
||||
rx.text(
|
||||
State.depends_on_count,
|
||||
id="depends_on_count",
|
||||
),
|
||||
rx.text("depends_on_count1:"),
|
||||
rx.text(
|
||||
State.depends_on_count1,
|
||||
@ -90,7 +100,7 @@ def ComputedVars():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def computed_vars(
|
||||
tmp_path_factory,
|
||||
tmp_path_factory: pytest.TempPathFactory,
|
||||
) -> Generator[AppHarness, None, None]:
|
||||
"""Start ComputedVars app at tmp_path via AppHarness.
|
||||
|
||||
@ -177,6 +187,10 @@ def test_computed_vars(
|
||||
assert count3
|
||||
assert count3.text == "0"
|
||||
|
||||
depends_on_count = driver.find_element(By.ID, "depends_on_count")
|
||||
assert depends_on_count
|
||||
assert depends_on_count.text == "0"
|
||||
|
||||
depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
|
||||
assert depends_on_count1
|
||||
assert depends_on_count1.text == "0"
|
||||
@ -197,10 +211,14 @@ def test_computed_vars(
|
||||
assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
|
||||
assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
|
||||
assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"
|
||||
assert (
|
||||
computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0")
|
||||
== "1"
|
||||
)
|
||||
|
||||
mark_dirty.click()
|
||||
with pytest.raises(TimeoutError):
|
||||
computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
|
||||
_ = computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")
|
||||
|
||||
time.sleep(10)
|
||||
assert count3.text == "0"
|
||||
|
@ -797,6 +797,34 @@ class App(LifespanMixin, Base):
|
||||
for render, kwargs in DECORATED_PAGES[get_config().app_name]:
|
||||
self.add_page(render, **kwargs)
|
||||
|
||||
def _validate_var_dependencies(
|
||||
self, state: Optional[Type[BaseState]] = None
|
||||
) -> None:
|
||||
"""Validate the dependencies of the vars in the app.
|
||||
|
||||
Args:
|
||||
state: The state to validate the dependencies for.
|
||||
|
||||
Raises:
|
||||
VarDependencyError: When a computed var has an invalid dependency.
|
||||
"""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
if not state:
|
||||
state = self.state
|
||||
|
||||
for var in state.computed_vars.values():
|
||||
deps = var._deps(objclass=state)
|
||||
for dep in deps:
|
||||
if dep not in state.vars and dep not in state.backend_vars:
|
||||
raise exceptions.VarDependencyError(
|
||||
f"ComputedVar {var._var_name} on state {state.__name__} has an invalid dependency {dep}"
|
||||
)
|
||||
|
||||
for substate in state.class_subclasses:
|
||||
self._validate_var_dependencies(substate)
|
||||
|
||||
def _compile(self, export: bool = False):
|
||||
"""Compile the app and output it to the pages folder.
|
||||
|
||||
@ -818,6 +846,7 @@ class App(LifespanMixin, Base):
|
||||
if not self._should_compile():
|
||||
return
|
||||
|
||||
self._validate_var_dependencies()
|
||||
self._setup_overlay_component()
|
||||
|
||||
# Create a progress bar.
|
||||
|
@ -61,6 +61,10 @@ class VarOperationTypeError(ReflexError, TypeError):
|
||||
"""Custom TypeError for when unsupported operations are performed on vars."""
|
||||
|
||||
|
||||
class VarDependencyError(ReflexError, ValueError):
|
||||
"""Custom ValueError for when a var depends on a non-existent var."""
|
||||
|
||||
|
||||
class InvalidStylePropError(ReflexError, TypeError):
|
||||
"""Custom Type Error when style props have invalid values."""
|
||||
|
||||
|
@ -1971,6 +1971,9 @@ class ComputedVar(Var, property):
|
||||
auto_deps: Whether var dependencies should be auto-determined.
|
||||
interval: Interval at which the computed var should be updated.
|
||||
**kwargs: additional attributes to set on the instance
|
||||
|
||||
Raises:
|
||||
TypeError: If the computed var dependencies are not Var instances or var names.
|
||||
"""
|
||||
self._initial_value = initial_value
|
||||
self._cache = cache
|
||||
@ -1979,6 +1982,15 @@ class ComputedVar(Var, property):
|
||||
self._update_interval = interval
|
||||
if deps is None:
|
||||
deps = []
|
||||
else:
|
||||
for dep in deps:
|
||||
if isinstance(dep, Var):
|
||||
continue
|
||||
if isinstance(dep, str) and dep != "":
|
||||
continue
|
||||
raise TypeError(
|
||||
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
|
||||
)
|
||||
self._static_deps = {
|
||||
dep._var_name if isinstance(dep, Var) else dep for dep in deps
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ from reflex.state import (
|
||||
)
|
||||
from reflex.style import Style
|
||||
from reflex.utils import exceptions, format
|
||||
from reflex.vars import ComputedVar
|
||||
from reflex.vars import computed_var
|
||||
|
||||
from .conftest import chdir
|
||||
from .states import (
|
||||
@ -951,7 +951,7 @@ class DynamicState(BaseState):
|
||||
"""Increment the counter var."""
|
||||
self.counter = self.counter + 1
|
||||
|
||||
@ComputedVar
|
||||
@computed_var
|
||||
def comp_dynamic(self) -> str:
|
||||
"""A computed var that depends on the dynamic var.
|
||||
|
||||
@ -1278,7 +1278,7 @@ def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]:
|
||||
yield app, web_dir
|
||||
|
||||
|
||||
def test_app_wrap_compile_theme(compilable_app):
|
||||
def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]):
|
||||
"""Test that the radix theme component wraps the app.
|
||||
|
||||
Args:
|
||||
@ -1306,7 +1306,7 @@ def test_app_wrap_compile_theme(compilable_app):
|
||||
) in "".join(app_js_lines)
|
||||
|
||||
|
||||
def test_app_wrap_priority(compilable_app):
|
||||
def test_app_wrap_priority(compilable_app: tuple[App, Path]):
|
||||
"""Test that the app wrap components are wrapped in the correct order.
|
||||
|
||||
Args:
|
||||
@ -1490,7 +1490,7 @@ def test_add_page_component_returning_tuple():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("export", (True, False))
|
||||
def test_app_with_transpile_packages(compilable_app, export):
|
||||
def test_app_with_transpile_packages(compilable_app: tuple[App, Path], export: bool):
|
||||
class C1(rx.Component):
|
||||
library = "foo@1.2.3"
|
||||
tag = "Foo"
|
||||
@ -1539,3 +1539,35 @@ def test_app_with_transpile_packages(compilable_app, export):
|
||||
else:
|
||||
assert 'output: "export"' not in next_config
|
||||
assert f'distDir: "{constants.Dirs.STATIC}"' not in next_config
|
||||
|
||||
|
||||
def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
|
||||
app, _ = compilable_app
|
||||
|
||||
class ValidDepState(BaseState):
|
||||
base: int = 0
|
||||
_backend: int = 0
|
||||
|
||||
@computed_var
|
||||
def foo(self) -> str:
|
||||
return "foo"
|
||||
|
||||
@computed_var(deps=["_backend", "base", foo])
|
||||
def bar(self) -> str:
|
||||
return "bar"
|
||||
|
||||
app.state = ValidDepState
|
||||
app._compile()
|
||||
|
||||
|
||||
def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
|
||||
app, _ = compilable_app
|
||||
|
||||
class InvalidDepState(BaseState):
|
||||
@computed_var(deps=["foolksjdf"])
|
||||
def bar(self) -> str:
|
||||
return "bar"
|
||||
|
||||
app.state = InvalidDepState
|
||||
with pytest.raises(exceptions.VarDependencyError):
|
||||
app._compile()
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import typing
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import Dict, List, Set, Tuple, Union
|
||||
|
||||
import pytest
|
||||
from pandas import DataFrame
|
||||
@ -9,6 +9,7 @@ from reflex.base import Base
|
||||
from reflex.state import BaseState
|
||||
from reflex.vars import (
|
||||
BaseVar,
|
||||
ComputedVar,
|
||||
Var,
|
||||
computed_var,
|
||||
)
|
||||
@ -1388,3 +1389,43 @@ def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: List
|
||||
)
|
||||
def test_var_name_unwrapped(var, expected):
|
||||
assert var._var_name_unwrapped == expected
|
||||
|
||||
|
||||
def cv_fget(state: BaseState) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"deps,expected",
|
||||
[
|
||||
(["a"], {"a"}),
|
||||
(["b"], {"b"}),
|
||||
([ComputedVar(fget=cv_fget)], {"cv_fget"}),
|
||||
],
|
||||
)
|
||||
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
|
||||
@computed_var(
|
||||
deps=deps,
|
||||
)
|
||||
def test_var(state) -> int:
|
||||
return 1
|
||||
|
||||
assert test_var._static_deps == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"deps",
|
||||
[
|
||||
[""],
|
||||
[1],
|
||||
["", "abc"],
|
||||
],
|
||||
)
|
||||
def test_invalid_computed_var_deps(deps: List):
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
@computed_var(
|
||||
deps=deps,
|
||||
)
|
||||
def test_var(state) -> int:
|
||||
return 1
|
||||
|
Loading…
Reference in New Issue
Block a user