vars: unbox EventHandler and functools.partial for dep analysis (#1305)

When calculating the variable dependencies of a cached_var, reach into objects
with .func or .fn attributes and perform analysis on those.

Fix #1303
This commit is contained in:
Masen Furer 2023-07-05 16:05:49 -07:00 committed by GitHub
parent e9592553c4
commit 20e2a25c9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 2 deletions

View File

@ -869,10 +869,18 @@ class ComputedVar(Var, property):
obj = cast(FunctionType, self.fget)
else:
return set()
if not obj.__code__.co_varnames:
with contextlib.suppress(AttributeError):
# unbox functools.partial
obj = cast(FunctionType, obj.func) # type: ignore
with contextlib.suppress(AttributeError):
# unbox EventHandler
obj = cast(FunctionType, obj.fn) # type: ignore
try:
self_name = obj.__code__.co_varnames[0]
except (AttributeError, IndexError):
# cannot reference self if method takes no args
return set()
self_name = obj.__code__.co_varnames[0]
self_is_top_of_stack = False
for instruction in dis.get_instructions(obj):
if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:

View File

@ -1,3 +1,4 @@
import functools
from typing import Dict, List
import pytest
@ -1089,6 +1090,43 @@ def test_computed_var_depends_on_parent_non_cached():
assert counter == 6
@pytest.mark.parametrize("use_partial", [True, False])
def test_cached_var_depends_on_event_handler(use_partial: bool):
"""A cached_var that calls an event handler calculates deps correctly.
Args:
use_partial: if true, replace the EventHandler with functools.partial
"""
counter = 0
class HandlerState(State):
x: int = 42
def handler(self):
self.x = self.x + 1
@rx.cached_var
def cached_x_side_effect(self) -> int:
self.handler()
nonlocal counter
counter += 1
return counter
if use_partial:
HandlerState.handler = functools.partial(HandlerState.handler.fn)
assert isinstance(HandlerState.handler, functools.partial)
else:
assert isinstance(HandlerState.handler, EventHandler)
s = HandlerState()
assert "cached_x_side_effect" in s.computed_var_dependencies["x"]
assert s.cached_x_side_effect == 1
assert s.x == 43
s.handler()
assert s.cached_x_side_effect == 2
assert s.x == 45
def test_backend_method():
"""A method with leading underscore should be callable from event handler."""