Move computed var dep tracking to separate module
This commit is contained in:
parent
8f4b12d29f
commit
aa5b3037d7
@ -797,7 +797,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
*defining_state_cls.inherited_vars,
|
*defining_state_cls.inherited_vars,
|
||||||
*defining_state_cls.inherited_backend_vars,
|
*defining_state_cls.inherited_backend_vars,
|
||||||
}:
|
}:
|
||||||
defining_state_cls = defining_state_cls.get_parent_state()
|
parent_state = defining_state_cls.get_parent_state()
|
||||||
|
if parent_state is not None:
|
||||||
|
defining_state_cls = parent_state
|
||||||
defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
|
defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
|
||||||
(cls.get_full_name(), cvar_name)
|
(cls.get_full_name(), cvar_name)
|
||||||
)
|
)
|
||||||
@ -2721,7 +2723,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
|
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
|
||||||
)
|
)
|
||||||
|
|
||||||
def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
||||||
"""Temporarily allow mutability to access parent_state.
|
"""Temporarily allow mutability to access parent_state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2734,7 +2736,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
original_mutable = self._self_mutable
|
original_mutable = self._self_mutable
|
||||||
self._self_mutable = True
|
self._self_mutable = True
|
||||||
try:
|
try:
|
||||||
return self.__wrapped__._as_state_update(*args, **kwargs)
|
return await self.__wrapped__._as_state_update(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
self._self_mutable = original_mutable
|
self._self_mutable = original_mutable
|
||||||
|
|
||||||
@ -3366,10 +3368,10 @@ class StateManagerRedis(StateManager):
|
|||||||
state.parent_state = parent_state
|
state.parent_state = parent_state
|
||||||
|
|
||||||
# To retain compatibility with previous implementation, by default, we return
|
# To retain compatibility with previous implementation, by default, we return
|
||||||
# the top-level state by chasing `parent_state` pointers up the tree.
|
# the top-level state which should always be fetched or already cached.
|
||||||
if top_level:
|
if top_level:
|
||||||
return state._get_root_state()
|
return flat_state_tree[self.state.get_full_name()]
|
||||||
return state
|
return flat_state_tree[state_cls.get_full_name()]
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def set_state(
|
async def set_state(
|
||||||
@ -4070,7 +4072,7 @@ def reload_state_module(
|
|||||||
if subclass.__module__ == module and module is not None:
|
if subclass.__module__ == module and module is not None:
|
||||||
state.class_subclasses.remove(subclass)
|
state.class_subclasses.remove(subclass)
|
||||||
state._always_dirty_substates.discard(subclass.get_name())
|
state._always_dirty_substates.discard(subclass.get_name())
|
||||||
state._potentially_dirty_substates.discard(subclass.get_name())
|
state._potentially_dirty_states.discard(subclass.get_name())
|
||||||
state._var_dependencies = {}
|
state._var_dependencies = {}
|
||||||
state._init_var_dependency_dicts()
|
state._init_var_dependency_dicts()
|
||||||
state.get_class_substate.cache_clear()
|
state.get_class_substate.cache_clear()
|
||||||
|
@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import datetime
|
import datetime
|
||||||
import dis
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@ -20,6 +19,7 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
Generic,
|
Generic,
|
||||||
@ -50,7 +50,6 @@ from reflex.utils.exceptions import (
|
|||||||
VarAttributeError,
|
VarAttributeError,
|
||||||
VarDependencyError,
|
VarDependencyError,
|
||||||
VarTypeError,
|
VarTypeError,
|
||||||
VarValueError,
|
|
||||||
)
|
)
|
||||||
from reflex.utils.format import format_state_name
|
from reflex.utils.format import format_state_name
|
||||||
from reflex.utils.imports import (
|
from reflex.utils.imports import (
|
||||||
@ -2073,9 +2072,8 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
|
|
||||||
def _deps(
|
def _deps(
|
||||||
self,
|
self,
|
||||||
objclass: BaseState,
|
objclass: Type[BaseState],
|
||||||
obj: FunctionType | CodeType | None = None,
|
obj: FunctionType | CodeType | None = None,
|
||||||
self_names: Optional[dict[str, str]] = None,
|
|
||||||
) -> dict[str, set[str]]:
|
) -> dict[str, set[str]]:
|
||||||
"""Determine var dependencies of this ComputedVar.
|
"""Determine var dependencies of this ComputedVar.
|
||||||
|
|
||||||
@ -2087,17 +2085,12 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
Args:
|
Args:
|
||||||
objclass: the class obj this ComputedVar is attached to.
|
objclass: the class obj this ComputedVar is attached to.
|
||||||
obj: the object to disassemble (defaults to the fget function).
|
obj: the object to disassemble (defaults to the fget function).
|
||||||
self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping state names to the set of variable names
|
A dictionary mapping state names to the set of variable names
|
||||||
accessed by the given obj.
|
accessed by the given obj.
|
||||||
|
|
||||||
Raises:
|
|
||||||
VarValueError: if the function references the get_state, parent_state, or substates attributes
|
|
||||||
(cannot track deps in a related state, only implicitly via parent state).
|
|
||||||
"""
|
"""
|
||||||
from reflex.state import BaseState
|
from .dep_tracking import DependencyTracker
|
||||||
|
|
||||||
d = {}
|
d = {}
|
||||||
if self._static_deps:
|
if self._static_deps:
|
||||||
@ -2108,158 +2101,26 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
|
|
||||||
if not self._auto_deps:
|
if not self._auto_deps:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
if obj is None:
|
if obj is None:
|
||||||
fget = self._fget
|
fget = self._fget
|
||||||
if fget is not None:
|
if fget is not None:
|
||||||
obj = cast(FunctionType, fget)
|
obj = cast(FunctionType, fget)
|
||||||
else:
|
else:
|
||||||
return d
|
return d
|
||||||
with contextlib.suppress(AttributeError):
|
|
||||||
# unbox functools.partial
|
|
||||||
obj = cast(FunctionType, obj.func) # type: ignore
|
|
||||||
with contextlib.suppress(AttributeError):
|
|
||||||
# unbox EventHandler
|
|
||||||
obj = cast(FunctionType, obj.fn) # type: ignore
|
|
||||||
|
|
||||||
if self_names is None and isinstance(obj, FunctionType):
|
try:
|
||||||
try:
|
return DependencyTracker(
|
||||||
# the first argument to the function is the name of "self" arg
|
func=obj, state_cls=objclass, dependencies=d
|
||||||
self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()}
|
).dependencies
|
||||||
except (AttributeError, IndexError):
|
except Exception as e:
|
||||||
self_names = None
|
console.warn(
|
||||||
if self_names is None:
|
"Failed to automatically determine dependencies for computed var "
|
||||||
# cannot reference attributes on self if method takes no args
|
f"{objclass.__name__}.{self._js_expr}: {e}. "
|
||||||
|
"Provide static_deps and set auto_deps=False to suppress this warning."
|
||||||
|
)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
invalid_names = ["parent_state", "substates", "get_substate"]
|
|
||||||
self_on_top_of_stack = None
|
|
||||||
getting_state = False
|
|
||||||
getting_var = False
|
|
||||||
for instruction in dis.get_instructions(obj):
|
|
||||||
if getting_state:
|
|
||||||
if instruction.opname == "LOAD_FAST":
|
|
||||||
raise VarValueError(
|
|
||||||
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
|
|
||||||
)
|
|
||||||
if instruction.opname == "LOAD_GLOBAL":
|
|
||||||
# Special case: referencing state class from global scope.
|
|
||||||
getting_state = obj.__globals__.get(instruction.argval)
|
|
||||||
elif instruction.opname == "LOAD_DEREF":
|
|
||||||
# Special case: referencing state class from closure.
|
|
||||||
closure = dict(zip(obj.__code__.co_freevars, obj.__closure__))
|
|
||||||
try:
|
|
||||||
getting_state = closure[instruction.argval].cell_contents
|
|
||||||
except ValueError as ve:
|
|
||||||
raise VarValueError(
|
|
||||||
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
|
|
||||||
) from ve
|
|
||||||
elif instruction.opname == "STORE_FAST":
|
|
||||||
# Storing the result of get_state in a local variable.
|
|
||||||
if not isinstance(getting_state, type) or not issubclass(
|
|
||||||
getting_state, BaseState
|
|
||||||
):
|
|
||||||
raise VarValueError(
|
|
||||||
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
|
|
||||||
)
|
|
||||||
self_names[instruction.argval] = getting_state.get_full_name()
|
|
||||||
getting_state = False
|
|
||||||
continue # nothing else happens until we have identified the local var
|
|
||||||
if getting_var:
|
|
||||||
if instruction.opname == "CALL":
|
|
||||||
# get the original source code and eval it
|
|
||||||
start_line = getting_var[0].positions.lineno
|
|
||||||
start_column = getting_var[0].positions.col_offset
|
|
||||||
end_line = getting_var[-1].positions.end_lineno
|
|
||||||
end_column = getting_var[-1].positions.end_col_offset
|
|
||||||
source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[
|
|
||||||
start_line - 1 : end_line
|
|
||||||
]
|
|
||||||
if len(source) > 1:
|
|
||||||
snipped_source = "".join(
|
|
||||||
[
|
|
||||||
source[0][start_column:],
|
|
||||||
source[1:-2] if len(source) > 2 else "",
|
|
||||||
source[-1][:end_column],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
snipped_source = source[0][start_column:end_column]
|
|
||||||
the_var = eval(f"({snipped_source})", obj.__globals__)
|
|
||||||
the_var_data = the_var._get_all_var_data()
|
|
||||||
d.setdefault(the_var_data.state, set()).add(the_var_data.field_name)
|
|
||||||
getting_var = False
|
|
||||||
elif isinstance(getting_var, list):
|
|
||||||
getting_var.append(instruction)
|
|
||||||
else:
|
|
||||||
getting_var = [instruction]
|
|
||||||
continue
|
|
||||||
if (
|
|
||||||
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
|
||||||
and instruction.argval in self_names
|
|
||||||
):
|
|
||||||
# bytecode loaded the class instance to the top of stack, next load instruction
|
|
||||||
# is referencing an attribute on self
|
|
||||||
self_on_top_of_stack = self_names[instruction.argval]
|
|
||||||
continue
|
|
||||||
if self_on_top_of_stack and instruction.opname in (
|
|
||||||
"LOAD_ATTR",
|
|
||||||
"LOAD_METHOD",
|
|
||||||
):
|
|
||||||
if instruction.argval in invalid_names:
|
|
||||||
raise VarValueError(
|
|
||||||
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
|
|
||||||
)
|
|
||||||
if instruction.argval == "get_state":
|
|
||||||
# Special case: arbitrary state access requested.
|
|
||||||
getting_state = True
|
|
||||||
continue
|
|
||||||
if instruction.argval == "get_var_value":
|
|
||||||
# Special case: arbitrary var access requested.
|
|
||||||
getting_var = True
|
|
||||||
continue
|
|
||||||
target_state = objclass.get_root_state().get_class_substate(
|
|
||||||
self_on_top_of_stack
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
ref_obj = getattr(target_state, instruction.argval)
|
|
||||||
except Exception:
|
|
||||||
ref_obj = None
|
|
||||||
if callable(ref_obj):
|
|
||||||
# recurse into callable attributes
|
|
||||||
for state_name, dep_name in self._deps(
|
|
||||||
objclass=target_state,
|
|
||||||
obj=ref_obj,
|
|
||||||
).items():
|
|
||||||
d.setdefault(state_name, set()).update(dep_name)
|
|
||||||
# recurse into property fget functions
|
|
||||||
elif isinstance(ref_obj, property) and not isinstance(
|
|
||||||
ref_obj, ComputedVar
|
|
||||||
):
|
|
||||||
for state_name, dep_name in self._deps(
|
|
||||||
objclass=target_state,
|
|
||||||
obj=ref_obj.fget, # type: ignore
|
|
||||||
).items():
|
|
||||||
d.setdefault(state_name, set()).update(dep_name)
|
|
||||||
elif (
|
|
||||||
instruction.argval in target_state.backend_vars
|
|
||||||
or instruction.argval in target_state.vars
|
|
||||||
):
|
|
||||||
# var access
|
|
||||||
d.setdefault(self_on_top_of_stack, set()).add(instruction.argval)
|
|
||||||
elif instruction.opname == "LOAD_CONST" and isinstance(
|
|
||||||
instruction.argval, CodeType
|
|
||||||
):
|
|
||||||
# recurse into nested functions / comprehensions, which can reference
|
|
||||||
# instance attributes from the outer scope
|
|
||||||
for state_name, dep_name in self._deps(
|
|
||||||
objclass=objclass,
|
|
||||||
obj=instruction.argval,
|
|
||||||
self_names=self_names,
|
|
||||||
).items():
|
|
||||||
d.setdefault(state_name, set()).update(dep_name)
|
|
||||||
self_on_top_of_stack = None
|
|
||||||
return d
|
|
||||||
|
|
||||||
def mark_dirty(self, instance) -> None:
|
def mark_dirty(self, instance) -> None:
|
||||||
"""Mark this ComputedVar as dirty.
|
"""Mark this ComputedVar as dirty.
|
||||||
|
|
||||||
@ -2314,11 +2175,63 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
|
|||||||
class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
||||||
"""A computed var that wraps a coroutinefunction."""
|
"""A computed var that wraps a coroutinefunction."""
|
||||||
|
|
||||||
_fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field(
|
_fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = (
|
||||||
default_factory=lambda: lambda _: None
|
dataclasses.field()
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
def __get__(self, instance: BaseState | None, owner):
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[int] | AsyncComputedVar[float],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> NumberVar: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[str],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> StringVar: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[list[LIST_INSIDE]],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> ArrayVar[list[LIST_INSIDE]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[set[LIST_INSIDE]],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> ArrayVar[set[LIST_INSIDE]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[tuple[LIST_INSIDE, ...]],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self, instance: BaseState, owner: Type
|
||||||
|
) -> Coroutine[None, None, RETURN_TYPE]: ...
|
||||||
|
|
||||||
|
def __get__(
|
||||||
|
self, instance: BaseState | None, owner
|
||||||
|
) -> Var | Coroutine[None, None, RETURN_TYPE]:
|
||||||
"""Get the ComputedVar value.
|
"""Get the ComputedVar value.
|
||||||
|
|
||||||
If the value is already cached on the instance, return the cached value.
|
If the value is already cached on the instance, return the cached value.
|
||||||
@ -2335,14 +2248,15 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
|||||||
|
|
||||||
if not self._cache:
|
if not self._cache:
|
||||||
|
|
||||||
async def _awaitable_result():
|
async def _awaitable_result(instance=instance) -> RETURN_TYPE:
|
||||||
value = await self.fget(instance)
|
value = await self.fget(instance)
|
||||||
self._check_deprecated_return_type(instance, value)
|
self._check_deprecated_return_type(instance, value)
|
||||||
|
return value
|
||||||
|
|
||||||
return _awaitable_result()
|
return _awaitable_result()
|
||||||
else:
|
else:
|
||||||
# handle caching
|
# handle caching
|
||||||
async def _awaitable_result():
|
async def _awaitable_result(instance=instance) -> RETURN_TYPE:
|
||||||
if not hasattr(instance, self._cache_attr) or self.needs_update(
|
if not hasattr(instance, self._cache_attr) or self.needs_update(
|
||||||
instance
|
instance
|
||||||
):
|
):
|
||||||
@ -2356,7 +2270,16 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
|||||||
self._check_deprecated_return_type(instance, value)
|
self._check_deprecated_return_type(instance, value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
return _awaitable_result()
|
return _awaitable_result()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]:
|
||||||
|
"""Get the getter function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The getter function.
|
||||||
|
"""
|
||||||
|
return self._fget
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
265
reflex/vars/dep_tracking.py
Normal file
265
reflex/vars/dep_tracking.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
"""Collection of base classes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import dataclasses
|
||||||
|
import dis
|
||||||
|
import enum
|
||||||
|
import inspect
|
||||||
|
from types import CodeType, FunctionType
|
||||||
|
from typing import TYPE_CHECKING, ClassVar, Type, cast
|
||||||
|
|
||||||
|
from reflex.utils.exceptions import VarValueError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from reflex.state import BaseState
|
||||||
|
|
||||||
|
from .base import Var
|
||||||
|
|
||||||
|
|
||||||
|
class ScanStatus(enum.Enum):
|
||||||
|
"""State of the dis instruction scanning loop."""
|
||||||
|
|
||||||
|
SCANNING = enum.auto()
|
||||||
|
GETTING_ATTR = enum.auto()
|
||||||
|
GETTING_STATE = enum.auto()
|
||||||
|
GETTING_VAR = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DependencyTracker:
|
||||||
|
"""State machine for identifying state attributes that are accessed by a function."""
|
||||||
|
|
||||||
|
func: FunctionType | CodeType = dataclasses.field()
|
||||||
|
state_cls: Type[BaseState] = dataclasses.field()
|
||||||
|
|
||||||
|
dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
|
||||||
|
top_of_stack: str | None = dataclasses.field(default=None)
|
||||||
|
|
||||||
|
tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
_getting_state_class: Type[BaseState] | None = dataclasses.field(default=None)
|
||||||
|
_getting_var_instructions: list[dis.Instruction] = dataclasses.field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
|
INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""After initializing, populate the dependencies dict."""
|
||||||
|
with contextlib.suppress(AttributeError):
|
||||||
|
# unbox functools.partial
|
||||||
|
self.func = cast(FunctionType, self.func.func) # type: ignore
|
||||||
|
with contextlib.suppress(AttributeError):
|
||||||
|
# unbox EventHandler
|
||||||
|
self.func = cast(FunctionType, self.func.fn) # type: ignore
|
||||||
|
|
||||||
|
if isinstance(self.func, FunctionType):
|
||||||
|
with contextlib.suppress(AttributeError, IndexError):
|
||||||
|
# the first argument to the function is the name of "self" arg
|
||||||
|
self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
|
||||||
|
|
||||||
|
self._populate_dependencies()
|
||||||
|
|
||||||
|
def _merge_deps(self, tracker: DependencyTracker) -> None:
|
||||||
|
"""Merge dependencies from another DependencyTracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracker: The DependencyTracker to merge dependencies from.
|
||||||
|
"""
|
||||||
|
for state_name, dep_name in tracker.dependencies.items():
|
||||||
|
self.dependencies.setdefault(state_name, set()).update(dep_name)
|
||||||
|
|
||||||
|
def load_attr_or_method(self, instruction: dis.Instruction) -> None:
|
||||||
|
"""Handle loading an attribute or method from the object on top of the stack.
|
||||||
|
|
||||||
|
This method directly tracks attributes and recursively merges
|
||||||
|
dependencies from analyzing the dependencies of any methods called.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: The dis instruction to process.
|
||||||
|
"""
|
||||||
|
from .base import ComputedVar
|
||||||
|
|
||||||
|
if instruction.argval in self.INVALID_NAMES:
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
|
||||||
|
)
|
||||||
|
if instruction.argval == "get_state":
|
||||||
|
# Special case: arbitrary state access requested.
|
||||||
|
self.scan_status = ScanStatus.GETTING_STATE
|
||||||
|
return
|
||||||
|
if instruction.argval == "get_var_value":
|
||||||
|
# Special case: arbitrary var access requested.
|
||||||
|
self.scan_status = ScanStatus.GETTING_VAR
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reset status back to SCANNING after attribute is accessed.
|
||||||
|
self.scan_status = ScanStatus.SCANNING
|
||||||
|
if not self.top_of_stack:
|
||||||
|
return
|
||||||
|
target_state = self.tracked_locals[self.top_of_stack]
|
||||||
|
ref_obj = getattr(target_state, instruction.argval)
|
||||||
|
|
||||||
|
if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
|
||||||
|
# recurse into property fget functions
|
||||||
|
ref_obj = ref_obj.fget # type: ignore
|
||||||
|
if callable(ref_obj):
|
||||||
|
# recurse into callable attributes
|
||||||
|
self._merge_deps(
|
||||||
|
type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
instruction.argval in target_state.backend_vars
|
||||||
|
or instruction.argval in target_state.vars
|
||||||
|
):
|
||||||
|
# var access
|
||||||
|
self.dependencies.setdefault(target_state.get_full_name(), set()).add(
|
||||||
|
instruction.argval
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_getting_state(self, instruction: dis.Instruction) -> None:
|
||||||
|
"""Handle bytecode analysis when `get_state` was called in the function.
|
||||||
|
|
||||||
|
If the wrapped function is getting an arbitrary state and saving it to a
|
||||||
|
local variable, this method associates the local variable name with the
|
||||||
|
state class in self.tracked_locals.
|
||||||
|
|
||||||
|
When an attribute/method is accessed on a tracked local, it will be
|
||||||
|
associated with this state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: The dis instruction to process.
|
||||||
|
"""
|
||||||
|
from reflex.state import BaseState
|
||||||
|
|
||||||
|
if instruction.opname == "LOAD_FAST":
|
||||||
|
raise VarValueError(
|
||||||
|
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
|
||||||
|
)
|
||||||
|
if instruction.opname == "LOAD_GLOBAL":
|
||||||
|
# Special case: referencing state class from global scope.
|
||||||
|
self._getting_state_class = self.func.__globals__.get(instruction.argval)
|
||||||
|
elif instruction.opname == "LOAD_DEREF":
|
||||||
|
# Special case: referencing state class from closure.
|
||||||
|
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__))
|
||||||
|
try:
|
||||||
|
self._getting_state_class = closure[instruction.argval].cell_contents
|
||||||
|
except ValueError as ve:
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
|
||||||
|
) from ve
|
||||||
|
elif instruction.opname == "STORE_FAST":
|
||||||
|
# Storing the result of get_state in a local variable.
|
||||||
|
if not isinstance(self._getting_state_class, type) or not issubclass(
|
||||||
|
self._getting_state_class, BaseState
|
||||||
|
):
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
|
||||||
|
)
|
||||||
|
self.tracked_locals[instruction.argval] = self._getting_state_class
|
||||||
|
self.scan_status = ScanStatus.SCANNING
|
||||||
|
self._getting_state_class = None
|
||||||
|
|
||||||
|
def _eval_var(self) -> Var:
|
||||||
|
"""Evaluate instructions from the wrapped function to get the Var object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Var object.
|
||||||
|
"""
|
||||||
|
# Get the original source code and eval it to get the Var.
|
||||||
|
start_line = self._getting_var_instructions[0].positions.lineno
|
||||||
|
start_column = self._getting_var_instructions[0].positions.col_offset
|
||||||
|
end_line = self._getting_var_instructions[-1].positions.end_lineno
|
||||||
|
end_column = self._getting_var_instructions[-1].positions.end_col_offset
|
||||||
|
source = inspect.getsource(inspect.getmodule(self.func)).splitlines(True)[
|
||||||
|
start_line - 1 : end_line
|
||||||
|
]
|
||||||
|
# Create a python source string snippet.
|
||||||
|
if len(source) > 1:
|
||||||
|
snipped_source = "".join(
|
||||||
|
[
|
||||||
|
source[0][start_column:],
|
||||||
|
source[1:-2] if len(source) > 2 else "",
|
||||||
|
source[-1][:end_column],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
snipped_source = source[0][start_column:end_column]
|
||||||
|
try:
|
||||||
|
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__))
|
||||||
|
except Exception:
|
||||||
|
# Fallback if the closure is not available.
|
||||||
|
closure = {}
|
||||||
|
# Evaluate the string in the context of the function's globals and closure.
|
||||||
|
return eval(f"({snipped_source})", self.func.__globals__, closure)
|
||||||
|
|
||||||
|
def handle_getting_var(self, instruction: dis.Instruction) -> None:
|
||||||
|
"""Handle bytecode analysis when `get_var_value` was called in the function.
|
||||||
|
|
||||||
|
This only really works if the expression passed to `get_var_value` is
|
||||||
|
evaluable in the function's global scope or closure, so getting the var
|
||||||
|
value from a var saved in a local variable or in the class instance is
|
||||||
|
not possible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: The dis instruction to process.
|
||||||
|
"""
|
||||||
|
if instruction.opname == "CALL" and self._getting_var_instructions:
|
||||||
|
if self._getting_var_instructions:
|
||||||
|
the_var = self._eval_var()
|
||||||
|
the_var_data = the_var._get_all_var_data()
|
||||||
|
self.dependencies.setdefault(the_var_data.state, set()).add(
|
||||||
|
the_var_data.field_name
|
||||||
|
)
|
||||||
|
self._getting_var_instructions.clear()
|
||||||
|
self.scan_status = ScanStatus.SCANNING
|
||||||
|
else:
|
||||||
|
self._getting_var_instructions.append(instruction)
|
||||||
|
|
||||||
|
def _populate_dependencies(self) -> None:
|
||||||
|
"""Update self.dependencies based on the disassembly of self.func.
|
||||||
|
|
||||||
|
Save references to attributes accessed on "self" or other fetched states.
|
||||||
|
|
||||||
|
Recursively called when the function makes a method call on "self" or
|
||||||
|
define comprehensions or nested functions that may reference "self".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
VarValueError: if the function references the get_state, parent_state, or substates attributes
|
||||||
|
(cannot track deps in a related state, only implicitly via parent state).
|
||||||
|
"""
|
||||||
|
for instruction in dis.get_instructions(self.func):
|
||||||
|
if self.scan_status == ScanStatus.GETTING_STATE:
|
||||||
|
self.handle_getting_state(instruction)
|
||||||
|
elif self.scan_status == ScanStatus.GETTING_VAR:
|
||||||
|
self.handle_getting_var(instruction)
|
||||||
|
elif (
|
||||||
|
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
||||||
|
and instruction.argval in self.tracked_locals
|
||||||
|
):
|
||||||
|
# bytecode loaded the class instance to the top of stack, next load instruction
|
||||||
|
# is referencing an attribute on self
|
||||||
|
self.top_of_stack = instruction.argval
|
||||||
|
self.scan_status = ScanStatus.GETTING_ATTR
|
||||||
|
elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
|
||||||
|
"LOAD_ATTR",
|
||||||
|
"LOAD_METHOD",
|
||||||
|
):
|
||||||
|
self.load_attr_or_method(instruction)
|
||||||
|
self.top_of_stack = None
|
||||||
|
elif instruction.opname == "LOAD_CONST" and isinstance(
|
||||||
|
instruction.argval, CodeType
|
||||||
|
):
|
||||||
|
# recurse into nested functions / comprehensions, which can reference
|
||||||
|
# instance attributes from the outer scope
|
||||||
|
self._merge_deps(
|
||||||
|
type(self)(
|
||||||
|
func=instruction.argval,
|
||||||
|
state_cls=self.state_cls,
|
||||||
|
tracked_locals=self.tracked_locals,
|
||||||
|
)
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user