Track var dependencies in comprehensions and nested functions (#1728)

This commit is contained in:
Masen Furer 2023-09-04 14:31:17 -07:00 committed by GitHub
parent 41e97bbc46
commit b44c2176e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 8 deletions

View File

@ -7,7 +7,7 @@ import json
import random import random
import string import string
from abc import ABC from abc import ABC
from types import FunctionType from types import CodeType, FunctionType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -983,16 +983,19 @@ class ComputedVar(Var, property):
def deps( def deps(
self, self,
objclass: Type, objclass: Type,
obj: FunctionType | None = None, obj: FunctionType | CodeType | None = None,
self_name: Optional[str] = None,
) -> set[str]: ) -> set[str]:
"""Determine var dependencies of this ComputedVar. """Determine var dependencies of this ComputedVar.
Save references to attributes accessed on "self". Recursively called Save references to attributes accessed on "self". Recursively called
when the function makes a method call on "self". when the function makes a method call on "self" or define comprehensions
or nested functions that may reference "self".
Args: Args:
objclass: the class obj this ComputedVar is attached to. objclass: the class obj this ComputedVar is attached to.
obj: the object to disassemble (defaults to the fget function). obj: the object to disassemble (defaults to the fget function).
self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
Returns: Returns:
A set of variable names accessed by the given obj. A set of variable names accessed by the given obj.
@ -1010,25 +1013,48 @@ class ComputedVar(Var, property):
# unbox EventHandler # unbox EventHandler
obj = cast(FunctionType, obj.fn) # type: ignore obj = cast(FunctionType, obj.fn) # type: ignore
try: if self_name is None and isinstance(obj, FunctionType):
self_name = obj.__code__.co_varnames[0] try:
except (AttributeError, IndexError): # the first argument to the function is the name of "self" arg
# cannot reference self if method takes no args self_name = obj.__code__.co_varnames[0]
except (AttributeError, IndexError):
self_name = None
if self_name is None:
# cannot reference attributes on self if method takes no args
return set() return set()
self_is_top_of_stack = False self_is_top_of_stack = False
for instruction in dis.get_instructions(obj): for instruction in dis.get_instructions(obj):
if instruction.opname == "LOAD_FAST" and instruction.argval == self_name: if (
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
and instruction.argval == self_name
):
# bytecode loaded the class instance to the top of stack, next load instruction
# is referencing an attribute on self
self_is_top_of_stack = True self_is_top_of_stack = True
continue continue
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR": if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
# direct attribute access
d.add(instruction.argval) d.add(instruction.argval)
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD": elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
# method call on self
d.update( d.update(
self.deps( self.deps(
objclass=objclass, objclass=objclass,
obj=getattr(objclass, instruction.argval), obj=getattr(objclass, instruction.argval),
) )
) )
elif instruction.opname == "LOAD_CONST" and isinstance(
instruction.argval, CodeType
):
# recurse into nested functions / comprehensions, which can reference
# instance attributes from the outer scope
d.update(
self.deps(
objclass=objclass,
obj=instruction.argval,
self_name=self_name,
)
)
self_is_top_of_stack = False self_is_top_of_stack = False
return d return d

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import functools import functools
from typing import Dict, List from typing import Dict, List
@ -1149,6 +1151,73 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
assert s.x == 45 assert s.x == 45
def test_computed_var_dependencies():
"""Test that a ComputedVar correctly tracks its dependencies."""
class ComputedState(State):
v: int = 0
w: int = 0
x: int = 0
y: List[int] = [1, 2, 3]
_z: List[int] = [1, 2, 3]
@rx.cached_var
def comp_v(self) -> int:
"""Direct access.
Returns:
The value of self.v.
"""
return self.v
@rx.cached_var
def comp_w(self):
"""Nested lambda.
Returns:
A lambda that returns the value of self.w.
"""
return lambda: self.w
@rx.cached_var
def comp_x(self):
"""Nested function.
Returns:
A function that returns the value of self.x.
"""
def _():
return self.x
return _
@rx.cached_var
def comp_y(self) -> List[int]:
"""Comprehension iterating over attribute.
Returns:
A list of the values of self.y.
"""
return [round(y) for y in self.y]
@rx.cached_var
def comp_z(self) -> List[bool]:
"""Comprehension accesses attribute.
Returns:
A list of whether the values 0-4 are in self._z.
"""
return [z in self._z for z in range(5)]
cs = ComputedState()
assert cs.computed_var_dependencies["v"] == {"comp_v"}
assert cs.computed_var_dependencies["w"] == {"comp_w"}
assert cs.computed_var_dependencies["x"] == {"comp_x"}
assert cs.computed_var_dependencies["y"] == {"comp_y"}
assert cs.computed_var_dependencies["_z"] == {"comp_z"}
def test_backend_method(): def test_backend_method():
"""A method with leading underscore should be callable from event handler.""" """A method with leading underscore should be callable from event handler."""