Track var dependencies in comprehensions and nested functions (#1728)
This commit is contained in:
parent
41e97bbc46
commit
b44c2176e0
@ -7,7 +7,7 @@ import json
|
||||
import random
|
||||
import string
|
||||
from abc import ABC
|
||||
from types import FunctionType
|
||||
from types import CodeType, FunctionType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -983,16 +983,19 @@ class ComputedVar(Var, property):
|
||||
def deps(
|
||||
self,
|
||||
objclass: Type,
|
||||
obj: FunctionType | None = None,
|
||||
obj: FunctionType | CodeType | None = None,
|
||||
self_name: Optional[str] = None,
|
||||
) -> set[str]:
|
||||
"""Determine var dependencies of this ComputedVar.
|
||||
|
||||
Save references to attributes accessed on "self". Recursively called
|
||||
when the function makes a method call on "self".
|
||||
when the function makes a method call on "self" or define comprehensions
|
||||
or nested functions that may reference "self".
|
||||
|
||||
Args:
|
||||
objclass: the class obj this ComputedVar is attached to.
|
||||
obj: the object to disassemble (defaults to the fget function).
|
||||
self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
|
||||
|
||||
Returns:
|
||||
A set of variable names accessed by the given obj.
|
||||
@ -1010,25 +1013,48 @@ class ComputedVar(Var, property):
|
||||
# unbox EventHandler
|
||||
obj = cast(FunctionType, obj.fn) # type: ignore
|
||||
|
||||
try:
|
||||
self_name = obj.__code__.co_varnames[0]
|
||||
except (AttributeError, IndexError):
|
||||
# cannot reference self if method takes no args
|
||||
if self_name is None and isinstance(obj, FunctionType):
|
||||
try:
|
||||
# the first argument to the function is the name of "self" arg
|
||||
self_name = obj.__code__.co_varnames[0]
|
||||
except (AttributeError, IndexError):
|
||||
self_name = None
|
||||
if self_name is None:
|
||||
# cannot reference attributes on self if method takes no args
|
||||
return set()
|
||||
self_is_top_of_stack = False
|
||||
for instruction in dis.get_instructions(obj):
|
||||
if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:
|
||||
if (
|
||||
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
||||
and instruction.argval == self_name
|
||||
):
|
||||
# bytecode loaded the class instance to the top of stack, next load instruction
|
||||
# is referencing an attribute on self
|
||||
self_is_top_of_stack = True
|
||||
continue
|
||||
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
|
||||
# direct attribute access
|
||||
d.add(instruction.argval)
|
||||
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
|
||||
# method call on self
|
||||
d.update(
|
||||
self.deps(
|
||||
objclass=objclass,
|
||||
obj=getattr(objclass, instruction.argval),
|
||||
)
|
||||
)
|
||||
elif instruction.opname == "LOAD_CONST" and isinstance(
|
||||
instruction.argval, CodeType
|
||||
):
|
||||
# recurse into nested functions / comprehensions, which can reference
|
||||
# instance attributes from the outer scope
|
||||
d.update(
|
||||
self.deps(
|
||||
objclass=objclass,
|
||||
obj=instruction.argval,
|
||||
self_name=self_name,
|
||||
)
|
||||
)
|
||||
self_is_top_of_stack = False
|
||||
return d
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Dict, List
|
||||
|
||||
@ -1149,6 +1151,73 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
||||
assert s.x == 45
|
||||
|
||||
|
||||
def test_computed_var_dependencies():
|
||||
"""Test that a ComputedVar correctly tracks its dependencies."""
|
||||
|
||||
class ComputedState(State):
|
||||
v: int = 0
|
||||
w: int = 0
|
||||
x: int = 0
|
||||
y: List[int] = [1, 2, 3]
|
||||
_z: List[int] = [1, 2, 3]
|
||||
|
||||
@rx.cached_var
|
||||
def comp_v(self) -> int:
|
||||
"""Direct access.
|
||||
|
||||
Returns:
|
||||
The value of self.v.
|
||||
"""
|
||||
return self.v
|
||||
|
||||
@rx.cached_var
|
||||
def comp_w(self):
|
||||
"""Nested lambda.
|
||||
|
||||
Returns:
|
||||
A lambda that returns the value of self.w.
|
||||
"""
|
||||
return lambda: self.w
|
||||
|
||||
@rx.cached_var
|
||||
def comp_x(self):
|
||||
"""Nested function.
|
||||
|
||||
Returns:
|
||||
A function that returns the value of self.x.
|
||||
"""
|
||||
|
||||
def _():
|
||||
return self.x
|
||||
|
||||
return _
|
||||
|
||||
@rx.cached_var
|
||||
def comp_y(self) -> List[int]:
|
||||
"""Comprehension iterating over attribute.
|
||||
|
||||
Returns:
|
||||
A list of the values of self.y.
|
||||
"""
|
||||
return [round(y) for y in self.y]
|
||||
|
||||
@rx.cached_var
|
||||
def comp_z(self) -> List[bool]:
|
||||
"""Comprehension accesses attribute.
|
||||
|
||||
Returns:
|
||||
A list of whether the values 0-4 are in self._z.
|
||||
"""
|
||||
return [z in self._z for z in range(5)]
|
||||
|
||||
cs = ComputedState()
|
||||
assert cs.computed_var_dependencies["v"] == {"comp_v"}
|
||||
assert cs.computed_var_dependencies["w"] == {"comp_w"}
|
||||
assert cs.computed_var_dependencies["x"] == {"comp_x"}
|
||||
assert cs.computed_var_dependencies["y"] == {"comp_y"}
|
||||
assert cs.computed_var_dependencies["_z"] == {"comp_z"}
|
||||
|
||||
|
||||
def test_backend_method():
|
||||
"""A method with leading underscore should be callable from event handler."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user