diff --git a/integration/test_computed_vars.py b/integration/test_computed_vars.py index 2523248a9..5b8da7571 100644 --- a/integration/test_computed_vars.py +++ b/integration/test_computed_vars.py @@ -37,8 +37,13 @@ def ComputedVars(): return self.count + # explicit dependency on count var + @rx.var(cache=True, deps=["count"], auto_deps=False) + def depends_on_count(self) -> int: + return self.count + # explicit dependency on count1 var - @rx.cached_var(deps=[count1], auto_deps=False) + @rx.var(cache=True, deps=[count1], auto_deps=False) def depends_on_count1(self) -> int: return self.count @@ -70,6 +75,11 @@ def ComputedVars(): rx.text(State.count2, id="count2"), rx.text("count3:"), rx.text(State.count3, id="count3"), + rx.text("depends_on_count:"), + rx.text( + State.depends_on_count, + id="depends_on_count", + ), rx.text("depends_on_count1:"), rx.text( State.depends_on_count1, @@ -90,7 +100,7 @@ def ComputedVars(): @pytest.fixture(scope="module") def computed_vars( - tmp_path_factory, + tmp_path_factory: pytest.TempPathFactory, ) -> Generator[AppHarness, None, None]: """Start ComputedVars app at tmp_path via AppHarness. @@ -177,6 +187,10 @@ def test_computed_vars( assert count3 assert count3.text == "0" + depends_on_count = driver.find_element(By.ID, "depends_on_count") + assert depends_on_count + assert depends_on_count.text == "0" + depends_on_count1 = driver.find_element(By.ID, "depends_on_count1") assert depends_on_count1 assert depends_on_count1.text == "0" @@ -197,10 +211,14 @@ def test_computed_vars( assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1" assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1" assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1" + assert ( + computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0") + == "1" + ) mark_dirty.click() with pytest.raises(TimeoutError): - computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0") + _ = computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0") time.sleep(10) assert count3.text == "0" diff --git a/reflex/app.py b/reflex/app.py index 937aee62a..db57222f7 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -797,6 +797,34 @@ class App(LifespanMixin, Base): for render, kwargs in DECORATED_PAGES[get_config().app_name]: self.add_page(render, **kwargs) + def _validate_var_dependencies( + self, state: Optional[Type[BaseState]] = None + ) -> None: + """Validate the dependencies of the vars in the app. + + Args: + state: The state to validate the dependencies for. + + Raises: + VarDependencyError: When a computed var has an invalid dependency. + """ + if not self.state: + return + + if not state: + state = self.state + + for var in state.computed_vars.values(): + deps = var._deps(objclass=state) + for dep in deps: + if dep not in state.vars and dep not in state.backend_vars: + raise exceptions.VarDependencyError( + f"ComputedVar {var._var_name} on state {state.__name__} has an invalid dependency {dep}" + ) + + for substate in state.class_subclasses: + self._validate_var_dependencies(substate) + def _compile(self, export: bool = False): """Compile the app and output it to the pages folder. @@ -818,6 +846,7 @@ class App(LifespanMixin, Base): if not self._should_compile(): return + self._validate_var_dependencies() self._setup_overlay_component() # Create a progress bar. diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index aabaaef14..d219dcf0c 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -61,6 +61,10 @@ class VarOperationTypeError(ReflexError, TypeError): """Custom TypeError for when unsupported operations are performed on vars.""" +class VarDependencyError(ReflexError, ValueError): + """Custom ValueError for when a var depends on a non-existent var.""" + + class InvalidStylePropError(ReflexError, TypeError): """Custom Type Error when style props have invalid values.""" diff --git a/reflex/vars.py b/reflex/vars.py index 6d002871c..93918e682 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1971,6 +1971,9 @@ class ComputedVar(Var, property): auto_deps: Whether var dependencies should be auto-determined. interval: Interval at which the computed var should be updated. **kwargs: additional attributes to set on the instance + + Raises: + TypeError: If the computed var dependencies are not Var instances or var names. """ self._initial_value = initial_value self._cache = cache @@ -1979,6 +1982,15 @@ class ComputedVar(Var, property): self._update_interval = interval if deps is None: deps = [] + else: + for dep in deps: + if isinstance(dep, Var): + continue + if isinstance(dep, str) and dep != "": + continue + raise TypeError( + "ComputedVar dependencies must be Var instances or var names (non-empty strings)." + ) self._static_deps = { dep._var_name if isinstance(dep, Var) else dep for dep in deps } diff --git a/tests/test_app.py b/tests/test_app.py index 142f0db0b..02982a10d 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -46,7 +46,7 @@ from reflex.state import ( ) from reflex.style import Style from reflex.utils import exceptions, format -from reflex.vars import ComputedVar +from reflex.vars import computed_var from .conftest import chdir from .states import ( @@ -951,7 +951,7 @@ class DynamicState(BaseState): """Increment the counter var.""" self.counter = self.counter + 1 - @ComputedVar + @computed_var def comp_dynamic(self) -> str: """A computed var that depends on the dynamic var. @@ -1278,7 +1278,7 @@ def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]: yield app, web_dir -def test_app_wrap_compile_theme(compilable_app): +def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]): """Test that the radix theme component wraps the app. Args: @@ -1306,7 +1306,7 @@ def test_app_wrap_compile_theme(compilable_app): ) in "".join(app_js_lines) -def test_app_wrap_priority(compilable_app): +def test_app_wrap_priority(compilable_app: tuple[App, Path]): """Test that the app wrap components are wrapped in the correct order. Args: @@ -1490,7 +1490,7 @@ def test_add_page_component_returning_tuple(): @pytest.mark.parametrize("export", (True, False)) -def test_app_with_transpile_packages(compilable_app, export): +def test_app_with_transpile_packages(compilable_app: tuple[App, Path], export: bool): class C1(rx.Component): library = "foo@1.2.3" tag = "Foo" @@ -1539,3 +1539,35 @@ def test_app_with_transpile_packages(compilable_app, export): else: assert 'output: "export"' not in next_config assert f'distDir: "{constants.Dirs.STATIC}"' not in next_config + + +def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]): + app, _ = compilable_app + + class ValidDepState(BaseState): + base: int = 0 + _backend: int = 0 + + @computed_var + def foo(self) -> str: + return "foo" + + @computed_var(deps=["_backend", "base", foo]) + def bar(self) -> str: + return "bar" + + app.state = ValidDepState + app._compile() + + +def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]): + app, _ = compilable_app + + class InvalidDepState(BaseState): + @computed_var(deps=["foolksjdf"]) + def bar(self) -> str: + return "bar" + + app.state = InvalidDepState + with pytest.raises(exceptions.VarDependencyError): + app._compile() diff --git a/tests/test_var.py b/tests/test_var.py index 3a1fb08a8..b57b6919a 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -1,6 +1,6 @@ import json import typing -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Set, Tuple, Union import pytest from pandas import DataFrame @@ -9,6 +9,7 @@ from reflex.base import Base from reflex.state import BaseState from reflex.vars import ( BaseVar, + ComputedVar, Var, computed_var, ) @@ -1388,3 +1389,43 @@ def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: List ) def test_var_name_unwrapped(var, expected): assert var._var_name_unwrapped == expected + + +def cv_fget(state: BaseState) -> int: + return 1 + + +@pytest.mark.parametrize( + "deps,expected", + [ + (["a"], {"a"}), + (["b"], {"b"}), + ([ComputedVar(fget=cv_fget)], {"cv_fget"}), + ], +) +def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]): + @computed_var( + deps=deps, + ) + def test_var(state) -> int: + return 1 + + assert test_var._static_deps == expected + + +@pytest.mark.parametrize( + "deps", + [ + [""], + [1], + ["", "abc"], + ], +) +def test_invalid_computed_var_deps(deps: List): + with pytest.raises(TypeError): + + @computed_var( + deps=deps, + ) + def test_var(state) -> int: + return 1