diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 3f5e1ac86..44d1764a3 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2130,6 +2130,37 @@ class ComputedVar(Var[RETURN_TYPE]): with contextlib.suppress(AttributeError): delattr(instance, self._cache_attr) + def add_dependency(self, objclass: Type[BaseState], dep: Var): + """Explicitly add a dependency to the ComputedVar. + + After adding the dependency, when the `dep` changes, this computed var + will be marked dirty. + + Args: + objclass: The class obj this ComputedVar is attached to. + dep: The dependency to add. + + Raises: + VarDependencyError: If the dependency is not a Var instance with a + state and field name + """ + if all_var_data := dep._get_all_var_data(): + state_name = all_var_data.state + if state_name: + var_name = all_var_data.field_name + if var_name: + self._static_deps.setdefault(state_name, set()).add(var_name) + objclass.get_root_state().get_class_substate( + state_name + )._var_dependencies.setdefault(var_name, set()).add( + (objclass.get_full_name(), self._js_expr) + ) + return + raise VarDependencyError( + "ComputedVar dependencies must be Var instances with a state and " + f"field name, got {dep!r}." + ) + def _determine_var_type(self) -> Type: """Get the type of the var. diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py index 387de5a66..8b0028120 100644 --- a/reflex/vars/dep_tracking.py +++ b/reflex/vars/dep_tracking.py @@ -8,7 +8,7 @@ import dis import enum import inspect from types import CodeType, FunctionType -from typing import TYPE_CHECKING, ClassVar, Type, cast +from typing import TYPE_CHECKING, Any, ClassVar, Type, cast from reflex.utils.exceptions import VarValueError @@ -18,6 +18,24 @@ if TYPE_CHECKING: from .base import Var +CellEmpty = object() + + +def get_cell_value(cell) -> Any: + """Get the value of a cell object. + + Args: + cell: The cell object to get the value from. (func.__closure__ objects) + + Returns: + The value from the cell or CellEmpty if a ValueError is raised. + """ + try: + return cell.cell_contents + except ValueError: + return CellEmpty + + class ScanStatus(enum.Enum): """State of the dis instruction scanning loop.""" @@ -124,6 +142,33 @@ class DependencyTracker: instruction.argval ) + def _get_globals(self) -> dict[str, Any]: + """Get the globals of the function. + + Returns: + The var names and values in the globals of the function. + """ + if isinstance(self.func, CodeType): + return {} + return self.func.__globals__ # pyright: ignore[reportGeneralTypeIssues] + + def _get_closure(self) -> dict[str, Any]: + """Get the closure of the function, with unbound values omitted. + + Returns: + The var names and values in the closure of the function. + """ + if isinstance(self.func, CodeType): + return {} + return { + var_name: get_cell_value(cell) + for var_name, cell in zip( + self.func.__code__.co_freevars, # pyright: ignore[reportGeneralTypeIssues] + self.func.__closure__, # pyright: ignore[reportGeneralTypeIssues] + ) + if get_cell_value(cell) is not CellEmpty + } + def handle_getting_state(self, instruction: dis.Instruction) -> None: """Handle bytecode analysis when `get_state` was called in the function. @@ -153,9 +198,7 @@ class DependencyTracker: if instruction.opname == "LOAD_GLOBAL": # Special case: referencing state class from global scope. try: - self._getting_state_class = inspect.getclosurevars(self.func).globals[ - instruction.argval - ] + self._getting_state_class = self._get_globals()[instruction.argval] except (ValueError, KeyError) as ve: raise VarValueError( f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals." @@ -163,11 +206,10 @@ class DependencyTracker: elif instruction.opname == "LOAD_DEREF": # Special case: referencing state class from closure. try: - closure = inspect.getclosurevars(self.func).nonlocals - self._getting_state_class = closure[instruction.argval] - except ValueError as ve: + self._getting_state_class = self._get_closure()[instruction.argval] + except (ValueError, KeyError) as ve: raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?" ) from ve elif instruction.opname == "STORE_FAST": # Storing the result of get_state in a local variable. @@ -192,15 +234,16 @@ class DependencyTracker: """ # Get the original source code and eval it to get the Var. module = inspect.getmodule(self.func) - positions = self._getting_var_instructions[0].positions - if module is None or positions is None: + positions0 = self._getting_var_instructions[0].positions + positions1 = self._getting_var_instructions[-1].positions + if module is None or positions0 is None or positions1 is None: raise VarValueError( f"Cannot determine the source code for the var in {self.func!r}." ) - start_line = positions.lineno - start_column = positions.col_offset - end_line = positions.end_lineno - end_column = positions.end_col_offset + start_line = positions0.lineno + start_column = positions0.col_offset + end_line = positions1.end_lineno + end_column = positions1.end_col_offset if ( start_line is None or start_column is None @@ -217,23 +260,13 @@ class DependencyTracker: [ *source[0][start_column:], *(source[1:-2] if len(source) > 2 else []), - *source[-1][:end_column], + *source[-1][: end_column - 1], ] ) else: - snipped_source = source[0][start_column:end_column] - # Fallback if the closure is not available. - globals = {} - closure = {} - try: - if not isinstance(self.func, CodeType): - closurevars = inspect.getclosurevars(self.func) - closure = closurevars.nonlocals - globals = dict(closurevars.globals) - except Exception: - pass + snipped_source = source[0][start_column : end_column - 1] # Evaluate the string in the context of the function's globals and closure. - return eval(f"({snipped_source})", globals, closure) + return eval(f"({snipped_source})", self._get_globals(), self._get_closure()) def handle_getting_var(self, instruction: dis.Instruction) -> None: """Handle bytecode analysis when `get_var_value` was called in the function. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 2a07f1b2e..00b1ac9a0 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -14,6 +14,7 @@ from typing import ( Any, AsyncGenerator, Callable, + ClassVar, Dict, List, Optional, @@ -3883,3 +3884,58 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str): assert await child.v == 2 root.parent_var = 2 assert await child.v == 3 + + +class Table(rx.ComponentState): + """A table state.""" + + data: ClassVar[Var] + + @rx.var(cache=True, auto_deps=False) + async def rows(self) -> List[Dict[str, Any]]: + """Computed var over the given rows. + + Returns: + The data rows. + """ + return await self.get_var_value(self.data) + + @classmethod + def get_component(cls, data: Var) -> rx.Component: + """Get the component for the table. + + Args: + data: The data var. + + Returns: + The component. + """ + cls.data = data + cls.computed_vars["rows"].add_dependency(cls, data) + return rx.foreach(data, lambda d: rx.text(d.to_string())) + + +@pytest.mark.asyncio +async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str): + """A test where an async computed var depends on a var in another state. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + + class OtherState(rx.State): + """A state with a var.""" + + data: List[Dict[str, Any]] = [{"foo": "bar"}] + + mock_app.state_manager.state = mock_app._state = rx.State + comp = Table.create(data=OtherState.data) + state = await mock_app.state_manager.get_state(_substate_key(token, OtherState)) + other_state = await state.get_state(OtherState) + assert comp.State is not None + comp_state = await state.get_state(comp.State) + assert comp_state.dirty_vars == set() + + other_state.data.append({"foo": "baz"}) + assert "rows" in comp_state.dirty_vars