ComputedVar.add_dependency: explicitly dependency declaration
Allow var dependencies to be added at runtime, for example, when defining a ComponentState that depends on vars that cannot be known statically. Fix more pyright issues.
This commit is contained in:
parent
e74e913b4c
commit
25456a53a9
@ -2130,6 +2130,37 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
delattr(instance, self._cache_attr)
|
delattr(instance, self._cache_attr)
|
||||||
|
|
||||||
|
def add_dependency(self, objclass: Type[BaseState], dep: Var):
|
||||||
|
"""Explicitly add a dependency to the ComputedVar.
|
||||||
|
|
||||||
|
After adding the dependency, when the `dep` changes, this computed var
|
||||||
|
will be marked dirty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
objclass: The class obj this ComputedVar is attached to.
|
||||||
|
dep: The dependency to add.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
VarDependencyError: If the dependency is not a Var instance with a
|
||||||
|
state and field name
|
||||||
|
"""
|
||||||
|
if all_var_data := dep._get_all_var_data():
|
||||||
|
state_name = all_var_data.state
|
||||||
|
if state_name:
|
||||||
|
var_name = all_var_data.field_name
|
||||||
|
if var_name:
|
||||||
|
self._static_deps.setdefault(state_name, set()).add(var_name)
|
||||||
|
objclass.get_root_state().get_class_substate(
|
||||||
|
state_name
|
||||||
|
)._var_dependencies.setdefault(var_name, set()).add(
|
||||||
|
(objclass.get_full_name(), self._js_expr)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
raise VarDependencyError(
|
||||||
|
"ComputedVar dependencies must be Var instances with a state and "
|
||||||
|
f"field name, got {dep!r}."
|
||||||
|
)
|
||||||
|
|
||||||
def _determine_var_type(self) -> Type:
|
def _determine_var_type(self) -> Type:
|
||||||
"""Get the type of the var.
|
"""Get the type of the var.
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import dis
|
|||||||
import enum
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
from types import CodeType, FunctionType
|
from types import CodeType, FunctionType
|
||||||
from typing import TYPE_CHECKING, ClassVar, Type, cast
|
from typing import TYPE_CHECKING, Any, ClassVar, Type, cast
|
||||||
|
|
||||||
from reflex.utils.exceptions import VarValueError
|
from reflex.utils.exceptions import VarValueError
|
||||||
|
|
||||||
@ -18,6 +18,24 @@ if TYPE_CHECKING:
|
|||||||
from .base import Var
|
from .base import Var
|
||||||
|
|
||||||
|
|
||||||
|
CellEmpty = object()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cell_value(cell) -> Any:
|
||||||
|
"""Get the value of a cell object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: The cell object to get the value from. (func.__closure__ objects)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value from the cell or CellEmpty if a ValueError is raised.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return cell.cell_contents
|
||||||
|
except ValueError:
|
||||||
|
return CellEmpty
|
||||||
|
|
||||||
|
|
||||||
class ScanStatus(enum.Enum):
|
class ScanStatus(enum.Enum):
|
||||||
"""State of the dis instruction scanning loop."""
|
"""State of the dis instruction scanning loop."""
|
||||||
|
|
||||||
@ -124,6 +142,33 @@ class DependencyTracker:
|
|||||||
instruction.argval
|
instruction.argval
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_globals(self) -> dict[str, Any]:
|
||||||
|
"""Get the globals of the function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The var names and values in the globals of the function.
|
||||||
|
"""
|
||||||
|
if isinstance(self.func, CodeType):
|
||||||
|
return {}
|
||||||
|
return self.func.__globals__ # pyright: ignore[reportGeneralTypeIssues]
|
||||||
|
|
||||||
|
def _get_closure(self) -> dict[str, Any]:
|
||||||
|
"""Get the closure of the function, with unbound values omitted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The var names and values in the closure of the function.
|
||||||
|
"""
|
||||||
|
if isinstance(self.func, CodeType):
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
var_name: get_cell_value(cell)
|
||||||
|
for var_name, cell in zip(
|
||||||
|
self.func.__code__.co_freevars, # pyright: ignore[reportGeneralTypeIssues]
|
||||||
|
self.func.__closure__, # pyright: ignore[reportGeneralTypeIssues]
|
||||||
|
)
|
||||||
|
if get_cell_value(cell) is not CellEmpty
|
||||||
|
}
|
||||||
|
|
||||||
def handle_getting_state(self, instruction: dis.Instruction) -> None:
|
def handle_getting_state(self, instruction: dis.Instruction) -> None:
|
||||||
"""Handle bytecode analysis when `get_state` was called in the function.
|
"""Handle bytecode analysis when `get_state` was called in the function.
|
||||||
|
|
||||||
@ -153,9 +198,7 @@ class DependencyTracker:
|
|||||||
if instruction.opname == "LOAD_GLOBAL":
|
if instruction.opname == "LOAD_GLOBAL":
|
||||||
# Special case: referencing state class from global scope.
|
# Special case: referencing state class from global scope.
|
||||||
try:
|
try:
|
||||||
self._getting_state_class = inspect.getclosurevars(self.func).globals[
|
self._getting_state_class = self._get_globals()[instruction.argval]
|
||||||
instruction.argval
|
|
||||||
]
|
|
||||||
except (ValueError, KeyError) as ve:
|
except (ValueError, KeyError) as ve:
|
||||||
raise VarValueError(
|
raise VarValueError(
|
||||||
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
|
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
|
||||||
@ -163,11 +206,10 @@ class DependencyTracker:
|
|||||||
elif instruction.opname == "LOAD_DEREF":
|
elif instruction.opname == "LOAD_DEREF":
|
||||||
# Special case: referencing state class from closure.
|
# Special case: referencing state class from closure.
|
||||||
try:
|
try:
|
||||||
closure = inspect.getclosurevars(self.func).nonlocals
|
self._getting_state_class = self._get_closure()[instruction.argval]
|
||||||
self._getting_state_class = closure[instruction.argval]
|
except (ValueError, KeyError) as ve:
|
||||||
except ValueError as ve:
|
|
||||||
raise VarValueError(
|
raise VarValueError(
|
||||||
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
|
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
|
||||||
) from ve
|
) from ve
|
||||||
elif instruction.opname == "STORE_FAST":
|
elif instruction.opname == "STORE_FAST":
|
||||||
# Storing the result of get_state in a local variable.
|
# Storing the result of get_state in a local variable.
|
||||||
@ -192,15 +234,16 @@ class DependencyTracker:
|
|||||||
"""
|
"""
|
||||||
# Get the original source code and eval it to get the Var.
|
# Get the original source code and eval it to get the Var.
|
||||||
module = inspect.getmodule(self.func)
|
module = inspect.getmodule(self.func)
|
||||||
positions = self._getting_var_instructions[0].positions
|
positions0 = self._getting_var_instructions[0].positions
|
||||||
if module is None or positions is None:
|
positions1 = self._getting_var_instructions[-1].positions
|
||||||
|
if module is None or positions0 is None or positions1 is None:
|
||||||
raise VarValueError(
|
raise VarValueError(
|
||||||
f"Cannot determine the source code for the var in {self.func!r}."
|
f"Cannot determine the source code for the var in {self.func!r}."
|
||||||
)
|
)
|
||||||
start_line = positions.lineno
|
start_line = positions0.lineno
|
||||||
start_column = positions.col_offset
|
start_column = positions0.col_offset
|
||||||
end_line = positions.end_lineno
|
end_line = positions1.end_lineno
|
||||||
end_column = positions.end_col_offset
|
end_column = positions1.end_col_offset
|
||||||
if (
|
if (
|
||||||
start_line is None
|
start_line is None
|
||||||
or start_column is None
|
or start_column is None
|
||||||
@ -217,23 +260,13 @@ class DependencyTracker:
|
|||||||
[
|
[
|
||||||
*source[0][start_column:],
|
*source[0][start_column:],
|
||||||
*(source[1:-2] if len(source) > 2 else []),
|
*(source[1:-2] if len(source) > 2 else []),
|
||||||
*source[-1][:end_column],
|
*source[-1][: end_column - 1],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
snipped_source = source[0][start_column:end_column]
|
snipped_source = source[0][start_column : end_column - 1]
|
||||||
# Fallback if the closure is not available.
|
|
||||||
globals = {}
|
|
||||||
closure = {}
|
|
||||||
try:
|
|
||||||
if not isinstance(self.func, CodeType):
|
|
||||||
closurevars = inspect.getclosurevars(self.func)
|
|
||||||
closure = closurevars.nonlocals
|
|
||||||
globals = dict(closurevars.globals)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
# Evaluate the string in the context of the function's globals and closure.
|
# Evaluate the string in the context of the function's globals and closure.
|
||||||
return eval(f"({snipped_source})", globals, closure)
|
return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
|
||||||
|
|
||||||
def handle_getting_var(self, instruction: dis.Instruction) -> None:
|
def handle_getting_var(self, instruction: dis.Instruction) -> None:
|
||||||
"""Handle bytecode analysis when `get_var_value` was called in the function.
|
"""Handle bytecode analysis when `get_var_value` was called in the function.
|
||||||
|
@ -14,6 +14,7 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Callable,
|
Callable,
|
||||||
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
@ -3883,3 +3884,58 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
|
|||||||
assert await child.v == 2
|
assert await child.v == 2
|
||||||
root.parent_var = 2
|
root.parent_var = 2
|
||||||
assert await child.v == 3
|
assert await child.v == 3
|
||||||
|
|
||||||
|
|
||||||
|
class Table(rx.ComponentState):
|
||||||
|
"""A table state."""
|
||||||
|
|
||||||
|
data: ClassVar[Var]
|
||||||
|
|
||||||
|
@rx.var(cache=True, auto_deps=False)
|
||||||
|
async def rows(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Computed var over the given rows.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The data rows.
|
||||||
|
"""
|
||||||
|
return await self.get_var_value(self.data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_component(cls, data: Var) -> rx.Component:
|
||||||
|
"""Get the component for the table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data var.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The component.
|
||||||
|
"""
|
||||||
|
cls.data = data
|
||||||
|
cls.computed_vars["rows"].add_dependency(cls, data)
|
||||||
|
return rx.foreach(data, lambda d: rx.text(d.to_string()))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str):
|
||||||
|
"""A test where an async computed var depends on a var in another state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mock_app: An app that will be returned by `get_app()`
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class OtherState(rx.State):
|
||||||
|
"""A state with a var."""
|
||||||
|
|
||||||
|
data: List[Dict[str, Any]] = [{"foo": "bar"}]
|
||||||
|
|
||||||
|
mock_app.state_manager.state = mock_app._state = rx.State
|
||||||
|
comp = Table.create(data=OtherState.data)
|
||||||
|
state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
|
||||||
|
other_state = await state.get_state(OtherState)
|
||||||
|
assert comp.State is not None
|
||||||
|
comp_state = await state.get_state(comp.State)
|
||||||
|
assert comp_state.dirty_vars == set()
|
||||||
|
|
||||||
|
other_state.data.append({"foo": "baz"})
|
||||||
|
assert "rows" in comp_state.dirty_vars
|
||||||
|
Loading…
Reference in New Issue
Block a user