diff --git a/reflex/state.py b/reflex/state.py index afa9dedb4..e2bdb79f5 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -797,7 +797,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): *defining_state_cls.inherited_vars, *defining_state_cls.inherited_backend_vars, }: - defining_state_cls = defining_state_cls.get_parent_state() + parent_state = defining_state_cls.get_parent_state() + if parent_state is not None: + defining_state_cls = parent_state defining_state_cls._var_dependencies.setdefault(dvar, set()).add( (cls.get_full_name(), cvar_name) ) @@ -2721,7 +2723,7 @@ class StateProxy(wrapt.ObjectProxy): await self.__wrapped__.get_state(state_cls), parent_state_proxy=self ) - def _as_state_update(self, *args, **kwargs) -> StateUpdate: + async def _as_state_update(self, *args, **kwargs) -> StateUpdate: """Temporarily allow mutability to access parent_state. Args: @@ -2734,7 +2736,7 @@ class StateProxy(wrapt.ObjectProxy): original_mutable = self._self_mutable self._self_mutable = True try: - return self.__wrapped__._as_state_update(*args, **kwargs) + return await self.__wrapped__._as_state_update(*args, **kwargs) finally: self._self_mutable = original_mutable @@ -3366,10 +3368,10 @@ class StateManagerRedis(StateManager): state.parent_state = parent_state # To retain compatibility with previous implementation, by default, we return - # the top-level state by chasing `parent_state` pointers up the tree. + # the top-level state which should always be fetched or already cached. if top_level: - return state._get_root_state() - return state + return flat_state_tree[self.state.get_full_name()] + return flat_state_tree[state_cls.get_full_name()] @override async def set_state( @@ -4070,7 +4072,7 @@ def reload_state_module( if subclass.__module__ == module and module is not None: state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) - state._potentially_dirty_substates.discard(subclass.get_name()) + state._potentially_dirty_states.discard(subclass.get_name()) state._var_dependencies = {} state._init_var_dependency_dicts() state.get_class_substate.cache_clear() diff --git a/reflex/vars/base.py b/reflex/vars/base.py index ada70aa6d..3f5e1ac86 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib import dataclasses import datetime -import dis import functools import inspect import json @@ -20,6 +19,7 @@ from typing import ( Any, Callable, ClassVar, + Coroutine, Dict, FrozenSet, Generic, @@ -50,7 +50,6 @@ from reflex.utils.exceptions import ( VarAttributeError, VarDependencyError, VarTypeError, - VarValueError, ) from reflex.utils.format import format_state_name from reflex.utils.imports import ( @@ -2073,9 +2072,8 @@ class ComputedVar(Var[RETURN_TYPE]): def _deps( self, - objclass: BaseState, + objclass: Type[BaseState], obj: FunctionType | CodeType | None = None, - self_names: Optional[dict[str, str]] = None, ) -> dict[str, set[str]]: """Determine var dependencies of this ComputedVar. @@ -2087,17 +2085,12 @@ class ComputedVar(Var[RETURN_TYPE]): Args: objclass: the class obj this ComputedVar is attached to. obj: the object to disassemble (defaults to the fget function). - self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions. Returns: A dictionary mapping state names to the set of variable names accessed by the given obj. - - Raises: - VarValueError: if the function references the get_state, parent_state, or substates attributes - (cannot track deps in a related state, only implicitly via parent state). """ - from reflex.state import BaseState + from .dep_tracking import DependencyTracker d = {} if self._static_deps: @@ -2108,158 +2101,26 @@ class ComputedVar(Var[RETURN_TYPE]): if not self._auto_deps: return d + if obj is None: fget = self._fget if fget is not None: obj = cast(FunctionType, fget) else: return d - with contextlib.suppress(AttributeError): - # unbox functools.partial - obj = cast(FunctionType, obj.func) # type: ignore - with contextlib.suppress(AttributeError): - # unbox EventHandler - obj = cast(FunctionType, obj.fn) # type: ignore - if self_names is None and isinstance(obj, FunctionType): - try: - # the first argument to the function is the name of "self" arg - self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()} - except (AttributeError, IndexError): - self_names = None - if self_names is None: - # cannot reference attributes on self if method takes no args + try: + return DependencyTracker( + func=obj, state_cls=objclass, dependencies=d + ).dependencies + except Exception as e: + console.warn( + "Failed to automatically determine dependencies for computed var " + f"{objclass.__name__}.{self._js_expr}: {e}. " + "Provide static_deps and set auto_deps=False to suppress this warning." + ) return d - invalid_names = ["parent_state", "substates", "get_substate"] - self_on_top_of_stack = None - getting_state = False - getting_var = False - for instruction in dis.get_instructions(obj): - if getting_state: - if instruction.opname == "LOAD_FAST": - raise VarValueError( - f"Dependency detection cannot identify get_state class from local var {instruction.argval}." - ) - if instruction.opname == "LOAD_GLOBAL": - # Special case: referencing state class from global scope. - getting_state = obj.__globals__.get(instruction.argval) - elif instruction.opname == "LOAD_DEREF": - # Special case: referencing state class from closure. - closure = dict(zip(obj.__code__.co_freevars, obj.__closure__)) - try: - getting_state = closure[instruction.argval].cell_contents - except ValueError as ve: - raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." - ) from ve - elif instruction.opname == "STORE_FAST": - # Storing the result of get_state in a local variable. - if not isinstance(getting_state, type) or not issubclass( - getting_state, BaseState - ): - raise VarValueError( - f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`." - ) - self_names[instruction.argval] = getting_state.get_full_name() - getting_state = False - continue # nothing else happens until we have identified the local var - if getting_var: - if instruction.opname == "CALL": - # get the original source code and eval it - start_line = getting_var[0].positions.lineno - start_column = getting_var[0].positions.col_offset - end_line = getting_var[-1].positions.end_lineno - end_column = getting_var[-1].positions.end_col_offset - source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[ - start_line - 1 : end_line - ] - if len(source) > 1: - snipped_source = "".join( - [ - source[0][start_column:], - source[1:-2] if len(source) > 2 else "", - source[-1][:end_column], - ] - ) - else: - snipped_source = source[0][start_column:end_column] - the_var = eval(f"({snipped_source})", obj.__globals__) - the_var_data = the_var._get_all_var_data() - d.setdefault(the_var_data.state, set()).add(the_var_data.field_name) - getting_var = False - elif isinstance(getting_var, list): - getting_var.append(instruction) - else: - getting_var = [instruction] - continue - if ( - instruction.opname in ("LOAD_FAST", "LOAD_DEREF") - and instruction.argval in self_names - ): - # bytecode loaded the class instance to the top of stack, next load instruction - # is referencing an attribute on self - self_on_top_of_stack = self_names[instruction.argval] - continue - if self_on_top_of_stack and instruction.opname in ( - "LOAD_ATTR", - "LOAD_METHOD", - ): - if instruction.argval in invalid_names: - raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." - ) - if instruction.argval == "get_state": - # Special case: arbitrary state access requested. - getting_state = True - continue - if instruction.argval == "get_var_value": - # Special case: arbitrary var access requested. - getting_var = True - continue - target_state = objclass.get_root_state().get_class_substate( - self_on_top_of_stack - ) - try: - ref_obj = getattr(target_state, instruction.argval) - except Exception: - ref_obj = None - if callable(ref_obj): - # recurse into callable attributes - for state_name, dep_name in self._deps( - objclass=target_state, - obj=ref_obj, - ).items(): - d.setdefault(state_name, set()).update(dep_name) - # recurse into property fget functions - elif isinstance(ref_obj, property) and not isinstance( - ref_obj, ComputedVar - ): - for state_name, dep_name in self._deps( - objclass=target_state, - obj=ref_obj.fget, # type: ignore - ).items(): - d.setdefault(state_name, set()).update(dep_name) - elif ( - instruction.argval in target_state.backend_vars - or instruction.argval in target_state.vars - ): - # var access - d.setdefault(self_on_top_of_stack, set()).add(instruction.argval) - elif instruction.opname == "LOAD_CONST" and isinstance( - instruction.argval, CodeType - ): - # recurse into nested functions / comprehensions, which can reference - # instance attributes from the outer scope - for state_name, dep_name in self._deps( - objclass=objclass, - obj=instruction.argval, - self_names=self_names, - ).items(): - d.setdefault(state_name, set()).update(dep_name) - self_on_top_of_stack = None - return d - def mark_dirty(self, instance) -> None: """Mark this ComputedVar as dirty. @@ -2314,11 +2175,63 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]): class AsyncComputedVar(ComputedVar[RETURN_TYPE]): """A computed var that wraps a coroutinefunction.""" - _fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field( - default_factory=lambda: lambda _: None - ) # type: ignore + _fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = ( + dataclasses.field() + ) - def __get__(self, instance: BaseState | None, owner): + @overload + def __get__( + self: AsyncComputedVar[int] | AsyncComputedVar[float], + instance: None, + owner: Type, + ) -> NumberVar: ... + + @overload + def __get__( + self: AsyncComputedVar[str], + instance: None, + owner: Type, + ) -> StringVar: ... + + @overload + def __get__( + self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]], + instance: None, + owner: Type, + ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ... + + @overload + def __get__( + self: AsyncComputedVar[list[LIST_INSIDE]], + instance: None, + owner: Type, + ) -> ArrayVar[list[LIST_INSIDE]]: ... + + @overload + def __get__( + self: AsyncComputedVar[set[LIST_INSIDE]], + instance: None, + owner: Type, + ) -> ArrayVar[set[LIST_INSIDE]]: ... + + @overload + def __get__( + self: AsyncComputedVar[tuple[LIST_INSIDE, ...]], + instance: None, + owner: Type, + ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ... + + @overload + def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ... + + @overload + def __get__( + self, instance: BaseState, owner: Type + ) -> Coroutine[None, None, RETURN_TYPE]: ... + + def __get__( + self, instance: BaseState | None, owner + ) -> Var | Coroutine[None, None, RETURN_TYPE]: """Get the ComputedVar value. If the value is already cached on the instance, return the cached value. @@ -2335,14 +2248,15 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]): if not self._cache: - async def _awaitable_result(): + async def _awaitable_result(instance=instance) -> RETURN_TYPE: value = await self.fget(instance) self._check_deprecated_return_type(instance, value) + return value return _awaitable_result() else: # handle caching - async def _awaitable_result(): + async def _awaitable_result(instance=instance) -> RETURN_TYPE: if not hasattr(instance, self._cache_attr) or self.needs_update( instance ): @@ -2356,7 +2270,16 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]): self._check_deprecated_return_type(instance, value) return value - return _awaitable_result() + return _awaitable_result() + + @property + def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]: + """Get the getter function. + + Returns: + The getter function. + """ + return self._fget if TYPE_CHECKING: diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py new file mode 100644 index 000000000..3bd8a6bf0 --- /dev/null +++ b/reflex/vars/dep_tracking.py @@ -0,0 +1,265 @@ +"""Collection of base classes.""" + +from __future__ import annotations + +import contextlib +import dataclasses +import dis +import enum +import inspect +from types import CodeType, FunctionType +from typing import TYPE_CHECKING, ClassVar, Type, cast + +from reflex.utils.exceptions import VarValueError + +if TYPE_CHECKING: + from reflex.state import BaseState + + from .base import Var + + +class ScanStatus(enum.Enum): + """State of the dis instruction scanning loop.""" + + SCANNING = enum.auto() + GETTING_ATTR = enum.auto() + GETTING_STATE = enum.auto() + GETTING_VAR = enum.auto() + + +@dataclasses.dataclass +class DependencyTracker: + """State machine for identifying state attributes that are accessed by a function.""" + + func: FunctionType | CodeType = dataclasses.field() + state_cls: Type[BaseState] = dataclasses.field() + + dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict) + + scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING) + top_of_stack: str | None = dataclasses.field(default=None) + + tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict) + + _getting_state_class: Type[BaseState] | None = dataclasses.field(default=None) + _getting_var_instructions: list[dis.Instruction] = dataclasses.field( + default_factory=list + ) + + INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"] + + def __post_init__(self): + """After initializing, populate the dependencies dict.""" + with contextlib.suppress(AttributeError): + # unbox functools.partial + self.func = cast(FunctionType, self.func.func) # type: ignore + with contextlib.suppress(AttributeError): + # unbox EventHandler + self.func = cast(FunctionType, self.func.fn) # type: ignore + + if isinstance(self.func, FunctionType): + with contextlib.suppress(AttributeError, IndexError): + # the first argument to the function is the name of "self" arg + self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls + + self._populate_dependencies() + + def _merge_deps(self, tracker: DependencyTracker) -> None: + """Merge dependencies from another DependencyTracker. + + Args: + tracker: The DependencyTracker to merge dependencies from. + """ + for state_name, dep_name in tracker.dependencies.items(): + self.dependencies.setdefault(state_name, set()).update(dep_name) + + def load_attr_or_method(self, instruction: dis.Instruction) -> None: + """Handle loading an attribute or method from the object on top of the stack. + + This method directly tracks attributes and recursively merges + dependencies from analyzing the dependencies of any methods called. + + Args: + instruction: The dis instruction to process. + """ + from .base import ComputedVar + + if instruction.argval in self.INVALID_NAMES: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." + ) + if instruction.argval == "get_state": + # Special case: arbitrary state access requested. + self.scan_status = ScanStatus.GETTING_STATE + return + if instruction.argval == "get_var_value": + # Special case: arbitrary var access requested. + self.scan_status = ScanStatus.GETTING_VAR + return + + # Reset status back to SCANNING after attribute is accessed. + self.scan_status = ScanStatus.SCANNING + if not self.top_of_stack: + return + target_state = self.tracked_locals[self.top_of_stack] + ref_obj = getattr(target_state, instruction.argval) + + if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar): + # recurse into property fget functions + ref_obj = ref_obj.fget # type: ignore + if callable(ref_obj): + # recurse into callable attributes + self._merge_deps( + type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state) + ) + elif ( + instruction.argval in target_state.backend_vars + or instruction.argval in target_state.vars + ): + # var access + self.dependencies.setdefault(target_state.get_full_name(), set()).add( + instruction.argval + ) + + def handle_getting_state(self, instruction: dis.Instruction) -> None: + """Handle bytecode analysis when `get_state` was called in the function. + + If the wrapped function is getting an arbitrary state and saving it to a + local variable, this method associates the local variable name with the + state class in self.tracked_locals. + + When an attribute/method is accessed on a tracked local, it will be + associated with this state. + + Args: + instruction: The dis instruction to process. + """ + from reflex.state import BaseState + + if instruction.opname == "LOAD_FAST": + raise VarValueError( + f"Dependency detection cannot identify get_state class from local var {instruction.argval}." + ) + if instruction.opname == "LOAD_GLOBAL": + # Special case: referencing state class from global scope. + self._getting_state_class = self.func.__globals__.get(instruction.argval) + elif instruction.opname == "LOAD_DEREF": + # Special case: referencing state class from closure. + closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__)) + try: + self._getting_state_class = closure[instruction.argval].cell_contents + except ValueError as ve: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." + ) from ve + elif instruction.opname == "STORE_FAST": + # Storing the result of get_state in a local variable. + if not isinstance(self._getting_state_class, type) or not issubclass( + self._getting_state_class, BaseState + ): + raise VarValueError( + f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`." + ) + self.tracked_locals[instruction.argval] = self._getting_state_class + self.scan_status = ScanStatus.SCANNING + self._getting_state_class = None + + def _eval_var(self) -> Var: + """Evaluate instructions from the wrapped function to get the Var object. + + Returns: + The Var object. + """ + # Get the original source code and eval it to get the Var. + start_line = self._getting_var_instructions[0].positions.lineno + start_column = self._getting_var_instructions[0].positions.col_offset + end_line = self._getting_var_instructions[-1].positions.end_lineno + end_column = self._getting_var_instructions[-1].positions.end_col_offset + source = inspect.getsource(inspect.getmodule(self.func)).splitlines(True)[ + start_line - 1 : end_line + ] + # Create a python source string snippet. + if len(source) > 1: + snipped_source = "".join( + [ + source[0][start_column:], + source[1:-2] if len(source) > 2 else "", + source[-1][:end_column], + ] + ) + else: + snipped_source = source[0][start_column:end_column] + try: + closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__)) + except Exception: + # Fallback if the closure is not available. + closure = {} + # Evaluate the string in the context of the function's globals and closure. + return eval(f"({snipped_source})", self.func.__globals__, closure) + + def handle_getting_var(self, instruction: dis.Instruction) -> None: + """Handle bytecode analysis when `get_var_value` was called in the function. + + This only really works if the expression passed to `get_var_value` is + evaluable in the function's global scope or closure, so getting the var + value from a var saved in a local variable or in the class instance is + not possible. + + Args: + instruction: The dis instruction to process. + """ + if instruction.opname == "CALL" and self._getting_var_instructions: + if self._getting_var_instructions: + the_var = self._eval_var() + the_var_data = the_var._get_all_var_data() + self.dependencies.setdefault(the_var_data.state, set()).add( + the_var_data.field_name + ) + self._getting_var_instructions.clear() + self.scan_status = ScanStatus.SCANNING + else: + self._getting_var_instructions.append(instruction) + + def _populate_dependencies(self) -> None: + """Update self.dependencies based on the disassembly of self.func. + + Save references to attributes accessed on "self" or other fetched states. + + Recursively called when the function makes a method call on "self" or + define comprehensions or nested functions that may reference "self". + + Raises: + VarValueError: if the function references the get_state, parent_state, or substates attributes + (cannot track deps in a related state, only implicitly via parent state). + """ + for instruction in dis.get_instructions(self.func): + if self.scan_status == ScanStatus.GETTING_STATE: + self.handle_getting_state(instruction) + elif self.scan_status == ScanStatus.GETTING_VAR: + self.handle_getting_var(instruction) + elif ( + instruction.opname in ("LOAD_FAST", "LOAD_DEREF") + and instruction.argval in self.tracked_locals + ): + # bytecode loaded the class instance to the top of stack, next load instruction + # is referencing an attribute on self + self.top_of_stack = instruction.argval + self.scan_status = ScanStatus.GETTING_ATTR + elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in ( + "LOAD_ATTR", + "LOAD_METHOD", + ): + self.load_attr_or_method(instruction) + self.top_of_stack = None + elif instruction.opname == "LOAD_CONST" and isinstance( + instruction.argval, CodeType + ): + # recurse into nested functions / comprehensions, which can reference + # instance attributes from the outer scope + self._merge_deps( + type(self)( + func=instruction.argval, + state_cls=self.state_cls, + tracked_locals=self.tracked_locals, + ) + )