Track var dependencies in comprehensions and nested functions (#1728)

This commit is contained in:
Masen Furer 2023-09-04 14:31:17 -07:00 committed by GitHub
parent 41e97bbc46
commit b44c2176e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 8 deletions

View File

@ -7,7 +7,7 @@ import json
import random
import string
from abc import ABC
from types import FunctionType
from types import CodeType, FunctionType
from typing import (
TYPE_CHECKING,
Any,
@ -983,16 +983,19 @@ class ComputedVar(Var, property):
def deps(
self,
objclass: Type,
obj: FunctionType | None = None,
obj: FunctionType | CodeType | None = None,
self_name: Optional[str] = None,
) -> set[str]:
"""Determine var dependencies of this ComputedVar.
Save references to attributes accessed on "self". Recursively called
when the function makes a method call on "self".
when the function makes a method call on "self" or define comprehensions
or nested functions that may reference "self".
Args:
objclass: the class obj this ComputedVar is attached to.
obj: the object to disassemble (defaults to the fget function).
self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
Returns:
A set of variable names accessed by the given obj.
@ -1010,25 +1013,48 @@ class ComputedVar(Var, property):
# unbox EventHandler
obj = cast(FunctionType, obj.fn) # type: ignore
try:
self_name = obj.__code__.co_varnames[0]
except (AttributeError, IndexError):
# cannot reference self if method takes no args
if self_name is None and isinstance(obj, FunctionType):
try:
# the first argument to the function is the name of "self" arg
self_name = obj.__code__.co_varnames[0]
except (AttributeError, IndexError):
self_name = None
if self_name is None:
# cannot reference attributes on self if method takes no args
return set()
self_is_top_of_stack = False
for instruction in dis.get_instructions(obj):
if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:
if (
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
and instruction.argval == self_name
):
# bytecode loaded the class instance to the top of stack, next load instruction
# is referencing an attribute on self
self_is_top_of_stack = True
continue
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
# direct attribute access
d.add(instruction.argval)
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
# method call on self
d.update(
self.deps(
objclass=objclass,
obj=getattr(objclass, instruction.argval),
)
)
elif instruction.opname == "LOAD_CONST" and isinstance(
instruction.argval, CodeType
):
# recurse into nested functions / comprehensions, which can reference
# instance attributes from the outer scope
d.update(
self.deps(
objclass=objclass,
obj=instruction.argval,
self_name=self_name,
)
)
self_is_top_of_stack = False
return d

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import functools
from typing import Dict, List
@ -1149,6 +1151,73 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
assert s.x == 45
def test_computed_var_dependencies():
"""Test that a ComputedVar correctly tracks its dependencies."""
class ComputedState(State):
v: int = 0
w: int = 0
x: int = 0
y: List[int] = [1, 2, 3]
_z: List[int] = [1, 2, 3]
@rx.cached_var
def comp_v(self) -> int:
"""Direct access.
Returns:
The value of self.v.
"""
return self.v
@rx.cached_var
def comp_w(self):
"""Nested lambda.
Returns:
A lambda that returns the value of self.w.
"""
return lambda: self.w
@rx.cached_var
def comp_x(self):
"""Nested function.
Returns:
A function that returns the value of self.x.
"""
def _():
return self.x
return _
@rx.cached_var
def comp_y(self) -> List[int]:
"""Comprehension iterating over attribute.
Returns:
A list of the values of self.y.
"""
return [round(y) for y in self.y]
@rx.cached_var
def comp_z(self) -> List[bool]:
"""Comprehension accesses attribute.
Returns:
A list of whether the values 0-4 are in self._z.
"""
return [z in self._z for z in range(5)]
cs = ComputedState()
assert cs.computed_var_dependencies["v"] == {"comp_v"}
assert cs.computed_var_dependencies["w"] == {"comp_w"}
assert cs.computed_var_dependencies["x"] == {"comp_x"}
assert cs.computed_var_dependencies["y"] == {"comp_y"}
assert cs.computed_var_dependencies["_z"] == {"comp_z"}
def test_backend_method():
"""A method with leading underscore should be callable from event handler."""