Per-instance backend variables (#959)

* test_state: check that _backend_vars are not shared between instances

Each instance of State should get its own backend vars

* per-instance backend vars

Attempt to fix #958
This commit is contained in:
Masen Furer 2023-05-07 16:23:31 -07:00 committed by GitHub
parent 960e4ec171
commit 4515561e61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 4 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import copy
import functools import functools
import traceback import traceback
from abc import ABC from abc import ABC
@ -74,6 +75,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Mapping of var name to set of computed variables that depend on it # Mapping of var name to set of computed variables that depend on it
computed_var_dependencies: Dict[str, Set[str]] = {} computed_var_dependencies: Dict[str, Set[str]] = {}
# Per-instance copy of backend variable values
_backend_vars: Dict[str, Any] = {}
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs): def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
"""Initialize the state. """Initialize the state.
@ -106,6 +110,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Initialize the mutable fields. # Initialize the mutable fields.
self._init_mutable_fields() self._init_mutable_fields()
# Create a fresh copy of the backend variables for this instance
self._backend_vars = copy.deepcopy(self.backend_vars)
def _init_mutable_fields(self): def _init_mutable_fields(self):
"""Initialize mutable fields. """Initialize mutable fields.
@ -219,6 +226,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
"dirty_substates", "dirty_substates",
"router_data", "router_data",
"computed_var_dependencies", "computed_var_dependencies",
"_backend_vars",
} }
@classmethod @classmethod
@ -511,8 +519,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
} }
if name in inherited_vars: if name in inherited_vars:
return getattr(super().__getattribute__("parent_state"), name) return getattr(super().__getattribute__("parent_state"), name)
elif name in super().__getattribute__("backend_vars"): elif name in super().__getattribute__("_backend_vars"):
return super().__getattribute__("backend_vars").__getitem__(name) return super().__getattribute__("_backend_vars").__getitem__(name)
return super().__getattribute__(name) return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any): def __setattr__(self, name: str, value: Any):
@ -530,8 +538,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
setattr(self.parent_state, name, value) setattr(self.parent_state, name, value)
return return
if types.is_backend_variable(name): if types.is_backend_variable(name) and name != "_backend_vars":
self.backend_vars.__setitem__(name, value) self._backend_vars.__setitem__(name, value)
self.dirty_vars.add(name) self.dirty_vars.add(name)
self.mark_dirty() self.mark_dirty()
return return

View File

@ -821,6 +821,26 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
} }
def test_per_state_backend_var(interdependent_state):
"""Set backend var on one instance, expect no affect in other instances.
Args:
interdependent_state: A state with varying Var dependencies.
"""
s2 = InterdependentState()
assert s2._v2 == interdependent_state._v2
interdependent_state._v2 = 2
assert s2._v2 != interdependent_state._v2
s3 = InterdependentState()
assert s3._v2 != interdependent_state._v2
# both s2 and s3 should still have the default value
assert s2._v2 == s3._v2
# changing s2._v2 should not affect others
s2._v2 = 4
assert s2._v2 != interdependent_state._v2
assert s2._v2 != s3._v2
def test_child_state(): def test_child_state():
"""Test that the child state computed vars can reference parent state vars.""" """Test that the child state computed vars can reference parent state vars."""