diff --git a/reflex/vars.py b/reflex/vars.py index c5c31d401..e0a4bf506 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -7,7 +7,7 @@ import json import random import string from abc import ABC -from types import FunctionType +from types import CodeType, FunctionType from typing import ( TYPE_CHECKING, Any, @@ -983,16 +983,19 @@ class ComputedVar(Var, property): def deps( self, objclass: Type, - obj: FunctionType | None = None, + obj: FunctionType | CodeType | None = None, + self_name: Optional[str] = None, ) -> set[str]: """Determine var dependencies of this ComputedVar. Save references to attributes accessed on "self". Recursively called - when the function makes a method call on "self". + when the function makes a method call on "self" or define comprehensions + or nested functions that may reference "self". Args: objclass: the class obj this ComputedVar is attached to. obj: the object to disassemble (defaults to the fget function). + self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions. Returns: A set of variable names accessed by the given obj. @@ -1010,25 +1013,48 @@ class ComputedVar(Var, property): # unbox EventHandler obj = cast(FunctionType, obj.fn) # type: ignore - try: - self_name = obj.__code__.co_varnames[0] - except (AttributeError, IndexError): - # cannot reference self if method takes no args + if self_name is None and isinstance(obj, FunctionType): + try: + # the first argument to the function is the name of "self" arg + self_name = obj.__code__.co_varnames[0] + except (AttributeError, IndexError): + self_name = None + if self_name is None: + # cannot reference attributes on self if method takes no args return set() self_is_top_of_stack = False for instruction in dis.get_instructions(obj): - if instruction.opname == "LOAD_FAST" and instruction.argval == self_name: + if ( + instruction.opname in ("LOAD_FAST", "LOAD_DEREF") + and instruction.argval == self_name + ): + # bytecode loaded the class instance to the top of stack, next load instruction + # is referencing an attribute on self self_is_top_of_stack = True continue if self_is_top_of_stack and instruction.opname == "LOAD_ATTR": + # direct attribute access d.add(instruction.argval) elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD": + # method call on self d.update( self.deps( objclass=objclass, obj=getattr(objclass, instruction.argval), ) ) + elif instruction.opname == "LOAD_CONST" and isinstance( + instruction.argval, CodeType + ): + # recurse into nested functions / comprehensions, which can reference + # instance attributes from the outer scope + d.update( + self.deps( + objclass=objclass, + obj=instruction.argval, + self_name=self_name, + ) + ) self_is_top_of_stack = False return d diff --git a/tests/test_state.py b/tests/test_state.py index f5650fc0f..2c60e377d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools from typing import Dict, List @@ -1149,6 +1151,73 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): assert s.x == 45 +def test_computed_var_dependencies(): + """Test that a ComputedVar correctly tracks its dependencies.""" + + class ComputedState(State): + v: int = 0 + w: int = 0 + x: int = 0 + y: List[int] = [1, 2, 3] + _z: List[int] = [1, 2, 3] + + @rx.cached_var + def comp_v(self) -> int: + """Direct access. + + Returns: + The value of self.v. + """ + return self.v + + @rx.cached_var + def comp_w(self): + """Nested lambda. + + Returns: + A lambda that returns the value of self.w. + """ + return lambda: self.w + + @rx.cached_var + def comp_x(self): + """Nested function. + + Returns: + A function that returns the value of self.x. + """ + + def _(): + return self.x + + return _ + + @rx.cached_var + def comp_y(self) -> List[int]: + """Comprehension iterating over attribute. + + Returns: + A list of the values of self.y. + """ + return [round(y) for y in self.y] + + @rx.cached_var + def comp_z(self) -> List[bool]: + """Comprehension accesses attribute. + + Returns: + A list of whether the values 0-4 are in self._z. + """ + return [z in self._z for z in range(5)] + + cs = ComputedState() + assert cs.computed_var_dependencies["v"] == {"comp_v"} + assert cs.computed_var_dependencies["w"] == {"comp_w"} + assert cs.computed_var_dependencies["x"] == {"comp_x"} + assert cs.computed_var_dependencies["y"] == {"comp_y"} + assert cs.computed_var_dependencies["_z"] == {"comp_z"} + + def test_backend_method(): """A method with leading underscore should be callable from event handler."""