From 25456a53a9380620e560b87f2b087d17405d39ba Mon Sep 17 00:00:00 2001
From: Masen Furer <m_github@0x26.net>
Date: Wed, 29 Jan 2025 01:15:47 -0800
Subject: [PATCH] ComputedVar.add_dependency: explicitly dependency declaration

Allow var dependencies to be added at runtime, for example, when defining a
ComponentState that depends on vars that cannot be known statically.

Fix more pyright issues.
---
 reflex/vars/base.py         | 31 +++++++++++++
 reflex/vars/dep_tracking.py | 87 +++++++++++++++++++++++++------------
 tests/units/test_state.py   | 56 ++++++++++++++++++++++++
 3 files changed, 147 insertions(+), 27 deletions(-)

diff --git a/reflex/vars/base.py b/reflex/vars/base.py
index 3f5e1ac86..44d1764a3 100644
--- a/reflex/vars/base.py
+++ b/reflex/vars/base.py
@@ -2130,6 +2130,37 @@ class ComputedVar(Var[RETURN_TYPE]):
         with contextlib.suppress(AttributeError):
             delattr(instance, self._cache_attr)
 
+    def add_dependency(self, objclass: Type[BaseState], dep: Var):
+        """Explicitly add a dependency to the ComputedVar.
+
+        After adding the dependency, when the `dep` changes, this computed var
+        will be marked dirty.
+
+        Args:
+            objclass: The class obj this ComputedVar is attached to.
+            dep: The dependency to add.
+
+        Raises:
+            VarDependencyError: If the dependency is not a Var instance with a
+                state and field name
+        """
+        if all_var_data := dep._get_all_var_data():
+            state_name = all_var_data.state
+            if state_name:
+                var_name = all_var_data.field_name
+                if var_name:
+                    self._static_deps.setdefault(state_name, set()).add(var_name)
+                    objclass.get_root_state().get_class_substate(
+                        state_name
+                    )._var_dependencies.setdefault(var_name, set()).add(
+                        (objclass.get_full_name(), self._js_expr)
+                    )
+                    return
+        raise VarDependencyError(
+            "ComputedVar dependencies must be Var instances with a state and "
+            f"field name, got {dep!r}."
+        )
+
     def _determine_var_type(self) -> Type:
         """Get the type of the var.
 
diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py
index 387de5a66..8b0028120 100644
--- a/reflex/vars/dep_tracking.py
+++ b/reflex/vars/dep_tracking.py
@@ -8,7 +8,7 @@ import dis
 import enum
 import inspect
 from types import CodeType, FunctionType
-from typing import TYPE_CHECKING, ClassVar, Type, cast
+from typing import TYPE_CHECKING, Any, ClassVar, Type, cast
 
 from reflex.utils.exceptions import VarValueError
 
@@ -18,6 +18,24 @@ if TYPE_CHECKING:
     from .base import Var
 
 
+CellEmpty = object()
+
+
+def get_cell_value(cell) -> Any:
+    """Get the value of a cell object.
+
+    Args:
+        cell: The cell object to get the value from. (func.__closure__ objects)
+
+    Returns:
+        The value from the cell or CellEmpty if a ValueError is raised.
+    """
+    try:
+        return cell.cell_contents
+    except ValueError:
+        return CellEmpty
+
+
 class ScanStatus(enum.Enum):
     """State of the dis instruction scanning loop."""
 
@@ -124,6 +142,33 @@ class DependencyTracker:
                 instruction.argval
             )
 
+    def _get_globals(self) -> dict[str, Any]:
+        """Get the globals of the function.
+
+        Returns:
+            The var names and values in the globals of the function.
+        """
+        if isinstance(self.func, CodeType):
+            return {}
+        return self.func.__globals__  # pyright: ignore[reportGeneralTypeIssues]
+
+    def _get_closure(self) -> dict[str, Any]:
+        """Get the closure of the function, with unbound values omitted.
+
+        Returns:
+            The var names and values in the closure of the function.
+        """
+        if isinstance(self.func, CodeType):
+            return {}
+        return {
+            var_name: get_cell_value(cell)
+            for var_name, cell in zip(
+                self.func.__code__.co_freevars,  # pyright: ignore[reportGeneralTypeIssues]
+                self.func.__closure__,  # pyright: ignore[reportGeneralTypeIssues]
+            )
+            if get_cell_value(cell) is not CellEmpty
+        }
+
     def handle_getting_state(self, instruction: dis.Instruction) -> None:
         """Handle bytecode analysis when `get_state` was called in the function.
 
@@ -153,9 +198,7 @@ class DependencyTracker:
         if instruction.opname == "LOAD_GLOBAL":
             # Special case: referencing state class from global scope.
             try:
-                self._getting_state_class = inspect.getclosurevars(self.func).globals[
-                    instruction.argval
-                ]
+                self._getting_state_class = self._get_globals()[instruction.argval]
             except (ValueError, KeyError) as ve:
                 raise VarValueError(
                     f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
@@ -163,11 +206,10 @@ class DependencyTracker:
         elif instruction.opname == "LOAD_DEREF":
             # Special case: referencing state class from closure.
             try:
-                closure = inspect.getclosurevars(self.func).nonlocals
-                self._getting_state_class = closure[instruction.argval]
-            except ValueError as ve:
+                self._getting_state_class = self._get_closure()[instruction.argval]
+            except (ValueError, KeyError) as ve:
                 raise VarValueError(
-                    f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
+                    f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
                 ) from ve
         elif instruction.opname == "STORE_FAST":
             # Storing the result of get_state in a local variable.
@@ -192,15 +234,16 @@ class DependencyTracker:
         """
         # Get the original source code and eval it to get the Var.
         module = inspect.getmodule(self.func)
-        positions = self._getting_var_instructions[0].positions
-        if module is None or positions is None:
+        positions0 = self._getting_var_instructions[0].positions
+        positions1 = self._getting_var_instructions[-1].positions
+        if module is None or positions0 is None or positions1 is None:
             raise VarValueError(
                 f"Cannot determine the source code for the var in {self.func!r}."
             )
-        start_line = positions.lineno
-        start_column = positions.col_offset
-        end_line = positions.end_lineno
-        end_column = positions.end_col_offset
+        start_line = positions0.lineno
+        start_column = positions0.col_offset
+        end_line = positions1.end_lineno
+        end_column = positions1.end_col_offset
         if (
             start_line is None
             or start_column is None
@@ -217,23 +260,13 @@ class DependencyTracker:
                 [
                     *source[0][start_column:],
                     *(source[1:-2] if len(source) > 2 else []),
-                    *source[-1][:end_column],
+                    *source[-1][: end_column - 1],
                 ]
             )
         else:
-            snipped_source = source[0][start_column:end_column]
-        # Fallback if the closure is not available.
-        globals = {}
-        closure = {}
-        try:
-            if not isinstance(self.func, CodeType):
-                closurevars = inspect.getclosurevars(self.func)
-                closure = closurevars.nonlocals
-                globals = dict(closurevars.globals)
-        except Exception:
-            pass
+            snipped_source = source[0][start_column : end_column - 1]
         # Evaluate the string in the context of the function's globals and closure.
-        return eval(f"({snipped_source})", globals, closure)
+        return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
 
     def handle_getting_var(self, instruction: dis.Instruction) -> None:
         """Handle bytecode analysis when `get_var_value` was called in the function.
diff --git a/tests/units/test_state.py b/tests/units/test_state.py
index 2a07f1b2e..00b1ac9a0 100644
--- a/tests/units/test_state.py
+++ b/tests/units/test_state.py
@@ -14,6 +14,7 @@ from typing import (
     Any,
     AsyncGenerator,
     Callable,
+    ClassVar,
     Dict,
     List,
     Optional,
@@ -3883,3 +3884,58 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
     assert await child.v == 2
     root.parent_var = 2
     assert await child.v == 3
+
+
+class Table(rx.ComponentState):
+    """A table state."""
+
+    data: ClassVar[Var]
+
+    @rx.var(cache=True, auto_deps=False)
+    async def rows(self) -> List[Dict[str, Any]]:
+        """Computed var over the given rows.
+
+        Returns:
+            The data rows.
+        """
+        return await self.get_var_value(self.data)
+
+    @classmethod
+    def get_component(cls, data: Var) -> rx.Component:
+        """Get the component for the table.
+
+        Args:
+            data: The data var.
+
+        Returns:
+            The component.
+        """
+        cls.data = data
+        cls.computed_vars["rows"].add_dependency(cls, data)
+        return rx.foreach(data, lambda d: rx.text(d.to_string()))
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str):
+    """A test where an async computed var depends on a var in another state.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+
+    class OtherState(rx.State):
+        """A state with a var."""
+
+        data: List[Dict[str, Any]] = [{"foo": "bar"}]
+
+    mock_app.state_manager.state = mock_app._state = rx.State
+    comp = Table.create(data=OtherState.data)
+    state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
+    other_state = await state.get_state(OtherState)
+    assert comp.State is not None
+    comp_state = await state.get_state(comp.State)
+    assert comp_state.dirty_vars == set()
+
+    other_state.data.append({"foo": "baz"})
+    assert "rows" in comp_state.dirty_vars