[ENG-4083] Track internal changes in dataclass instances (#4558)
* [ENG-4083] Track internal changes in dataclass instances Create a dynamic subclass of MutableProxy with `__dataclass_fields__` set according to the dataclass being wrapped. * support dataclasses.asdict on MutableProxy instances
This commit is contained in:
parent
4b89b8260b
commit
41cb2d8cff
@ -3649,6 +3649,9 @@ def get_state_manager() -> StateManager:
|
|||||||
class MutableProxy(wrapt.ObjectProxy):
|
class MutableProxy(wrapt.ObjectProxy):
|
||||||
"""A proxy for a mutable object that tracks changes."""
|
"""A proxy for a mutable object that tracks changes."""
|
||||||
|
|
||||||
|
# Hint for finding the base class of the proxy.
|
||||||
|
__base_proxy__ = "MutableProxy"
|
||||||
|
|
||||||
# Methods on wrapped objects which should mark the state as dirty.
|
# Methods on wrapped objects which should mark the state as dirty.
|
||||||
__mark_dirty_attrs__ = {
|
__mark_dirty_attrs__ = {
|
||||||
"add",
|
"add",
|
||||||
@ -3691,6 +3694,39 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
BaseModelV1,
|
BaseModelV1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Dynamically generated classes for tracking dataclass mutations.
|
||||||
|
__dataclass_proxies__: Dict[type, type] = {}
|
||||||
|
|
||||||
|
def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
|
||||||
|
"""Create a proxy instance for a mutable object that tracks changes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wrapped: The object to proxy.
|
||||||
|
*args: Other args passed to MutableProxy (ignored).
|
||||||
|
**kwargs: Other kwargs passed to MutableProxy (ignored).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The proxy instance.
|
||||||
|
"""
|
||||||
|
if dataclasses.is_dataclass(wrapped):
|
||||||
|
wrapped_cls = type(wrapped)
|
||||||
|
wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
|
||||||
|
# Find the associated class
|
||||||
|
if wrapper_cls_name not in cls.__dataclass_proxies__:
|
||||||
|
# Create a new class that has the __dataclass_fields__ defined
|
||||||
|
cls.__dataclass_proxies__[wrapper_cls_name] = type(
|
||||||
|
wrapper_cls_name,
|
||||||
|
(cls,),
|
||||||
|
{
|
||||||
|
dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues]
|
||||||
|
wrapped_cls,
|
||||||
|
dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
cls = cls.__dataclass_proxies__[wrapper_cls_name]
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
||||||
"""Create a proxy for a mutable object that tracks changes.
|
"""Create a proxy for a mutable object that tracks changes.
|
||||||
|
|
||||||
@ -3747,7 +3783,27 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
Returns:
|
Returns:
|
||||||
Whether the value is of a mutable type.
|
Whether the value is of a mutable type.
|
||||||
"""
|
"""
|
||||||
return isinstance(value, cls.__mutable_types__)
|
return isinstance(value, cls.__mutable_types__) or (
|
||||||
|
dataclasses.is_dataclass(value) and not isinstance(value, Var)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_called_from_dataclasses_internal() -> bool:
|
||||||
|
"""Check if the current function is called from dataclasses helper.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the current function is called from dataclasses internal code.
|
||||||
|
"""
|
||||||
|
# Walk up the stack a bit to see if we are called from dataclasses
|
||||||
|
# internal code, for example `asdict` or `astuple`.
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
for _ in range(5):
|
||||||
|
# Why not `inspect.stack()` -- this is much faster!
|
||||||
|
if not (frame := frame and frame.f_back):
|
||||||
|
break
|
||||||
|
if inspect.getfile(frame) == dataclasses.__file__:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _wrap_recursive(self, value: Any) -> Any:
|
def _wrap_recursive(self, value: Any) -> Any:
|
||||||
"""Wrap a value recursively if it is mutable.
|
"""Wrap a value recursively if it is mutable.
|
||||||
@ -3758,9 +3814,13 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
Returns:
|
Returns:
|
||||||
The wrapped value.
|
The wrapped value.
|
||||||
"""
|
"""
|
||||||
|
# When called from dataclasses internal code, return the unwrapped value
|
||||||
|
if self._is_called_from_dataclasses_internal():
|
||||||
|
return value
|
||||||
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
|
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
|
||||||
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
|
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
|
||||||
return type(self)(
|
base_cls = globals()[self.__base_proxy__]
|
||||||
|
return base_cls(
|
||||||
wrapped=value,
|
wrapped=value,
|
||||||
state=self._self_state,
|
state=self._self_state,
|
||||||
field_name=self._self_field_name,
|
field_name=self._self_field_name,
|
||||||
@ -3968,6 +4028,9 @@ class ImmutableMutableProxy(MutableProxy):
|
|||||||
to modify the wrapped object when the StateProxy is immutable.
|
to modify the wrapped object when the StateProxy is immutable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
|
||||||
|
__base_proxy__ = "ImmutableMutableProxy"
|
||||||
|
|
||||||
def _mark_dirty(
|
def _mark_dirty(
|
||||||
self,
|
self,
|
||||||
wrapped=None,
|
wrapped=None,
|
||||||
|
@ -1936,6 +1936,14 @@ def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
|
|||||||
return mock_app_simple
|
return mock_app_simple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ModelDC:
|
||||||
|
"""A dataclass."""
|
||||||
|
|
||||||
|
foo: str = "bar"
|
||||||
|
ls: list[dict] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||||
"""Test that the state proxy works.
|
"""Test that the state proxy works.
|
||||||
@ -2038,6 +2046,7 @@ class BackgroundTaskState(BaseState):
|
|||||||
|
|
||||||
order: List[str] = []
|
order: List[str] = []
|
||||||
dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
|
dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
|
||||||
|
dc: ModelDC = ModelDC()
|
||||||
|
|
||||||
def __init__(self, **kwargs): # noqa: D107
|
def __init__(self, **kwargs): # noqa: D107
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -2063,10 +2072,18 @@ class BackgroundTaskState(BaseState):
|
|||||||
with pytest.raises(ImmutableStateError):
|
with pytest.raises(ImmutableStateError):
|
||||||
self.order.append("bad idea")
|
self.order.append("bad idea")
|
||||||
|
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
# Cannot manipulate dataclass attributes.
|
||||||
|
self.dc.foo = "baz"
|
||||||
|
|
||||||
with pytest.raises(ImmutableStateError):
|
with pytest.raises(ImmutableStateError):
|
||||||
# Even nested access to mutables raises an exception.
|
# Even nested access to mutables raises an exception.
|
||||||
self.dict_list["foo"].append(42)
|
self.dict_list["foo"].append(42)
|
||||||
|
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
# Cannot modify dataclass list attribute.
|
||||||
|
self.dc.ls.append({"foo": "bar"})
|
||||||
|
|
||||||
with pytest.raises(ImmutableStateError):
|
with pytest.raises(ImmutableStateError):
|
||||||
# Direct calling another handler that modifies state raises an exception.
|
# Direct calling another handler that modifies state raises an exception.
|
||||||
self.other()
|
self.other()
|
||||||
@ -3582,13 +3599,6 @@ class ModelV2(BaseModelV2):
|
|||||||
foo: str = "bar"
|
foo: str = "bar"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ModelDC:
|
|
||||||
"""A dataclass."""
|
|
||||||
|
|
||||||
foo: str = "bar"
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticState(rx.State):
|
class PydanticState(rx.State):
|
||||||
"""A state with pydantic BaseModel vars."""
|
"""A state with pydantic BaseModel vars."""
|
||||||
|
|
||||||
@ -3610,11 +3620,22 @@ def test_mutable_models():
|
|||||||
assert state.dirty_vars == {"v2"}
|
assert state.dirty_vars == {"v2"}
|
||||||
state.dirty_vars.clear()
|
state.dirty_vars.clear()
|
||||||
|
|
||||||
# Not yet supported ENG-4083
|
assert isinstance(state.dc, MutableProxy)
|
||||||
# assert isinstance(state.dc, MutableProxy) #noqa: ERA001
|
state.dc.foo = "baz"
|
||||||
# state.dc.foo = "baz" #noqa: ERA001
|
assert state.dirty_vars == {"dc"}
|
||||||
# assert state.dirty_vars == {"dc"} #noqa: ERA001
|
state.dirty_vars.clear()
|
||||||
# state.dirty_vars.clear() #noqa: ERA001
|
assert state.dirty_vars == set()
|
||||||
|
state.dc.ls.append({"hi": "reflex"})
|
||||||
|
assert state.dirty_vars == {"dc"}
|
||||||
|
state.dirty_vars.clear()
|
||||||
|
assert state.dirty_vars == set()
|
||||||
|
assert dataclasses.asdict(state.dc) == {"foo": "baz", "ls": [{"hi": "reflex"}]}
|
||||||
|
assert dataclasses.astuple(state.dc) == ("baz", [{"hi": "reflex"}])
|
||||||
|
# creating a new instance shouldn't mark the state dirty
|
||||||
|
assert dataclasses.replace(state.dc, foo="quuc") == ModelDC(
|
||||||
|
foo="quuc", ls=[{"hi": "reflex"}]
|
||||||
|
)
|
||||||
|
assert state.dirty_vars == set()
|
||||||
|
|
||||||
|
|
||||||
def test_get_value():
|
def test_get_value():
|
||||||
|
Loading…
Reference in New Issue
Block a user