ComputedVar dependency tracking: require caller to pass objclass (#963)

Avoid issue where a ComputedVar is added to to a class dynamically, but does
not have a reference to the class its attached to, but requiring callers of the
`deps()` method to provide the objclass for looking up recursive method calls.

This allows for safer and more simplified determination of dependencies, even
in highly dynamic environments.
This commit is contained in:
Masen Furer 2023-05-09 14:36:45 -07:00 committed by GitHub
parent 3b88e7c329
commit 557097e2ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 21 additions and 21 deletions

View File

@ -104,7 +104,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
self.computed_var_dependencies = defaultdict(set) self.computed_var_dependencies = defaultdict(set)
for cvar_name, cvar in self.computed_vars.items(): for cvar_name, cvar in self.computed_vars.items():
# Add the dependencies. # Add the dependencies.
for var in cvar.deps(): for var in cvar.deps(objclass=type(self)):
self.computed_var_dependencies[var].add(cvar_name) self.computed_var_dependencies[var].add(cvar_name)
# Initialize the mutable fields. # Initialize the mutable fields.
@ -492,9 +492,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
func = arglist_factory(param) func = arglist_factory(param)
else: else:
continue continue
# link dynamically created ComputedVar to this state class for dep determination func.fget.__name__ = param # to allow passing as a prop
func.__objclass__ = cls
func.fget.__name__ = param
cls.vars[param] = cls.computed_vars[param] = func.set_state(cls) # type: ignore cls.vars[param] = cls.computed_vars[param] = func.set_state(cls) # type: ignore
setattr(cls, param, func) setattr(cls, param, func)

View File

@ -825,9 +825,6 @@ class ComputedVar(Var, property):
If the value is already cached on the instance, return the cached value. If the value is already cached on the instance, return the cached value.
If this ComputedVar doesn't know what type of object it is attached to, then save
a reference as self.__objclass__.
Args: Args:
instance: the instance of the class accessing this computed var. instance: the instance of the class accessing this computed var.
owner: the class that this descriptor is attached to. owner: the class that this descriptor is attached to.
@ -835,9 +832,6 @@ class ComputedVar(Var, property):
Returns: Returns:
The value of the var for the given instance. The value of the var for the given instance.
""" """
if not hasattr(self, "__objclass__"):
self.__objclass__ = owner
if instance is None: if instance is None:
return super().__get__(instance, owner) return super().__get__(instance, owner)
@ -846,21 +840,22 @@ class ComputedVar(Var, property):
setattr(instance, self.cache_attr, super().__get__(instance, owner)) setattr(instance, self.cache_attr, super().__get__(instance, owner))
return getattr(instance, self.cache_attr) return getattr(instance, self.cache_attr)
def deps(self, obj: Optional[FunctionType] = None) -> Set[str]: def deps(
self,
objclass: Type,
obj: Optional[FunctionType] = None,
) -> Set[str]:
"""Determine var dependencies of this ComputedVar. """Determine var dependencies of this ComputedVar.
Save references to attributes accessed on "self". Recursively called Save references to attributes accessed on "self". Recursively called
when the function makes a method call on "self". when the function makes a method call on "self".
Args: Args:
objclass: the class obj this ComputedVar is attached to.
obj: the object to disassemble (defaults to the fget function). obj: the object to disassemble (defaults to the fget function).
Returns: Returns:
A set of variable names accessed by the given obj. A set of variable names accessed by the given obj.
Raises:
RuntimeError: if this ComputedVar does not have a reference to the class
it is attached to. (Assign var.__objclass__ manually to workaround.)
""" """
d = set() d = set()
if obj is None: if obj is None:
@ -880,11 +875,12 @@ class ComputedVar(Var, property):
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR": if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
d.add(instruction.argval) d.add(instruction.argval)
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD": elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
if not hasattr(self, "__objclass__"): d.update(
raise RuntimeError( self.deps(
f"ComputedVar {self.name!r} is not bound to a State subclass.", objclass=objclass,
obj=getattr(objclass, instruction.argval),
) )
d.update(self.deps(obj=getattr(self.__objclass__, instruction.argval))) )
self_is_top_of_stack = False self_is_top_of_stack = False
return d return d

View File

@ -120,7 +120,9 @@ def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool
app.add_page(index_page, route=route) app.add_page(index_page, route=route)
assert set(app.pages.keys()) == {"test/[dynamic]"} assert set(app.pages.keys()) == {"test/[dynamic]"}
assert "dynamic" in app.state.computed_vars assert "dynamic" in app.state.computed_vars
assert app.state.computed_vars["dynamic"].deps() == {"router_data"} assert app.state.computed_vars["dynamic"].deps(objclass=DefaultState) == {
"router_data"
}
assert "router_data" in app.state().computed_var_dependencies assert "router_data" in app.state().computed_var_dependencies

View File

@ -882,7 +882,11 @@ def test_conditional_computed_vars():
assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"} assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"} assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"} assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
assert ms.computed_vars["rendered_var"].deps() == {"flag", "t1", "t2"} assert ms.computed_vars["rendered_var"].deps(objclass=MainState) == {
"flag",
"t1",
"t2",
}
def test_event_handlers_convert_to_fns(test_state, child_state): def test_event_handlers_convert_to_fns(test_state, child_state):