Move computed var dep tracking to separate module

This commit is contained in:
Masen Furer 2025-01-28 15:30:27 -08:00
parent 8f4b12d29f
commit aa5b3037d7
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4
3 changed files with 357 additions and 167 deletions

View File

@ -797,7 +797,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
*defining_state_cls.inherited_vars,
*defining_state_cls.inherited_backend_vars,
}:
defining_state_cls = defining_state_cls.get_parent_state()
parent_state = defining_state_cls.get_parent_state()
if parent_state is not None:
defining_state_cls = parent_state
defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
(cls.get_full_name(), cvar_name)
)
@ -2721,7 +2723,7 @@ class StateProxy(wrapt.ObjectProxy):
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
)
def _as_state_update(self, *args, **kwargs) -> StateUpdate:
async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
"""Temporarily allow mutability to access parent_state.
Args:
@ -2734,7 +2736,7 @@ class StateProxy(wrapt.ObjectProxy):
original_mutable = self._self_mutable
self._self_mutable = True
try:
return self.__wrapped__._as_state_update(*args, **kwargs)
return await self.__wrapped__._as_state_update(*args, **kwargs)
finally:
self._self_mutable = original_mutable
@ -3366,10 +3368,10 @@ class StateManagerRedis(StateManager):
state.parent_state = parent_state
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
# the top-level state which should always be fetched or already cached.
if top_level:
return state._get_root_state()
return state
return flat_state_tree[self.state.get_full_name()]
return flat_state_tree[state_cls.get_full_name()]
@override
async def set_state(
@ -4070,7 +4072,7 @@ def reload_state_module(
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
state._potentially_dirty_substates.discard(subclass.get_name())
state._potentially_dirty_states.discard(subclass.get_name())
state._var_dependencies = {}
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()

View File

@ -5,7 +5,6 @@ from __future__ import annotations
import contextlib
import dataclasses
import datetime
import dis
import functools
import inspect
import json
@ -20,6 +19,7 @@ from typing import (
Any,
Callable,
ClassVar,
Coroutine,
Dict,
FrozenSet,
Generic,
@ -50,7 +50,6 @@ from reflex.utils.exceptions import (
VarAttributeError,
VarDependencyError,
VarTypeError,
VarValueError,
)
from reflex.utils.format import format_state_name
from reflex.utils.imports import (
@ -2073,9 +2072,8 @@ class ComputedVar(Var[RETURN_TYPE]):
def _deps(
self,
objclass: BaseState,
objclass: Type[BaseState],
obj: FunctionType | CodeType | None = None,
self_names: Optional[dict[str, str]] = None,
) -> dict[str, set[str]]:
"""Determine var dependencies of this ComputedVar.
@ -2087,17 +2085,12 @@ class ComputedVar(Var[RETURN_TYPE]):
Args:
objclass: the class obj this ComputedVar is attached to.
obj: the object to disassemble (defaults to the fget function).
self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions.
Returns:
A dictionary mapping state names to the set of variable names
accessed by the given obj.
Raises:
VarValueError: if the function references the get_state, parent_state, or substates attributes
(cannot track deps in a related state, only implicitly via parent state).
"""
from reflex.state import BaseState
from .dep_tracking import DependencyTracker
d = {}
if self._static_deps:
@ -2108,158 +2101,26 @@ class ComputedVar(Var[RETURN_TYPE]):
if not self._auto_deps:
return d
if obj is None:
fget = self._fget
if fget is not None:
obj = cast(FunctionType, fget)
else:
return d
with contextlib.suppress(AttributeError):
# unbox functools.partial
obj = cast(FunctionType, obj.func) # type: ignore
with contextlib.suppress(AttributeError):
# unbox EventHandler
obj = cast(FunctionType, obj.fn) # type: ignore
if self_names is None and isinstance(obj, FunctionType):
try:
# the first argument to the function is the name of "self" arg
self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()}
except (AttributeError, IndexError):
self_names = None
if self_names is None:
# cannot reference attributes on self if method takes no args
try:
return DependencyTracker(
func=obj, state_cls=objclass, dependencies=d
).dependencies
except Exception as e:
console.warn(
"Failed to automatically determine dependencies for computed var "
f"{objclass.__name__}.{self._js_expr}: {e}. "
"Provide static_deps and set auto_deps=False to suppress this warning."
)
return d
invalid_names = ["parent_state", "substates", "get_substate"]
self_on_top_of_stack = None
getting_state = False
getting_var = False
for instruction in dis.get_instructions(obj):
if getting_state:
if instruction.opname == "LOAD_FAST":
raise VarValueError(
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
)
if instruction.opname == "LOAD_GLOBAL":
# Special case: referencing state class from global scope.
getting_state = obj.__globals__.get(instruction.argval)
elif instruction.opname == "LOAD_DEREF":
# Special case: referencing state class from closure.
closure = dict(zip(obj.__code__.co_freevars, obj.__closure__))
try:
getting_state = closure[instruction.argval].cell_contents
except ValueError as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
) from ve
elif instruction.opname == "STORE_FAST":
# Storing the result of get_state in a local variable.
if not isinstance(getting_state, type) or not issubclass(
getting_state, BaseState
):
raise VarValueError(
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
)
self_names[instruction.argval] = getting_state.get_full_name()
getting_state = False
continue # nothing else happens until we have identified the local var
if getting_var:
if instruction.opname == "CALL":
# get the original source code and eval it
start_line = getting_var[0].positions.lineno
start_column = getting_var[0].positions.col_offset
end_line = getting_var[-1].positions.end_lineno
end_column = getting_var[-1].positions.end_col_offset
source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[
start_line - 1 : end_line
]
if len(source) > 1:
snipped_source = "".join(
[
source[0][start_column:],
source[1:-2] if len(source) > 2 else "",
source[-1][:end_column],
]
)
else:
snipped_source = source[0][start_column:end_column]
the_var = eval(f"({snipped_source})", obj.__globals__)
the_var_data = the_var._get_all_var_data()
d.setdefault(the_var_data.state, set()).add(the_var_data.field_name)
getting_var = False
elif isinstance(getting_var, list):
getting_var.append(instruction)
else:
getting_var = [instruction]
continue
if (
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
and instruction.argval in self_names
):
# bytecode loaded the class instance to the top of stack, next load instruction
# is referencing an attribute on self
self_on_top_of_stack = self_names[instruction.argval]
continue
if self_on_top_of_stack and instruction.opname in (
"LOAD_ATTR",
"LOAD_METHOD",
):
if instruction.argval in invalid_names:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
)
if instruction.argval == "get_state":
# Special case: arbitrary state access requested.
getting_state = True
continue
if instruction.argval == "get_var_value":
# Special case: arbitrary var access requested.
getting_var = True
continue
target_state = objclass.get_root_state().get_class_substate(
self_on_top_of_stack
)
try:
ref_obj = getattr(target_state, instruction.argval)
except Exception:
ref_obj = None
if callable(ref_obj):
# recurse into callable attributes
for state_name, dep_name in self._deps(
objclass=target_state,
obj=ref_obj,
).items():
d.setdefault(state_name, set()).update(dep_name)
# recurse into property fget functions
elif isinstance(ref_obj, property) and not isinstance(
ref_obj, ComputedVar
):
for state_name, dep_name in self._deps(
objclass=target_state,
obj=ref_obj.fget, # type: ignore
).items():
d.setdefault(state_name, set()).update(dep_name)
elif (
instruction.argval in target_state.backend_vars
or instruction.argval in target_state.vars
):
# var access
d.setdefault(self_on_top_of_stack, set()).add(instruction.argval)
elif instruction.opname == "LOAD_CONST" and isinstance(
instruction.argval, CodeType
):
# recurse into nested functions / comprehensions, which can reference
# instance attributes from the outer scope
for state_name, dep_name in self._deps(
objclass=objclass,
obj=instruction.argval,
self_names=self_names,
).items():
d.setdefault(state_name, set()).update(dep_name)
self_on_top_of_stack = None
return d
def mark_dirty(self, instance) -> None:
"""Mark this ComputedVar as dirty.
@ -2314,11 +2175,63 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
"""A computed var that wraps a coroutinefunction."""
_fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field(
default_factory=lambda: lambda _: None
) # type: ignore
_fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = (
dataclasses.field()
)
def __get__(self, instance: BaseState | None, owner):
@overload
def __get__(
self: AsyncComputedVar[int] | AsyncComputedVar[float],
instance: None,
owner: Type,
) -> NumberVar: ...
@overload
def __get__(
self: AsyncComputedVar[str],
instance: None,
owner: Type,
) -> StringVar: ...
@overload
def __get__(
self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]],
instance: None,
owner: Type,
) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
@overload
def __get__(
self: AsyncComputedVar[list[LIST_INSIDE]],
instance: None,
owner: Type,
) -> ArrayVar[list[LIST_INSIDE]]: ...
@overload
def __get__(
self: AsyncComputedVar[set[LIST_INSIDE]],
instance: None,
owner: Type,
) -> ArrayVar[set[LIST_INSIDE]]: ...
@overload
def __get__(
self: AsyncComputedVar[tuple[LIST_INSIDE, ...]],
instance: None,
owner: Type,
) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
@overload
def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ...
@overload
def __get__(
self, instance: BaseState, owner: Type
) -> Coroutine[None, None, RETURN_TYPE]: ...
def __get__(
self, instance: BaseState | None, owner
) -> Var | Coroutine[None, None, RETURN_TYPE]:
"""Get the ComputedVar value.
If the value is already cached on the instance, return the cached value.
@ -2335,14 +2248,15 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
if not self._cache:
async def _awaitable_result():
async def _awaitable_result(instance=instance) -> RETURN_TYPE:
value = await self.fget(instance)
self._check_deprecated_return_type(instance, value)
return value
return _awaitable_result()
else:
# handle caching
async def _awaitable_result():
async def _awaitable_result(instance=instance) -> RETURN_TYPE:
if not hasattr(instance, self._cache_attr) or self.needs_update(
instance
):
@ -2356,7 +2270,16 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
self._check_deprecated_return_type(instance, value)
return value
return _awaitable_result()
return _awaitable_result()
@property
def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]:
"""Get the getter function.
Returns:
The getter function.
"""
return self._fget
if TYPE_CHECKING:

265
reflex/vars/dep_tracking.py Normal file
View File

@ -0,0 +1,265 @@
"""Collection of base classes."""
from __future__ import annotations
import contextlib
import dataclasses
import dis
import enum
import inspect
from types import CodeType, FunctionType
from typing import TYPE_CHECKING, ClassVar, Type, cast
from reflex.utils.exceptions import VarValueError
if TYPE_CHECKING:
from reflex.state import BaseState
from .base import Var
class ScanStatus(enum.Enum):
"""State of the dis instruction scanning loop."""
SCANNING = enum.auto()
GETTING_ATTR = enum.auto()
GETTING_STATE = enum.auto()
GETTING_VAR = enum.auto()
@dataclasses.dataclass
class DependencyTracker:
"""State machine for identifying state attributes that are accessed by a function."""
func: FunctionType | CodeType = dataclasses.field()
state_cls: Type[BaseState] = dataclasses.field()
dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
top_of_stack: str | None = dataclasses.field(default=None)
tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict)
_getting_state_class: Type[BaseState] | None = dataclasses.field(default=None)
_getting_var_instructions: list[dis.Instruction] = dataclasses.field(
default_factory=list
)
INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
def __post_init__(self):
"""After initializing, populate the dependencies dict."""
with contextlib.suppress(AttributeError):
# unbox functools.partial
self.func = cast(FunctionType, self.func.func) # type: ignore
with contextlib.suppress(AttributeError):
# unbox EventHandler
self.func = cast(FunctionType, self.func.fn) # type: ignore
if isinstance(self.func, FunctionType):
with contextlib.suppress(AttributeError, IndexError):
# the first argument to the function is the name of "self" arg
self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
self._populate_dependencies()
def _merge_deps(self, tracker: DependencyTracker) -> None:
"""Merge dependencies from another DependencyTracker.
Args:
tracker: The DependencyTracker to merge dependencies from.
"""
for state_name, dep_name in tracker.dependencies.items():
self.dependencies.setdefault(state_name, set()).update(dep_name)
def load_attr_or_method(self, instruction: dis.Instruction) -> None:
"""Handle loading an attribute or method from the object on top of the stack.
This method directly tracks attributes and recursively merges
dependencies from analyzing the dependencies of any methods called.
Args:
instruction: The dis instruction to process.
"""
from .base import ComputedVar
if instruction.argval in self.INVALID_NAMES:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
)
if instruction.argval == "get_state":
# Special case: arbitrary state access requested.
self.scan_status = ScanStatus.GETTING_STATE
return
if instruction.argval == "get_var_value":
# Special case: arbitrary var access requested.
self.scan_status = ScanStatus.GETTING_VAR
return
# Reset status back to SCANNING after attribute is accessed.
self.scan_status = ScanStatus.SCANNING
if not self.top_of_stack:
return
target_state = self.tracked_locals[self.top_of_stack]
ref_obj = getattr(target_state, instruction.argval)
if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
# recurse into property fget functions
ref_obj = ref_obj.fget # type: ignore
if callable(ref_obj):
# recurse into callable attributes
self._merge_deps(
type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
)
elif (
instruction.argval in target_state.backend_vars
or instruction.argval in target_state.vars
):
# var access
self.dependencies.setdefault(target_state.get_full_name(), set()).add(
instruction.argval
)
def handle_getting_state(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis when `get_state` was called in the function.
If the wrapped function is getting an arbitrary state and saving it to a
local variable, this method associates the local variable name with the
state class in self.tracked_locals.
When an attribute/method is accessed on a tracked local, it will be
associated with this state.
Args:
instruction: The dis instruction to process.
"""
from reflex.state import BaseState
if instruction.opname == "LOAD_FAST":
raise VarValueError(
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
)
if instruction.opname == "LOAD_GLOBAL":
# Special case: referencing state class from global scope.
self._getting_state_class = self.func.__globals__.get(instruction.argval)
elif instruction.opname == "LOAD_DEREF":
# Special case: referencing state class from closure.
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__))
try:
self._getting_state_class = closure[instruction.argval].cell_contents
except ValueError as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
) from ve
elif instruction.opname == "STORE_FAST":
# Storing the result of get_state in a local variable.
if not isinstance(self._getting_state_class, type) or not issubclass(
self._getting_state_class, BaseState
):
raise VarValueError(
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
)
self.tracked_locals[instruction.argval] = self._getting_state_class
self.scan_status = ScanStatus.SCANNING
self._getting_state_class = None
def _eval_var(self) -> Var:
"""Evaluate instructions from the wrapped function to get the Var object.
Returns:
The Var object.
"""
# Get the original source code and eval it to get the Var.
start_line = self._getting_var_instructions[0].positions.lineno
start_column = self._getting_var_instructions[0].positions.col_offset
end_line = self._getting_var_instructions[-1].positions.end_lineno
end_column = self._getting_var_instructions[-1].positions.end_col_offset
source = inspect.getsource(inspect.getmodule(self.func)).splitlines(True)[
start_line - 1 : end_line
]
# Create a python source string snippet.
if len(source) > 1:
snipped_source = "".join(
[
source[0][start_column:],
source[1:-2] if len(source) > 2 else "",
source[-1][:end_column],
]
)
else:
snipped_source = source[0][start_column:end_column]
try:
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__))
except Exception:
# Fallback if the closure is not available.
closure = {}
# Evaluate the string in the context of the function's globals and closure.
return eval(f"({snipped_source})", self.func.__globals__, closure)
def handle_getting_var(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis when `get_var_value` was called in the function.
This only really works if the expression passed to `get_var_value` is
evaluable in the function's global scope or closure, so getting the var
value from a var saved in a local variable or in the class instance is
not possible.
Args:
instruction: The dis instruction to process.
"""
if instruction.opname == "CALL" and self._getting_var_instructions:
if self._getting_var_instructions:
the_var = self._eval_var()
the_var_data = the_var._get_all_var_data()
self.dependencies.setdefault(the_var_data.state, set()).add(
the_var_data.field_name
)
self._getting_var_instructions.clear()
self.scan_status = ScanStatus.SCANNING
else:
self._getting_var_instructions.append(instruction)
def _populate_dependencies(self) -> None:
"""Update self.dependencies based on the disassembly of self.func.
Save references to attributes accessed on "self" or other fetched states.
Recursively called when the function makes a method call on "self" or
define comprehensions or nested functions that may reference "self".
Raises:
VarValueError: if the function references the get_state, parent_state, or substates attributes
(cannot track deps in a related state, only implicitly via parent state).
"""
for instruction in dis.get_instructions(self.func):
if self.scan_status == ScanStatus.GETTING_STATE:
self.handle_getting_state(instruction)
elif self.scan_status == ScanStatus.GETTING_VAR:
self.handle_getting_var(instruction)
elif (
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
and instruction.argval in self.tracked_locals
):
# bytecode loaded the class instance to the top of stack, next load instruction
# is referencing an attribute on self
self.top_of_stack = instruction.argval
self.scan_status = ScanStatus.GETTING_ATTR
elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
"LOAD_ATTR",
"LOAD_METHOD",
):
self.load_attr_or_method(instruction)
self.top_of_stack = None
elif instruction.opname == "LOAD_CONST" and isinstance(
instruction.argval, CodeType
):
# recurse into nested functions / comprehensions, which can reference
# instance attributes from the outer scope
self._merge_deps(
type(self)(
func=instruction.argval,
state_cls=self.state_cls,
tracked_locals=self.tracked_locals,
)
)