Track var dependencies in comprehensions and nested functions (#1728)
This commit is contained in:
parent
41e97bbc46
commit
b44c2176e0
@ -7,7 +7,7 @@ import json
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from types import FunctionType
|
from types import CodeType, FunctionType
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -983,16 +983,19 @@ class ComputedVar(Var, property):
|
|||||||
def deps(
|
def deps(
|
||||||
self,
|
self,
|
||||||
objclass: Type,
|
objclass: Type,
|
||||||
obj: FunctionType | None = None,
|
obj: FunctionType | CodeType | None = None,
|
||||||
|
self_name: Optional[str] = None,
|
||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
"""Determine var dependencies of this ComputedVar.
|
"""Determine var dependencies of this ComputedVar.
|
||||||
|
|
||||||
Save references to attributes accessed on "self". Recursively called
|
Save references to attributes accessed on "self". Recursively called
|
||||||
when the function makes a method call on "self".
|
when the function makes a method call on "self" or define comprehensions
|
||||||
|
or nested functions that may reference "self".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
objclass: the class obj this ComputedVar is attached to.
|
objclass: the class obj this ComputedVar is attached to.
|
||||||
obj: the object to disassemble (defaults to the fget function).
|
obj: the object to disassemble (defaults to the fget function).
|
||||||
|
self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A set of variable names accessed by the given obj.
|
A set of variable names accessed by the given obj.
|
||||||
@ -1010,25 +1013,48 @@ class ComputedVar(Var, property):
|
|||||||
# unbox EventHandler
|
# unbox EventHandler
|
||||||
obj = cast(FunctionType, obj.fn) # type: ignore
|
obj = cast(FunctionType, obj.fn) # type: ignore
|
||||||
|
|
||||||
try:
|
if self_name is None and isinstance(obj, FunctionType):
|
||||||
self_name = obj.__code__.co_varnames[0]
|
try:
|
||||||
except (AttributeError, IndexError):
|
# the first argument to the function is the name of "self" arg
|
||||||
# cannot reference self if method takes no args
|
self_name = obj.__code__.co_varnames[0]
|
||||||
|
except (AttributeError, IndexError):
|
||||||
|
self_name = None
|
||||||
|
if self_name is None:
|
||||||
|
# cannot reference attributes on self if method takes no args
|
||||||
return set()
|
return set()
|
||||||
self_is_top_of_stack = False
|
self_is_top_of_stack = False
|
||||||
for instruction in dis.get_instructions(obj):
|
for instruction in dis.get_instructions(obj):
|
||||||
if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:
|
if (
|
||||||
|
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
||||||
|
and instruction.argval == self_name
|
||||||
|
):
|
||||||
|
# bytecode loaded the class instance to the top of stack, next load instruction
|
||||||
|
# is referencing an attribute on self
|
||||||
self_is_top_of_stack = True
|
self_is_top_of_stack = True
|
||||||
continue
|
continue
|
||||||
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
|
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
|
||||||
|
# direct attribute access
|
||||||
d.add(instruction.argval)
|
d.add(instruction.argval)
|
||||||
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
|
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
|
||||||
|
# method call on self
|
||||||
d.update(
|
d.update(
|
||||||
self.deps(
|
self.deps(
|
||||||
objclass=objclass,
|
objclass=objclass,
|
||||||
obj=getattr(objclass, instruction.argval),
|
obj=getattr(objclass, instruction.argval),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif instruction.opname == "LOAD_CONST" and isinstance(
|
||||||
|
instruction.argval, CodeType
|
||||||
|
):
|
||||||
|
# recurse into nested functions / comprehensions, which can reference
|
||||||
|
# instance attributes from the outer scope
|
||||||
|
d.update(
|
||||||
|
self.deps(
|
||||||
|
objclass=objclass,
|
||||||
|
obj=instruction.argval,
|
||||||
|
self_name=self_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
self_is_top_of_stack = False
|
self_is_top_of_stack = False
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
@ -1149,6 +1151,73 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
|||||||
assert s.x == 45
|
assert s.x == 45
|
||||||
|
|
||||||
|
|
||||||
|
def test_computed_var_dependencies():
|
||||||
|
"""Test that a ComputedVar correctly tracks its dependencies."""
|
||||||
|
|
||||||
|
class ComputedState(State):
|
||||||
|
v: int = 0
|
||||||
|
w: int = 0
|
||||||
|
x: int = 0
|
||||||
|
y: List[int] = [1, 2, 3]
|
||||||
|
_z: List[int] = [1, 2, 3]
|
||||||
|
|
||||||
|
@rx.cached_var
|
||||||
|
def comp_v(self) -> int:
|
||||||
|
"""Direct access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of self.v.
|
||||||
|
"""
|
||||||
|
return self.v
|
||||||
|
|
||||||
|
@rx.cached_var
|
||||||
|
def comp_w(self):
|
||||||
|
"""Nested lambda.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A lambda that returns the value of self.w.
|
||||||
|
"""
|
||||||
|
return lambda: self.w
|
||||||
|
|
||||||
|
@rx.cached_var
|
||||||
|
def comp_x(self):
|
||||||
|
"""Nested function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that returns the value of self.x.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _():
|
||||||
|
return self.x
|
||||||
|
|
||||||
|
return _
|
||||||
|
|
||||||
|
@rx.cached_var
|
||||||
|
def comp_y(self) -> List[int]:
|
||||||
|
"""Comprehension iterating over attribute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of the values of self.y.
|
||||||
|
"""
|
||||||
|
return [round(y) for y in self.y]
|
||||||
|
|
||||||
|
@rx.cached_var
|
||||||
|
def comp_z(self) -> List[bool]:
|
||||||
|
"""Comprehension accesses attribute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of whether the values 0-4 are in self._z.
|
||||||
|
"""
|
||||||
|
return [z in self._z for z in range(5)]
|
||||||
|
|
||||||
|
cs = ComputedState()
|
||||||
|
assert cs.computed_var_dependencies["v"] == {"comp_v"}
|
||||||
|
assert cs.computed_var_dependencies["w"] == {"comp_w"}
|
||||||
|
assert cs.computed_var_dependencies["x"] == {"comp_x"}
|
||||||
|
assert cs.computed_var_dependencies["y"] == {"comp_y"}
|
||||||
|
assert cs.computed_var_dependencies["_z"] == {"comp_z"}
|
||||||
|
|
||||||
|
|
||||||
def test_backend_method():
|
def test_backend_method():
|
||||||
"""A method with leading underscore should be callable from event handler."""
|
"""A method with leading underscore should be callable from event handler."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user