Move computed var dep tracking to separate module
This commit is contained in:
parent
8f4b12d29f
commit
aa5b3037d7
@ -797,7 +797,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
*defining_state_cls.inherited_vars,
|
||||
*defining_state_cls.inherited_backend_vars,
|
||||
}:
|
||||
defining_state_cls = defining_state_cls.get_parent_state()
|
||||
parent_state = defining_state_cls.get_parent_state()
|
||||
if parent_state is not None:
|
||||
defining_state_cls = parent_state
|
||||
defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
|
||||
(cls.get_full_name(), cvar_name)
|
||||
)
|
||||
@ -2721,7 +2723,7 @@ class StateProxy(wrapt.ObjectProxy):
|
||||
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
|
||||
)
|
||||
|
||||
def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
||||
async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
||||
"""Temporarily allow mutability to access parent_state.
|
||||
|
||||
Args:
|
||||
@ -2734,7 +2736,7 @@ class StateProxy(wrapt.ObjectProxy):
|
||||
original_mutable = self._self_mutable
|
||||
self._self_mutable = True
|
||||
try:
|
||||
return self.__wrapped__._as_state_update(*args, **kwargs)
|
||||
return await self.__wrapped__._as_state_update(*args, **kwargs)
|
||||
finally:
|
||||
self._self_mutable = original_mutable
|
||||
|
||||
@ -3366,10 +3368,10 @@ class StateManagerRedis(StateManager):
|
||||
state.parent_state = parent_state
|
||||
|
||||
# To retain compatibility with previous implementation, by default, we return
|
||||
# the top-level state by chasing `parent_state` pointers up the tree.
|
||||
# the top-level state which should always be fetched or already cached.
|
||||
if top_level:
|
||||
return state._get_root_state()
|
||||
return state
|
||||
return flat_state_tree[self.state.get_full_name()]
|
||||
return flat_state_tree[state_cls.get_full_name()]
|
||||
|
||||
@override
|
||||
async def set_state(
|
||||
@ -4070,7 +4072,7 @@ def reload_state_module(
|
||||
if subclass.__module__ == module and module is not None:
|
||||
state.class_subclasses.remove(subclass)
|
||||
state._always_dirty_substates.discard(subclass.get_name())
|
||||
state._potentially_dirty_substates.discard(subclass.get_name())
|
||||
state._potentially_dirty_states.discard(subclass.get_name())
|
||||
state._var_dependencies = {}
|
||||
state._init_var_dependency_dicts()
|
||||
state.get_class_substate.cache_clear()
|
||||
|
@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import datetime
|
||||
import dis
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
@ -20,6 +19,7 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Coroutine,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Generic,
|
||||
@ -50,7 +50,6 @@ from reflex.utils.exceptions import (
|
||||
VarAttributeError,
|
||||
VarDependencyError,
|
||||
VarTypeError,
|
||||
VarValueError,
|
||||
)
|
||||
from reflex.utils.format import format_state_name
|
||||
from reflex.utils.imports import (
|
||||
@ -2073,9 +2072,8 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
|
||||
def _deps(
|
||||
self,
|
||||
objclass: BaseState,
|
||||
objclass: Type[BaseState],
|
||||
obj: FunctionType | CodeType | None = None,
|
||||
self_names: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Determine var dependencies of this ComputedVar.
|
||||
|
||||
@ -2087,17 +2085,12 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
Args:
|
||||
objclass: the class obj this ComputedVar is attached to.
|
||||
obj: the object to disassemble (defaults to the fget function).
|
||||
self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping state names to the set of variable names
|
||||
accessed by the given obj.
|
||||
|
||||
Raises:
|
||||
VarValueError: if the function references the get_state, parent_state, or substates attributes
|
||||
(cannot track deps in a related state, only implicitly via parent state).
|
||||
"""
|
||||
from reflex.state import BaseState
|
||||
from .dep_tracking import DependencyTracker
|
||||
|
||||
d = {}
|
||||
if self._static_deps:
|
||||
@ -2108,158 +2101,26 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
|
||||
if not self._auto_deps:
|
||||
return d
|
||||
|
||||
if obj is None:
|
||||
fget = self._fget
|
||||
if fget is not None:
|
||||
obj = cast(FunctionType, fget)
|
||||
else:
|
||||
return d
|
||||
with contextlib.suppress(AttributeError):
|
||||
# unbox functools.partial
|
||||
obj = cast(FunctionType, obj.func) # type: ignore
|
||||
with contextlib.suppress(AttributeError):
|
||||
# unbox EventHandler
|
||||
obj = cast(FunctionType, obj.fn) # type: ignore
|
||||
|
||||
if self_names is None and isinstance(obj, FunctionType):
|
||||
try:
|
||||
# the first argument to the function is the name of "self" arg
|
||||
self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()}
|
||||
except (AttributeError, IndexError):
|
||||
self_names = None
|
||||
if self_names is None:
|
||||
# cannot reference attributes on self if method takes no args
|
||||
try:
|
||||
return DependencyTracker(
|
||||
func=obj, state_cls=objclass, dependencies=d
|
||||
).dependencies
|
||||
except Exception as e:
|
||||
console.warn(
|
||||
"Failed to automatically determine dependencies for computed var "
|
||||
f"{objclass.__name__}.{self._js_expr}: {e}. "
|
||||
"Provide static_deps and set auto_deps=False to suppress this warning."
|
||||
)
|
||||
return d
|
||||
|
||||
invalid_names = ["parent_state", "substates", "get_substate"]
|
||||
self_on_top_of_stack = None
|
||||
getting_state = False
|
||||
getting_var = False
|
||||
for instruction in dis.get_instructions(obj):
|
||||
if getting_state:
|
||||
if instruction.opname == "LOAD_FAST":
|
||||
raise VarValueError(
|
||||
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
|
||||
)
|
||||
if instruction.opname == "LOAD_GLOBAL":
|
||||
# Special case: referencing state class from global scope.
|
||||
getting_state = obj.__globals__.get(instruction.argval)
|
||||
elif instruction.opname == "LOAD_DEREF":
|
||||
# Special case: referencing state class from closure.
|
||||
closure = dict(zip(obj.__code__.co_freevars, obj.__closure__))
|
||||
try:
|
||||
getting_state = closure[instruction.argval].cell_contents
|
||||
except ValueError as ve:
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
|
||||
) from ve
|
||||
elif instruction.opname == "STORE_FAST":
|
||||
# Storing the result of get_state in a local variable.
|
||||
if not isinstance(getting_state, type) or not issubclass(
|
||||
getting_state, BaseState
|
||||
):
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
|
||||
)
|
||||
self_names[instruction.argval] = getting_state.get_full_name()
|
||||
getting_state = False
|
||||
continue # nothing else happens until we have identified the local var
|
||||
if getting_var:
|
||||
if instruction.opname == "CALL":
|
||||
# get the original source code and eval it
|
||||
start_line = getting_var[0].positions.lineno
|
||||
start_column = getting_var[0].positions.col_offset
|
||||
end_line = getting_var[-1].positions.end_lineno
|
||||
end_column = getting_var[-1].positions.end_col_offset
|
||||
source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[
|
||||
start_line - 1 : end_line
|
||||
]
|
||||
if len(source) > 1:
|
||||
snipped_source = "".join(
|
||||
[
|
||||
source[0][start_column:],
|
||||
source[1:-2] if len(source) > 2 else "",
|
||||
source[-1][:end_column],
|
||||
]
|
||||
)
|
||||
else:
|
||||
snipped_source = source[0][start_column:end_column]
|
||||
the_var = eval(f"({snipped_source})", obj.__globals__)
|
||||
the_var_data = the_var._get_all_var_data()
|
||||
d.setdefault(the_var_data.state, set()).add(the_var_data.field_name)
|
||||
getting_var = False
|
||||
elif isinstance(getting_var, list):
|
||||
getting_var.append(instruction)
|
||||
else:
|
||||
getting_var = [instruction]
|
||||
continue
|
||||
if (
|
||||
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
||||
and instruction.argval in self_names
|
||||
):
|
||||
# bytecode loaded the class instance to the top of stack, next load instruction
|
||||
# is referencing an attribute on self
|
||||
self_on_top_of_stack = self_names[instruction.argval]
|
||||
continue
|
||||
if self_on_top_of_stack and instruction.opname in (
|
||||
"LOAD_ATTR",
|
||||
"LOAD_METHOD",
|
||||
):
|
||||
if instruction.argval in invalid_names:
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
|
||||
)
|
||||
if instruction.argval == "get_state":
|
||||
# Special case: arbitrary state access requested.
|
||||
getting_state = True
|
||||
continue
|
||||
if instruction.argval == "get_var_value":
|
||||
# Special case: arbitrary var access requested.
|
||||
getting_var = True
|
||||
continue
|
||||
target_state = objclass.get_root_state().get_class_substate(
|
||||
self_on_top_of_stack
|
||||
)
|
||||
try:
|
||||
ref_obj = getattr(target_state, instruction.argval)
|
||||
except Exception:
|
||||
ref_obj = None
|
||||
if callable(ref_obj):
|
||||
# recurse into callable attributes
|
||||
for state_name, dep_name in self._deps(
|
||||
objclass=target_state,
|
||||
obj=ref_obj,
|
||||
).items():
|
||||
d.setdefault(state_name, set()).update(dep_name)
|
||||
# recurse into property fget functions
|
||||
elif isinstance(ref_obj, property) and not isinstance(
|
||||
ref_obj, ComputedVar
|
||||
):
|
||||
for state_name, dep_name in self._deps(
|
||||
objclass=target_state,
|
||||
obj=ref_obj.fget, # type: ignore
|
||||
).items():
|
||||
d.setdefault(state_name, set()).update(dep_name)
|
||||
elif (
|
||||
instruction.argval in target_state.backend_vars
|
||||
or instruction.argval in target_state.vars
|
||||
):
|
||||
# var access
|
||||
d.setdefault(self_on_top_of_stack, set()).add(instruction.argval)
|
||||
elif instruction.opname == "LOAD_CONST" and isinstance(
|
||||
instruction.argval, CodeType
|
||||
):
|
||||
# recurse into nested functions / comprehensions, which can reference
|
||||
# instance attributes from the outer scope
|
||||
for state_name, dep_name in self._deps(
|
||||
objclass=objclass,
|
||||
obj=instruction.argval,
|
||||
self_names=self_names,
|
||||
).items():
|
||||
d.setdefault(state_name, set()).update(dep_name)
|
||||
self_on_top_of_stack = None
|
||||
return d
|
||||
|
||||
def mark_dirty(self, instance) -> None:
|
||||
"""Mark this ComputedVar as dirty.
|
||||
|
||||
@ -2314,11 +2175,63 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
|
||||
class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
||||
"""A computed var that wraps a coroutinefunction."""
|
||||
|
||||
_fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field(
|
||||
default_factory=lambda: lambda _: None
|
||||
) # type: ignore
|
||||
_fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = (
|
||||
dataclasses.field()
|
||||
)
|
||||
|
||||
def __get__(self, instance: BaseState | None, owner):
|
||||
@overload
|
||||
def __get__(
|
||||
self: AsyncComputedVar[int] | AsyncComputedVar[float],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: AsyncComputedVar[str],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> StringVar: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: AsyncComputedVar[list[LIST_INSIDE]],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> ArrayVar[list[LIST_INSIDE]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: AsyncComputedVar[set[LIST_INSIDE]],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> ArrayVar[set[LIST_INSIDE]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: AsyncComputedVar[tuple[LIST_INSIDE, ...]],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self, instance: BaseState, owner: Type
|
||||
) -> Coroutine[None, None, RETURN_TYPE]: ...
|
||||
|
||||
def __get__(
|
||||
self, instance: BaseState | None, owner
|
||||
) -> Var | Coroutine[None, None, RETURN_TYPE]:
|
||||
"""Get the ComputedVar value.
|
||||
|
||||
If the value is already cached on the instance, return the cached value.
|
||||
@ -2335,14 +2248,15 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
||||
|
||||
if not self._cache:
|
||||
|
||||
async def _awaitable_result():
|
||||
async def _awaitable_result(instance=instance) -> RETURN_TYPE:
|
||||
value = await self.fget(instance)
|
||||
self._check_deprecated_return_type(instance, value)
|
||||
return value
|
||||
|
||||
return _awaitable_result()
|
||||
else:
|
||||
# handle caching
|
||||
async def _awaitable_result():
|
||||
async def _awaitable_result(instance=instance) -> RETURN_TYPE:
|
||||
if not hasattr(instance, self._cache_attr) or self.needs_update(
|
||||
instance
|
||||
):
|
||||
@ -2356,7 +2270,16 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
||||
self._check_deprecated_return_type(instance, value)
|
||||
return value
|
||||
|
||||
return _awaitable_result()
|
||||
return _awaitable_result()
|
||||
|
||||
@property
|
||||
def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]:
|
||||
"""Get the getter function.
|
||||
|
||||
Returns:
|
||||
The getter function.
|
||||
"""
|
||||
return self._fget
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
265
reflex/vars/dep_tracking.py
Normal file
265
reflex/vars/dep_tracking.py
Normal file
@ -0,0 +1,265 @@
|
||||
"""Collection of base classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import dis
|
||||
import enum
|
||||
import inspect
|
||||
from types import CodeType, FunctionType
|
||||
from typing import TYPE_CHECKING, ClassVar, Type, cast
|
||||
|
||||
from reflex.utils.exceptions import VarValueError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.state import BaseState
|
||||
|
||||
from .base import Var
|
||||
|
||||
|
||||
class ScanStatus(enum.Enum):
|
||||
"""State of the dis instruction scanning loop."""
|
||||
|
||||
SCANNING = enum.auto()
|
||||
GETTING_ATTR = enum.auto()
|
||||
GETTING_STATE = enum.auto()
|
||||
GETTING_VAR = enum.auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DependencyTracker:
|
||||
"""State machine for identifying state attributes that are accessed by a function."""
|
||||
|
||||
func: FunctionType | CodeType = dataclasses.field()
|
||||
state_cls: Type[BaseState] = dataclasses.field()
|
||||
|
||||
dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
|
||||
|
||||
scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
|
||||
top_of_stack: str | None = dataclasses.field(default=None)
|
||||
|
||||
tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict)
|
||||
|
||||
_getting_state_class: Type[BaseState] | None = dataclasses.field(default=None)
|
||||
_getting_var_instructions: list[dis.Instruction] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
|
||||
|
||||
def __post_init__(self):
|
||||
"""After initializing, populate the dependencies dict."""
|
||||
with contextlib.suppress(AttributeError):
|
||||
# unbox functools.partial
|
||||
self.func = cast(FunctionType, self.func.func) # type: ignore
|
||||
with contextlib.suppress(AttributeError):
|
||||
# unbox EventHandler
|
||||
self.func = cast(FunctionType, self.func.fn) # type: ignore
|
||||
|
||||
if isinstance(self.func, FunctionType):
|
||||
with contextlib.suppress(AttributeError, IndexError):
|
||||
# the first argument to the function is the name of "self" arg
|
||||
self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
|
||||
|
||||
self._populate_dependencies()
|
||||
|
||||
def _merge_deps(self, tracker: DependencyTracker) -> None:
|
||||
"""Merge dependencies from another DependencyTracker.
|
||||
|
||||
Args:
|
||||
tracker: The DependencyTracker to merge dependencies from.
|
||||
"""
|
||||
for state_name, dep_name in tracker.dependencies.items():
|
||||
self.dependencies.setdefault(state_name, set()).update(dep_name)
|
||||
|
||||
def load_attr_or_method(self, instruction: dis.Instruction) -> None:
|
||||
"""Handle loading an attribute or method from the object on top of the stack.
|
||||
|
||||
This method directly tracks attributes and recursively merges
|
||||
dependencies from analyzing the dependencies of any methods called.
|
||||
|
||||
Args:
|
||||
instruction: The dis instruction to process.
|
||||
"""
|
||||
from .base import ComputedVar
|
||||
|
||||
if instruction.argval in self.INVALID_NAMES:
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
|
||||
)
|
||||
if instruction.argval == "get_state":
|
||||
# Special case: arbitrary state access requested.
|
||||
self.scan_status = ScanStatus.GETTING_STATE
|
||||
return
|
||||
if instruction.argval == "get_var_value":
|
||||
# Special case: arbitrary var access requested.
|
||||
self.scan_status = ScanStatus.GETTING_VAR
|
||||
return
|
||||
|
||||
# Reset status back to SCANNING after attribute is accessed.
|
||||
self.scan_status = ScanStatus.SCANNING
|
||||
if not self.top_of_stack:
|
||||
return
|
||||
target_state = self.tracked_locals[self.top_of_stack]
|
||||
ref_obj = getattr(target_state, instruction.argval)
|
||||
|
||||
if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
|
||||
# recurse into property fget functions
|
||||
ref_obj = ref_obj.fget # type: ignore
|
||||
if callable(ref_obj):
|
||||
# recurse into callable attributes
|
||||
self._merge_deps(
|
||||
type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
|
||||
)
|
||||
elif (
|
||||
instruction.argval in target_state.backend_vars
|
||||
or instruction.argval in target_state.vars
|
||||
):
|
||||
# var access
|
||||
self.dependencies.setdefault(target_state.get_full_name(), set()).add(
|
||||
instruction.argval
|
||||
)
|
||||
|
||||
def handle_getting_state(self, instruction: dis.Instruction) -> None:
|
||||
"""Handle bytecode analysis when `get_state` was called in the function.
|
||||
|
||||
If the wrapped function is getting an arbitrary state and saving it to a
|
||||
local variable, this method associates the local variable name with the
|
||||
state class in self.tracked_locals.
|
||||
|
||||
When an attribute/method is accessed on a tracked local, it will be
|
||||
associated with this state.
|
||||
|
||||
Args:
|
||||
instruction: The dis instruction to process.
|
||||
"""
|
||||
from reflex.state import BaseState
|
||||
|
||||
if instruction.opname == "LOAD_FAST":
|
||||
raise VarValueError(
|
||||
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
|
||||
)
|
||||
if instruction.opname == "LOAD_GLOBAL":
|
||||
# Special case: referencing state class from global scope.
|
||||
self._getting_state_class = self.func.__globals__.get(instruction.argval)
|
||||
elif instruction.opname == "LOAD_DEREF":
|
||||
# Special case: referencing state class from closure.
|
||||
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__))
|
||||
try:
|
||||
self._getting_state_class = closure[instruction.argval].cell_contents
|
||||
except ValueError as ve:
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
|
||||
) from ve
|
||||
elif instruction.opname == "STORE_FAST":
|
||||
# Storing the result of get_state in a local variable.
|
||||
if not isinstance(self._getting_state_class, type) or not issubclass(
|
||||
self._getting_state_class, BaseState
|
||||
):
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
|
||||
)
|
||||
self.tracked_locals[instruction.argval] = self._getting_state_class
|
||||
self.scan_status = ScanStatus.SCANNING
|
||||
self._getting_state_class = None
|
||||
|
||||
def _eval_var(self) -> Var:
|
||||
"""Evaluate instructions from the wrapped function to get the Var object.
|
||||
|
||||
Returns:
|
||||
The Var object.
|
||||
"""
|
||||
# Get the original source code and eval it to get the Var.
|
||||
start_line = self._getting_var_instructions[0].positions.lineno
|
||||
start_column = self._getting_var_instructions[0].positions.col_offset
|
||||
end_line = self._getting_var_instructions[-1].positions.end_lineno
|
||||
end_column = self._getting_var_instructions[-1].positions.end_col_offset
|
||||
source = inspect.getsource(inspect.getmodule(self.func)).splitlines(True)[
|
||||
start_line - 1 : end_line
|
||||
]
|
||||
# Create a python source string snippet.
|
||||
if len(source) > 1:
|
||||
snipped_source = "".join(
|
||||
[
|
||||
source[0][start_column:],
|
||||
source[1:-2] if len(source) > 2 else "",
|
||||
source[-1][:end_column],
|
||||
]
|
||||
)
|
||||
else:
|
||||
snipped_source = source[0][start_column:end_column]
|
||||
try:
|
||||
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__))
|
||||
except Exception:
|
||||
# Fallback if the closure is not available.
|
||||
closure = {}
|
||||
# Evaluate the string in the context of the function's globals and closure.
|
||||
return eval(f"({snipped_source})", self.func.__globals__, closure)
|
||||
|
||||
def handle_getting_var(self, instruction: dis.Instruction) -> None:
|
||||
"""Handle bytecode analysis when `get_var_value` was called in the function.
|
||||
|
||||
This only really works if the expression passed to `get_var_value` is
|
||||
evaluable in the function's global scope or closure, so getting the var
|
||||
value from a var saved in a local variable or in the class instance is
|
||||
not possible.
|
||||
|
||||
Args:
|
||||
instruction: The dis instruction to process.
|
||||
"""
|
||||
if instruction.opname == "CALL" and self._getting_var_instructions:
|
||||
if self._getting_var_instructions:
|
||||
the_var = self._eval_var()
|
||||
the_var_data = the_var._get_all_var_data()
|
||||
self.dependencies.setdefault(the_var_data.state, set()).add(
|
||||
the_var_data.field_name
|
||||
)
|
||||
self._getting_var_instructions.clear()
|
||||
self.scan_status = ScanStatus.SCANNING
|
||||
else:
|
||||
self._getting_var_instructions.append(instruction)
|
||||
|
||||
def _populate_dependencies(self) -> None:
|
||||
"""Update self.dependencies based on the disassembly of self.func.
|
||||
|
||||
Save references to attributes accessed on "self" or other fetched states.
|
||||
|
||||
Recursively called when the function makes a method call on "self" or
|
||||
define comprehensions or nested functions that may reference "self".
|
||||
|
||||
Raises:
|
||||
VarValueError: if the function references the get_state, parent_state, or substates attributes
|
||||
(cannot track deps in a related state, only implicitly via parent state).
|
||||
"""
|
||||
for instruction in dis.get_instructions(self.func):
|
||||
if self.scan_status == ScanStatus.GETTING_STATE:
|
||||
self.handle_getting_state(instruction)
|
||||
elif self.scan_status == ScanStatus.GETTING_VAR:
|
||||
self.handle_getting_var(instruction)
|
||||
elif (
|
||||
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
||||
and instruction.argval in self.tracked_locals
|
||||
):
|
||||
# bytecode loaded the class instance to the top of stack, next load instruction
|
||||
# is referencing an attribute on self
|
||||
self.top_of_stack = instruction.argval
|
||||
self.scan_status = ScanStatus.GETTING_ATTR
|
||||
elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
|
||||
"LOAD_ATTR",
|
||||
"LOAD_METHOD",
|
||||
):
|
||||
self.load_attr_or_method(instruction)
|
||||
self.top_of_stack = None
|
||||
elif instruction.opname == "LOAD_CONST" and isinstance(
|
||||
instruction.argval, CodeType
|
||||
):
|
||||
# recurse into nested functions / comprehensions, which can reference
|
||||
# instance attributes from the outer scope
|
||||
self._merge_deps(
|
||||
type(self)(
|
||||
func=instruction.argval,
|
||||
state_cls=self.state_cls,
|
||||
tracked_locals=self.tracked_locals,
|
||||
)
|
||||
)
|
Loading…
Reference in New Issue
Block a user