ComputedVar.add_dependency: explicitly dependency declaration

Allow var dependencies to be added at runtime, for example, when defining a
ComponentState that depends on vars that cannot be known statically.

Fix more pyright issues.
This commit is contained in:
Masen Furer 2025-01-29 01:15:47 -08:00
parent e74e913b4c
commit 25456a53a9
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4
3 changed files with 147 additions and 27 deletions

View File

@ -2130,6 +2130,37 @@ class ComputedVar(Var[RETURN_TYPE]):
with contextlib.suppress(AttributeError):
delattr(instance, self._cache_attr)
def add_dependency(self, objclass: Type[BaseState], dep: Var):
"""Explicitly add a dependency to the ComputedVar.
After adding the dependency, when the `dep` changes, this computed var
will be marked dirty.
Args:
objclass: The class obj this ComputedVar is attached to.
dep: The dependency to add.
Raises:
VarDependencyError: If the dependency is not a Var instance with a
state and field name
"""
if all_var_data := dep._get_all_var_data():
state_name = all_var_data.state
if state_name:
var_name = all_var_data.field_name
if var_name:
self._static_deps.setdefault(state_name, set()).add(var_name)
objclass.get_root_state().get_class_substate(
state_name
)._var_dependencies.setdefault(var_name, set()).add(
(objclass.get_full_name(), self._js_expr)
)
return
raise VarDependencyError(
"ComputedVar dependencies must be Var instances with a state and "
f"field name, got {dep!r}."
)
def _determine_var_type(self) -> Type:
"""Get the type of the var.

View File

@ -8,7 +8,7 @@ import dis
import enum
import inspect
from types import CodeType, FunctionType
from typing import TYPE_CHECKING, ClassVar, Type, cast
from typing import TYPE_CHECKING, Any, ClassVar, Type, cast
from reflex.utils.exceptions import VarValueError
@ -18,6 +18,24 @@ if TYPE_CHECKING:
from .base import Var
CellEmpty = object()
def get_cell_value(cell) -> Any:
"""Get the value of a cell object.
Args:
cell: The cell object to get the value from. (func.__closure__ objects)
Returns:
The value from the cell or CellEmpty if a ValueError is raised.
"""
try:
return cell.cell_contents
except ValueError:
return CellEmpty
class ScanStatus(enum.Enum):
"""State of the dis instruction scanning loop."""
@ -124,6 +142,33 @@ class DependencyTracker:
instruction.argval
)
def _get_globals(self) -> dict[str, Any]:
"""Get the globals of the function.
Returns:
The var names and values in the globals of the function.
"""
if isinstance(self.func, CodeType):
return {}
return self.func.__globals__ # pyright: ignore[reportGeneralTypeIssues]
def _get_closure(self) -> dict[str, Any]:
"""Get the closure of the function, with unbound values omitted.
Returns:
The var names and values in the closure of the function.
"""
if isinstance(self.func, CodeType):
return {}
return {
var_name: get_cell_value(cell)
for var_name, cell in zip(
self.func.__code__.co_freevars, # pyright: ignore[reportGeneralTypeIssues]
self.func.__closure__, # pyright: ignore[reportGeneralTypeIssues]
)
if get_cell_value(cell) is not CellEmpty
}
def handle_getting_state(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis when `get_state` was called in the function.
@ -153,9 +198,7 @@ class DependencyTracker:
if instruction.opname == "LOAD_GLOBAL":
# Special case: referencing state class from global scope.
try:
self._getting_state_class = inspect.getclosurevars(self.func).globals[
instruction.argval
]
self._getting_state_class = self._get_globals()[instruction.argval]
except (ValueError, KeyError) as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
@ -163,11 +206,10 @@ class DependencyTracker:
elif instruction.opname == "LOAD_DEREF":
# Special case: referencing state class from closure.
try:
closure = inspect.getclosurevars(self.func).nonlocals
self._getting_state_class = closure[instruction.argval]
except ValueError as ve:
self._getting_state_class = self._get_closure()[instruction.argval]
except (ValueError, KeyError) as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
) from ve
elif instruction.opname == "STORE_FAST":
# Storing the result of get_state in a local variable.
@ -192,15 +234,16 @@ class DependencyTracker:
"""
# Get the original source code and eval it to get the Var.
module = inspect.getmodule(self.func)
positions = self._getting_var_instructions[0].positions
if module is None or positions is None:
positions0 = self._getting_var_instructions[0].positions
positions1 = self._getting_var_instructions[-1].positions
if module is None or positions0 is None or positions1 is None:
raise VarValueError(
f"Cannot determine the source code for the var in {self.func!r}."
)
start_line = positions.lineno
start_column = positions.col_offset
end_line = positions.end_lineno
end_column = positions.end_col_offset
start_line = positions0.lineno
start_column = positions0.col_offset
end_line = positions1.end_lineno
end_column = positions1.end_col_offset
if (
start_line is None
or start_column is None
@ -217,23 +260,13 @@ class DependencyTracker:
[
*source[0][start_column:],
*(source[1:-2] if len(source) > 2 else []),
*source[-1][:end_column],
*source[-1][: end_column - 1],
]
)
else:
snipped_source = source[0][start_column:end_column]
# Fallback if the closure is not available.
globals = {}
closure = {}
try:
if not isinstance(self.func, CodeType):
closurevars = inspect.getclosurevars(self.func)
closure = closurevars.nonlocals
globals = dict(closurevars.globals)
except Exception:
pass
snipped_source = source[0][start_column : end_column - 1]
# Evaluate the string in the context of the function's globals and closure.
return eval(f"({snipped_source})", globals, closure)
return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
def handle_getting_var(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis when `get_var_value` was called in the function.

View File

@ -14,6 +14,7 @@ from typing import (
Any,
AsyncGenerator,
Callable,
ClassVar,
Dict,
List,
Optional,
@ -3883,3 +3884,58 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
assert await child.v == 2
root.parent_var = 2
assert await child.v == 3
class Table(rx.ComponentState):
"""A table state."""
data: ClassVar[Var]
@rx.var(cache=True, auto_deps=False)
async def rows(self) -> List[Dict[str, Any]]:
"""Computed var over the given rows.
Returns:
The data rows.
"""
return await self.get_var_value(self.data)
@classmethod
def get_component(cls, data: Var) -> rx.Component:
"""Get the component for the table.
Args:
data: The data var.
Returns:
The component.
"""
cls.data = data
cls.computed_vars["rows"].add_dependency(cls, data)
return rx.foreach(data, lambda d: rx.text(d.to_string()))
@pytest.mark.asyncio
async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str):
"""A test where an async computed var depends on a var in another state.
Args:
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
class OtherState(rx.State):
"""A state with a var."""
data: List[Dict[str, Any]] = [{"foo": "bar"}]
mock_app.state_manager.state = mock_app._state = rx.State
comp = Table.create(data=OtherState.data)
state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
other_state = await state.get_state(OtherState)
assert comp.State is not None
comp_state = await state.get_state(comp.State)
assert comp_state.dirty_vars == set()
other_state.data.append({"foo": "baz"})
assert "rows" in comp_state.dirty_vars