[ENG-4083] Track internal changes in dataclass instances (#4558)
* [ENG-4083] Track internal changes in dataclass instances Create a dynamic subclass of MutableProxy with `__dataclass_fields__` set according to the dataclass being wrapped. * support dataclasses.asdict on MutableProxy instances
This commit is contained in:
parent
4b89b8260b
commit
41cb2d8cff
@ -3649,6 +3649,9 @@ def get_state_manager() -> StateManager:
|
||||
class MutableProxy(wrapt.ObjectProxy):
|
||||
"""A proxy for a mutable object that tracks changes."""
|
||||
|
||||
# Hint for finding the base class of the proxy.
|
||||
__base_proxy__ = "MutableProxy"
|
||||
|
||||
# Methods on wrapped objects which should mark the state as dirty.
|
||||
__mark_dirty_attrs__ = {
|
||||
"add",
|
||||
@ -3691,6 +3694,39 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
BaseModelV1,
|
||||
)
|
||||
|
||||
# Dynamically generated classes for tracking dataclass mutations.
|
||||
__dataclass_proxies__: Dict[type, type] = {}
|
||||
|
||||
def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
|
||||
"""Create a proxy instance for a mutable object that tracks changes.
|
||||
|
||||
Args:
|
||||
wrapped: The object to proxy.
|
||||
*args: Other args passed to MutableProxy (ignored).
|
||||
**kwargs: Other kwargs passed to MutableProxy (ignored).
|
||||
|
||||
Returns:
|
||||
The proxy instance.
|
||||
"""
|
||||
if dataclasses.is_dataclass(wrapped):
|
||||
wrapped_cls = type(wrapped)
|
||||
wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
|
||||
# Find the associated class
|
||||
if wrapper_cls_name not in cls.__dataclass_proxies__:
|
||||
# Create a new class that has the __dataclass_fields__ defined
|
||||
cls.__dataclass_proxies__[wrapper_cls_name] = type(
|
||||
wrapper_cls_name,
|
||||
(cls,),
|
||||
{
|
||||
dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues]
|
||||
wrapped_cls,
|
||||
dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues]
|
||||
),
|
||||
},
|
||||
)
|
||||
cls = cls.__dataclass_proxies__[wrapper_cls_name]
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
||||
"""Create a proxy for a mutable object that tracks changes.
|
||||
|
||||
@ -3747,7 +3783,27 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
Returns:
|
||||
Whether the value is of a mutable type.
|
||||
"""
|
||||
return isinstance(value, cls.__mutable_types__)
|
||||
return isinstance(value, cls.__mutable_types__) or (
|
||||
dataclasses.is_dataclass(value) and not isinstance(value, Var)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_called_from_dataclasses_internal() -> bool:
|
||||
"""Check if the current function is called from dataclasses helper.
|
||||
|
||||
Returns:
|
||||
Whether the current function is called from dataclasses internal code.
|
||||
"""
|
||||
# Walk up the stack a bit to see if we are called from dataclasses
|
||||
# internal code, for example `asdict` or `astuple`.
|
||||
frame = inspect.currentframe()
|
||||
for _ in range(5):
|
||||
# Why not `inspect.stack()` -- this is much faster!
|
||||
if not (frame := frame and frame.f_back):
|
||||
break
|
||||
if inspect.getfile(frame) == dataclasses.__file__:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _wrap_recursive(self, value: Any) -> Any:
|
||||
"""Wrap a value recursively if it is mutable.
|
||||
@ -3758,9 +3814,13 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
Returns:
|
||||
The wrapped value.
|
||||
"""
|
||||
# When called from dataclasses internal code, return the unwrapped value
|
||||
if self._is_called_from_dataclasses_internal():
|
||||
return value
|
||||
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
|
||||
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
|
||||
return type(self)(
|
||||
base_cls = globals()[self.__base_proxy__]
|
||||
return base_cls(
|
||||
wrapped=value,
|
||||
state=self._self_state,
|
||||
field_name=self._self_field_name,
|
||||
@ -3968,6 +4028,9 @@ class ImmutableMutableProxy(MutableProxy):
|
||||
to modify the wrapped object when the StateProxy is immutable.
|
||||
"""
|
||||
|
||||
# Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
|
||||
__base_proxy__ = "ImmutableMutableProxy"
|
||||
|
||||
def _mark_dirty(
|
||||
self,
|
||||
wrapped=None,
|
||||
|
@ -1936,6 +1936,14 @@ def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
|
||||
return mock_app_simple
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelDC:
|
||||
"""A dataclass."""
|
||||
|
||||
foo: str = "bar"
|
||||
ls: list[dict] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
"""Test that the state proxy works.
|
||||
@ -2038,6 +2046,7 @@ class BackgroundTaskState(BaseState):
|
||||
|
||||
order: List[str] = []
|
||||
dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
|
||||
dc: ModelDC = ModelDC()
|
||||
|
||||
def __init__(self, **kwargs): # noqa: D107
|
||||
super().__init__(**kwargs)
|
||||
@ -2063,10 +2072,18 @@ class BackgroundTaskState(BaseState):
|
||||
with pytest.raises(ImmutableStateError):
|
||||
self.order.append("bad idea")
|
||||
|
||||
with pytest.raises(ImmutableStateError):
|
||||
# Cannot manipulate dataclass attributes.
|
||||
self.dc.foo = "baz"
|
||||
|
||||
with pytest.raises(ImmutableStateError):
|
||||
# Even nested access to mutables raises an exception.
|
||||
self.dict_list["foo"].append(42)
|
||||
|
||||
with pytest.raises(ImmutableStateError):
|
||||
# Cannot modify dataclass list attribute.
|
||||
self.dc.ls.append({"foo": "bar"})
|
||||
|
||||
with pytest.raises(ImmutableStateError):
|
||||
# Direct calling another handler that modifies state raises an exception.
|
||||
self.other()
|
||||
@ -3582,13 +3599,6 @@ class ModelV2(BaseModelV2):
|
||||
foo: str = "bar"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelDC:
|
||||
"""A dataclass."""
|
||||
|
||||
foo: str = "bar"
|
||||
|
||||
|
||||
class PydanticState(rx.State):
|
||||
"""A state with pydantic BaseModel vars."""
|
||||
|
||||
@ -3610,11 +3620,22 @@ def test_mutable_models():
|
||||
assert state.dirty_vars == {"v2"}
|
||||
state.dirty_vars.clear()
|
||||
|
||||
# Not yet supported ENG-4083
|
||||
# assert isinstance(state.dc, MutableProxy) #noqa: ERA001
|
||||
# state.dc.foo = "baz" #noqa: ERA001
|
||||
# assert state.dirty_vars == {"dc"} #noqa: ERA001
|
||||
# state.dirty_vars.clear() #noqa: ERA001
|
||||
assert isinstance(state.dc, MutableProxy)
|
||||
state.dc.foo = "baz"
|
||||
assert state.dirty_vars == {"dc"}
|
||||
state.dirty_vars.clear()
|
||||
assert state.dirty_vars == set()
|
||||
state.dc.ls.append({"hi": "reflex"})
|
||||
assert state.dirty_vars == {"dc"}
|
||||
state.dirty_vars.clear()
|
||||
assert state.dirty_vars == set()
|
||||
assert dataclasses.asdict(state.dc) == {"foo": "baz", "ls": [{"hi": "reflex"}]}
|
||||
assert dataclasses.astuple(state.dc) == ("baz", [{"hi": "reflex"}])
|
||||
# creating a new instance shouldn't mark the state dirty
|
||||
assert dataclasses.replace(state.dc, foo="quuc") == ModelDC(
|
||||
foo="quuc", ls=[{"hi": "reflex"}]
|
||||
)
|
||||
assert state.dirty_vars == set()
|
||||
|
||||
|
||||
def test_get_value():
|
||||
|
Loading…
Reference in New Issue
Block a user