ComputedVar.add_dependency: explicitly dependency declaration
Allow var dependencies to be added at runtime, for example, when defining a ComponentState that depends on vars that cannot be known statically. Fix more pyright issues.
This commit is contained in:
parent
e74e913b4c
commit
25456a53a9
@ -2130,6 +2130,37 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
with contextlib.suppress(AttributeError):
|
||||
delattr(instance, self._cache_attr)
|
||||
|
||||
def add_dependency(self, objclass: Type[BaseState], dep: Var):
|
||||
"""Explicitly add a dependency to the ComputedVar.
|
||||
|
||||
After adding the dependency, when the `dep` changes, this computed var
|
||||
will be marked dirty.
|
||||
|
||||
Args:
|
||||
objclass: The class obj this ComputedVar is attached to.
|
||||
dep: The dependency to add.
|
||||
|
||||
Raises:
|
||||
VarDependencyError: If the dependency is not a Var instance with a
|
||||
state and field name
|
||||
"""
|
||||
if all_var_data := dep._get_all_var_data():
|
||||
state_name = all_var_data.state
|
||||
if state_name:
|
||||
var_name = all_var_data.field_name
|
||||
if var_name:
|
||||
self._static_deps.setdefault(state_name, set()).add(var_name)
|
||||
objclass.get_root_state().get_class_substate(
|
||||
state_name
|
||||
)._var_dependencies.setdefault(var_name, set()).add(
|
||||
(objclass.get_full_name(), self._js_expr)
|
||||
)
|
||||
return
|
||||
raise VarDependencyError(
|
||||
"ComputedVar dependencies must be Var instances with a state and "
|
||||
f"field name, got {dep!r}."
|
||||
)
|
||||
|
||||
def _determine_var_type(self) -> Type:
|
||||
"""Get the type of the var.
|
||||
|
||||
|
@ -8,7 +8,7 @@ import dis
|
||||
import enum
|
||||
import inspect
|
||||
from types import CodeType, FunctionType
|
||||
from typing import TYPE_CHECKING, ClassVar, Type, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Type, cast
|
||||
|
||||
from reflex.utils.exceptions import VarValueError
|
||||
|
||||
@ -18,6 +18,24 @@ if TYPE_CHECKING:
|
||||
from .base import Var
|
||||
|
||||
|
||||
CellEmpty = object()
|
||||
|
||||
|
||||
def get_cell_value(cell) -> Any:
|
||||
"""Get the value of a cell object.
|
||||
|
||||
Args:
|
||||
cell: The cell object to get the value from. (func.__closure__ objects)
|
||||
|
||||
Returns:
|
||||
The value from the cell or CellEmpty if a ValueError is raised.
|
||||
"""
|
||||
try:
|
||||
return cell.cell_contents
|
||||
except ValueError:
|
||||
return CellEmpty
|
||||
|
||||
|
||||
class ScanStatus(enum.Enum):
|
||||
"""State of the dis instruction scanning loop."""
|
||||
|
||||
@ -124,6 +142,33 @@ class DependencyTracker:
|
||||
instruction.argval
|
||||
)
|
||||
|
||||
def _get_globals(self) -> dict[str, Any]:
|
||||
"""Get the globals of the function.
|
||||
|
||||
Returns:
|
||||
The var names and values in the globals of the function.
|
||||
"""
|
||||
if isinstance(self.func, CodeType):
|
||||
return {}
|
||||
return self.func.__globals__ # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
def _get_closure(self) -> dict[str, Any]:
|
||||
"""Get the closure of the function, with unbound values omitted.
|
||||
|
||||
Returns:
|
||||
The var names and values in the closure of the function.
|
||||
"""
|
||||
if isinstance(self.func, CodeType):
|
||||
return {}
|
||||
return {
|
||||
var_name: get_cell_value(cell)
|
||||
for var_name, cell in zip(
|
||||
self.func.__code__.co_freevars, # pyright: ignore[reportGeneralTypeIssues]
|
||||
self.func.__closure__, # pyright: ignore[reportGeneralTypeIssues]
|
||||
)
|
||||
if get_cell_value(cell) is not CellEmpty
|
||||
}
|
||||
|
||||
def handle_getting_state(self, instruction: dis.Instruction) -> None:
|
||||
"""Handle bytecode analysis when `get_state` was called in the function.
|
||||
|
||||
@ -153,9 +198,7 @@ class DependencyTracker:
|
||||
if instruction.opname == "LOAD_GLOBAL":
|
||||
# Special case: referencing state class from global scope.
|
||||
try:
|
||||
self._getting_state_class = inspect.getclosurevars(self.func).globals[
|
||||
instruction.argval
|
||||
]
|
||||
self._getting_state_class = self._get_globals()[instruction.argval]
|
||||
except (ValueError, KeyError) as ve:
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
|
||||
@ -163,11 +206,10 @@ class DependencyTracker:
|
||||
elif instruction.opname == "LOAD_DEREF":
|
||||
# Special case: referencing state class from closure.
|
||||
try:
|
||||
closure = inspect.getclosurevars(self.func).nonlocals
|
||||
self._getting_state_class = closure[instruction.argval]
|
||||
except ValueError as ve:
|
||||
self._getting_state_class = self._get_closure()[instruction.argval]
|
||||
except (ValueError, KeyError) as ve:
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
|
||||
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
|
||||
) from ve
|
||||
elif instruction.opname == "STORE_FAST":
|
||||
# Storing the result of get_state in a local variable.
|
||||
@ -192,15 +234,16 @@ class DependencyTracker:
|
||||
"""
|
||||
# Get the original source code and eval it to get the Var.
|
||||
module = inspect.getmodule(self.func)
|
||||
positions = self._getting_var_instructions[0].positions
|
||||
if module is None or positions is None:
|
||||
positions0 = self._getting_var_instructions[0].positions
|
||||
positions1 = self._getting_var_instructions[-1].positions
|
||||
if module is None or positions0 is None or positions1 is None:
|
||||
raise VarValueError(
|
||||
f"Cannot determine the source code for the var in {self.func!r}."
|
||||
)
|
||||
start_line = positions.lineno
|
||||
start_column = positions.col_offset
|
||||
end_line = positions.end_lineno
|
||||
end_column = positions.end_col_offset
|
||||
start_line = positions0.lineno
|
||||
start_column = positions0.col_offset
|
||||
end_line = positions1.end_lineno
|
||||
end_column = positions1.end_col_offset
|
||||
if (
|
||||
start_line is None
|
||||
or start_column is None
|
||||
@ -217,23 +260,13 @@ class DependencyTracker:
|
||||
[
|
||||
*source[0][start_column:],
|
||||
*(source[1:-2] if len(source) > 2 else []),
|
||||
*source[-1][:end_column],
|
||||
*source[-1][: end_column - 1],
|
||||
]
|
||||
)
|
||||
else:
|
||||
snipped_source = source[0][start_column:end_column]
|
||||
# Fallback if the closure is not available.
|
||||
globals = {}
|
||||
closure = {}
|
||||
try:
|
||||
if not isinstance(self.func, CodeType):
|
||||
closurevars = inspect.getclosurevars(self.func)
|
||||
closure = closurevars.nonlocals
|
||||
globals = dict(closurevars.globals)
|
||||
except Exception:
|
||||
pass
|
||||
snipped_source = source[0][start_column : end_column - 1]
|
||||
# Evaluate the string in the context of the function's globals and closure.
|
||||
return eval(f"({snipped_source})", globals, closure)
|
||||
return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
|
||||
|
||||
def handle_getting_var(self, instruction: dis.Instruction) -> None:
|
||||
"""Handle bytecode analysis when `get_var_value` was called in the function.
|
||||
|
@ -14,6 +14,7 @@ from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
@ -3883,3 +3884,58 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
|
||||
assert await child.v == 2
|
||||
root.parent_var = 2
|
||||
assert await child.v == 3
|
||||
|
||||
|
||||
class Table(rx.ComponentState):
|
||||
"""A table state."""
|
||||
|
||||
data: ClassVar[Var]
|
||||
|
||||
@rx.var(cache=True, auto_deps=False)
|
||||
async def rows(self) -> List[Dict[str, Any]]:
|
||||
"""Computed var over the given rows.
|
||||
|
||||
Returns:
|
||||
The data rows.
|
||||
"""
|
||||
return await self.get_var_value(self.data)
|
||||
|
||||
@classmethod
|
||||
def get_component(cls, data: Var) -> rx.Component:
|
||||
"""Get the component for the table.
|
||||
|
||||
Args:
|
||||
data: The data var.
|
||||
|
||||
Returns:
|
||||
The component.
|
||||
"""
|
||||
cls.data = data
|
||||
cls.computed_vars["rows"].add_dependency(cls, data)
|
||||
return rx.foreach(data, lambda d: rx.text(d.to_string()))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str):
|
||||
"""A test where an async computed var depends on a var in another state.
|
||||
|
||||
Args:
|
||||
mock_app: An app that will be returned by `get_app()`
|
||||
token: A token.
|
||||
"""
|
||||
|
||||
class OtherState(rx.State):
|
||||
"""A state with a var."""
|
||||
|
||||
data: List[Dict[str, Any]] = [{"foo": "bar"}]
|
||||
|
||||
mock_app.state_manager.state = mock_app._state = rx.State
|
||||
comp = Table.create(data=OtherState.data)
|
||||
state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
|
||||
other_state = await state.get_state(OtherState)
|
||||
assert comp.State is not None
|
||||
comp_state = await state.get_state(comp.State)
|
||||
assert comp_state.dirty_vars == set()
|
||||
|
||||
other_state.data.append({"foo": "baz"})
|
||||
assert "rows" in comp_state.dirty_vars
|
||||
|
Loading…
Reference in New Issue
Block a user