reflex/reflex/vars/dep_tracking.py

341 lines
13 KiB
Python

"""Collection of base classes."""
from __future__ import annotations
import contextlib
import dataclasses
import dis
import enum
import inspect
from types import CellType, CodeType, FunctionType
from typing import TYPE_CHECKING, Any, ClassVar, Type, cast
from reflex.utils.exceptions import VarValueError
if TYPE_CHECKING:
from reflex.state import BaseState
from .base import Var
CellEmpty = object()
def get_cell_value(cell: CellType) -> Any:
"""Get the value of a cell object.
Args:
cell: The cell object to get the value from. (func.__closure__ objects)
Returns:
The value from the cell or CellEmpty if a ValueError is raised.
"""
try:
return cell.cell_contents
except ValueError:
return CellEmpty
class ScanStatus(enum.Enum):
"""State of the dis instruction scanning loop."""
SCANNING = enum.auto()
GETTING_ATTR = enum.auto()
GETTING_STATE = enum.auto()
GETTING_VAR = enum.auto()
@dataclasses.dataclass
class DependencyTracker:
"""State machine for identifying state attributes that are accessed by a function."""
func: FunctionType | CodeType = dataclasses.field()
state_cls: Type[BaseState] = dataclasses.field()
dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
top_of_stack: str | None = dataclasses.field(default=None)
tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict)
_getting_state_class: Type[BaseState] | None = dataclasses.field(default=None)
_getting_var_instructions: list[dis.Instruction] = dataclasses.field(
default_factory=list
)
INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
def __post_init__(self):
"""After initializing, populate the dependencies dict."""
with contextlib.suppress(AttributeError):
# unbox functools.partial
self.func = cast(FunctionType, self.func.func) # pyright: ignore[reportGeneralTypeIssues]
with contextlib.suppress(AttributeError):
# unbox EventHandler
self.func = cast(FunctionType, self.func.fn) # pyright: ignore[reportGeneralTypeIssues]
if isinstance(self.func, FunctionType):
with contextlib.suppress(AttributeError, IndexError):
# the first argument to the function is the name of "self" arg
self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
self._populate_dependencies()
def _merge_deps(self, tracker: DependencyTracker) -> None:
"""Merge dependencies from another DependencyTracker.
Args:
tracker: The DependencyTracker to merge dependencies from.
"""
for state_name, dep_name in tracker.dependencies.items():
self.dependencies.setdefault(state_name, set()).update(dep_name)
def load_attr_or_method(self, instruction: dis.Instruction) -> None:
"""Handle loading an attribute or method from the object on top of the stack.
This method directly tracks attributes and recursively merges
dependencies from analyzing the dependencies of any methods called.
Args:
instruction: The dis instruction to process.
Raises:
VarValueError: if the attribute is an disallowed name.
"""
from .base import ComputedVar
if instruction.argval in self.INVALID_NAMES:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
)
if instruction.argval == "get_state":
# Special case: arbitrary state access requested.
self.scan_status = ScanStatus.GETTING_STATE
return
if instruction.argval == "get_var_value":
# Special case: arbitrary var access requested.
self.scan_status = ScanStatus.GETTING_VAR
return
# Reset status back to SCANNING after attribute is accessed.
self.scan_status = ScanStatus.SCANNING
if not self.top_of_stack:
return
target_state = self.tracked_locals[self.top_of_stack]
ref_obj = getattr(target_state, instruction.argval)
if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
# recurse into property fget functions
ref_obj = ref_obj.fget
if callable(ref_obj):
# recurse into callable attributes
self._merge_deps(
type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
)
elif (
instruction.argval in target_state.backend_vars
or instruction.argval in target_state.vars
):
# var access
self.dependencies.setdefault(target_state.get_full_name(), set()).add(
instruction.argval
)
def _get_globals(self) -> dict[str, Any]:
"""Get the globals of the function.
Returns:
The var names and values in the globals of the function.
"""
if isinstance(self.func, CodeType):
return {}
return self.func.__globals__ # pyright: ignore[reportGeneralTypeIssues]
def _get_closure(self) -> dict[str, Any]:
"""Get the closure of the function, with unbound values omitted.
Returns:
The var names and values in the closure of the function.
"""
if isinstance(self.func, CodeType):
return {}
return {
var_name: get_cell_value(cell)
for var_name, cell in zip(
self.func.__code__.co_freevars, # pyright: ignore[reportGeneralTypeIssues]
self.func.__closure__ or (),
strict=False,
)
if get_cell_value(cell) is not CellEmpty
}
def handle_getting_state(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis when `get_state` was called in the function.
If the wrapped function is getting an arbitrary state and saving it to a
local variable, this method associates the local variable name with the
state class in self.tracked_locals.
When an attribute/method is accessed on a tracked local, it will be
associated with this state.
Args:
instruction: The dis instruction to process.
Raises:
VarValueError: if the state class cannot be determined from the instruction.
"""
from reflex.state import BaseState
if instruction.opname == "LOAD_FAST":
raise VarValueError(
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
)
if isinstance(self.func, CodeType):
raise VarValueError(
"Dependency detection cannot identify get_state class from a code object."
)
if instruction.opname == "LOAD_GLOBAL":
# Special case: referencing state class from global scope.
try:
self._getting_state_class = self._get_globals()[instruction.argval]
except (ValueError, KeyError) as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
) from ve
elif instruction.opname == "LOAD_DEREF":
# Special case: referencing state class from closure.
try:
self._getting_state_class = self._get_closure()[instruction.argval]
except (ValueError, KeyError) as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
) from ve
elif instruction.opname == "STORE_FAST":
# Storing the result of get_state in a local variable.
if not isinstance(self._getting_state_class, type) or not issubclass(
self._getting_state_class, BaseState
):
raise VarValueError(
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
)
self.tracked_locals[instruction.argval] = self._getting_state_class
self.scan_status = ScanStatus.SCANNING
self._getting_state_class = None
def _eval_var(self) -> Var:
"""Evaluate instructions from the wrapped function to get the Var object.
Returns:
The Var object.
Raises:
VarValueError: if the source code for the var cannot be determined.
"""
# Get the original source code and eval it to get the Var.
module = inspect.getmodule(self.func)
positions0 = self._getting_var_instructions[0].positions
positions1 = self._getting_var_instructions[-1].positions
if module is None or positions0 is None or positions1 is None:
raise VarValueError(
f"Cannot determine the source code for the var in {self.func!r}."
)
start_line = positions0.lineno
start_column = positions0.col_offset
end_line = positions1.end_lineno
end_column = positions1.end_col_offset
if (
start_line is None
or start_column is None
or end_line is None
or end_column is None
):
raise VarValueError(
f"Cannot determine the source code for the var in {self.func!r}."
)
source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line]
# Create a python source string snippet.
if len(source) > 1:
snipped_source = "".join(
[
*source[0][start_column:],
*(source[1:-2] if len(source) > 2 else []),
*source[-1][: end_column - 1],
]
)
else:
snipped_source = source[0][start_column : end_column - 1]
# Evaluate the string in the context of the function's globals and closure.
return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
def handle_getting_var(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis when `get_var_value` was called in the function.
This only really works if the expression passed to `get_var_value` is
evaluable in the function's global scope or closure, so getting the var
value from a var saved in a local variable or in the class instance is
not possible.
Args:
instruction: The dis instruction to process.
Raises:
VarValueError: if the source code for the var cannot be determined.
"""
if instruction.opname == "CALL" and self._getting_var_instructions:
if self._getting_var_instructions:
the_var = self._eval_var()
the_var_data = the_var._get_all_var_data()
if the_var_data is None:
raise VarValueError(
f"Cannot determine the source code for the var in {self.func!r}."
)
self.dependencies.setdefault(the_var_data.state, set()).add(
the_var_data.field_name
)
self._getting_var_instructions.clear()
self.scan_status = ScanStatus.SCANNING
else:
self._getting_var_instructions.append(instruction)
def _populate_dependencies(self) -> None:
"""Update self.dependencies based on the disassembly of self.func.
Save references to attributes accessed on "self" or other fetched states.
Recursively called when the function makes a method call on "self" or
define comprehensions or nested functions that may reference "self".
"""
for instruction in dis.get_instructions(self.func):
if self.scan_status == ScanStatus.GETTING_STATE:
self.handle_getting_state(instruction)
elif self.scan_status == ScanStatus.GETTING_VAR:
self.handle_getting_var(instruction)
elif (
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
and instruction.argval in self.tracked_locals
):
# bytecode loaded the class instance to the top of stack, next load instruction
# is referencing an attribute on self
self.top_of_stack = instruction.argval
self.scan_status = ScanStatus.GETTING_ATTR
elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
"LOAD_ATTR",
"LOAD_METHOD",
):
self.load_attr_or_method(instruction)
self.top_of_stack = None
elif instruction.opname == "LOAD_CONST" and isinstance(
instruction.argval, CodeType
):
# recurse into nested functions / comprehensions, which can reference
# instance attributes from the outer scope
self._merge_deps(
type(self)(
func=instruction.argval,
state_cls=self.state_cls,
tracked_locals=self.tracked_locals,
)
)