diff --git a/reflex/vars.py b/reflex/vars.py index 10dc772a3..ba81bffea 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -869,10 +869,18 @@ class ComputedVar(Var, property): obj = cast(FunctionType, self.fget) else: return set() - if not obj.__code__.co_varnames: + with contextlib.suppress(AttributeError): + # unbox functools.partial + obj = cast(FunctionType, obj.func) # type: ignore + with contextlib.suppress(AttributeError): + # unbox EventHandler + obj = cast(FunctionType, obj.fn) # type: ignore + + try: + self_name = obj.__code__.co_varnames[0] + except (AttributeError, IndexError): # cannot reference self if method takes no args return set() - self_name = obj.__code__.co_varnames[0] self_is_top_of_stack = False for instruction in dis.get_instructions(obj): if instruction.opname == "LOAD_FAST" and instruction.argval == self_name: diff --git a/tests/test_state.py b/tests/test_state.py index d91bb19be..e6332dfcd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,3 +1,4 @@ +import functools from typing import Dict, List import pytest @@ -1089,6 +1090,43 @@ def test_computed_var_depends_on_parent_non_cached(): assert counter == 6 +@pytest.mark.parametrize("use_partial", [True, False]) +def test_cached_var_depends_on_event_handler(use_partial: bool): + """A cached_var that calls an event handler calculates deps correctly. + + Args: + use_partial: if true, replace the EventHandler with functools.partial + """ + counter = 0 + + class HandlerState(State): + x: int = 42 + + def handler(self): + self.x = self.x + 1 + + @rx.cached_var + def cached_x_side_effect(self) -> int: + self.handler() + nonlocal counter + counter += 1 + return counter + + if use_partial: + HandlerState.handler = functools.partial(HandlerState.handler.fn) + assert isinstance(HandlerState.handler, functools.partial) + else: + assert isinstance(HandlerState.handler, EventHandler) + + s = HandlerState() + assert "cached_x_side_effect" in s.computed_var_dependencies["x"] + assert s.cached_x_side_effect == 1 + assert s.x == 43 + s.handler() + assert s.cached_x_side_effect == 2 + assert s.x == 45 + + def test_backend_method(): """A method with leading underscore should be callable from event handler."""