add Bare SQLAlchemy mutation tracking, improve typing (#3628)
This commit is contained in:
parent
9e1789a6c2
commit
d621115f9b
@ -29,6 +29,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import dill
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
try:
|
||||
import pydantic.v1 as pydantic
|
||||
@ -2963,7 +2964,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
pydantic.BaseModel.__dict__
|
||||
)
|
||||
|
||||
__mutable_types__ = (list, dict, set, Base)
|
||||
__mutable_types__ = (list, dict, set, Base, DeclarativeBase)
|
||||
|
||||
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
||||
"""Create a proxy for a mutable object that tracks changes.
|
||||
|
@ -231,7 +231,7 @@ def tmp_working_dir(tmp_path):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mutable_state():
|
||||
def mutable_state() -> MutableTestState:
|
||||
"""Create a Test state containing mutable types.
|
||||
|
||||
Returns:
|
||||
|
@ -2,8 +2,12 @@
|
||||
|
||||
from typing import Dict, List, Set, Union
|
||||
|
||||
from sqlalchemy import ARRAY, JSON, String
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
import reflex as rx
|
||||
from reflex.state import BaseState
|
||||
from reflex.utils.serializers import serializer
|
||||
|
||||
|
||||
class DictMutationTestState(BaseState):
|
||||
@ -145,15 +149,47 @@ class CustomVar(rx.Base):
|
||||
custom: OtherBase = OtherBase()
|
||||
|
||||
|
||||
class MutableSQLABase(DeclarativeBase):
|
||||
"""SQLAlchemy base model for mutable vars."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MutableSQLAModel(MutableSQLABase):
|
||||
"""SQLAlchemy model for mutable vars."""
|
||||
|
||||
__tablename__: str = "mutable_test_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
strlist: Mapped[List[str]] = mapped_column(ARRAY(String))
|
||||
hashmap: Mapped[Dict[str, str]] = mapped_column(JSON)
|
||||
test_set: Mapped[Set[str]] = mapped_column(ARRAY(String))
|
||||
|
||||
|
||||
@serializer
|
||||
def serialize_mutable_sqla_model(
|
||||
model: MutableSQLAModel,
|
||||
) -> Dict[str, Union[List[str], Dict[str, str]]]:
|
||||
"""Serialize the MutableSQLAModel.
|
||||
|
||||
Args:
|
||||
model: The MutableSQLAModel instance to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized model.
|
||||
"""
|
||||
return {"strlist": model.strlist, "hashmap": model.hashmap}
|
||||
|
||||
|
||||
class MutableTestState(BaseState):
|
||||
"""A test state."""
|
||||
|
||||
array: List[Union[str, List, Dict[str, str]]] = [
|
||||
array: List[Union[str, int, List, Dict[str, str]]] = [
|
||||
"value",
|
||||
[1, 2, 3],
|
||||
{"key": "value"},
|
||||
]
|
||||
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
|
||||
hashmap: Dict[str, Union[List, str, Dict[str, Union[str, Dict]]]] = {
|
||||
"key": ["list", "of", "values"],
|
||||
"another_key": "another_value",
|
||||
"third_key": {"key": "value"},
|
||||
@ -161,6 +197,11 @@ class MutableTestState(BaseState):
|
||||
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
|
||||
custom: CustomVar = CustomVar()
|
||||
_be_custom: CustomVar = CustomVar()
|
||||
sqla_model: MutableSQLAModel = MutableSQLAModel(
|
||||
strlist=["a", "b", "c"],
|
||||
hashmap={"key": "value"},
|
||||
test_set={"one", "two", "three"},
|
||||
)
|
||||
|
||||
def reassign_mutables(self):
|
||||
"""Assign mutable fields to different values."""
|
||||
@ -171,3 +212,8 @@ class MutableTestState(BaseState):
|
||||
"mod_third_key": {"key": "value"},
|
||||
}
|
||||
self.test_set = {1, 2, 3, 4, "five"}
|
||||
self.sqla_model = MutableSQLAModel(
|
||||
strlist=["d", "e", "f"],
|
||||
hashmap={"key": "value"},
|
||||
test_set={"one", "two", "three"},
|
||||
)
|
||||
|
@ -8,7 +8,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Union
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
@ -40,6 +40,7 @@ from reflex.testing import chdir
|
||||
from reflex.utils import format, prerequisites, types
|
||||
from reflex.utils.format import json_dumps
|
||||
from reflex.vars import BaseVar, ComputedVar
|
||||
from tests.states.mutation import MutableSQLAModel, MutableTestState
|
||||
|
||||
from .states import GenState
|
||||
|
||||
@ -1389,7 +1390,7 @@ def test_backend_method():
|
||||
assert bms._be_method()
|
||||
|
||||
|
||||
def test_setattr_of_mutable_types(mutable_state):
|
||||
def test_setattr_of_mutable_types(mutable_state: MutableTestState):
|
||||
"""Test that mutable types are converted to corresponding Reflex wrappers.
|
||||
|
||||
Args:
|
||||
@ -1398,6 +1399,7 @@ def test_setattr_of_mutable_types(mutable_state):
|
||||
array = mutable_state.array
|
||||
hashmap = mutable_state.hashmap
|
||||
test_set = mutable_state.test_set
|
||||
sqla_model = mutable_state.sqla_model
|
||||
|
||||
assert isinstance(array, MutableProxy)
|
||||
assert isinstance(array, list)
|
||||
@ -1425,11 +1427,21 @@ def test_setattr_of_mutable_types(mutable_state):
|
||||
assert isinstance(mutable_state.custom.test_set, set)
|
||||
assert isinstance(mutable_state.custom.custom, MutableProxy)
|
||||
|
||||
assert isinstance(sqla_model, MutableProxy)
|
||||
assert isinstance(sqla_model, MutableSQLAModel)
|
||||
assert isinstance(sqla_model.strlist, MutableProxy)
|
||||
assert isinstance(sqla_model.strlist, list)
|
||||
assert isinstance(sqla_model.hashmap, MutableProxy)
|
||||
assert isinstance(sqla_model.hashmap, dict)
|
||||
assert isinstance(sqla_model.test_set, MutableProxy)
|
||||
assert isinstance(sqla_model.test_set, set)
|
||||
|
||||
mutable_state.reassign_mutables()
|
||||
|
||||
array = mutable_state.array
|
||||
hashmap = mutable_state.hashmap
|
||||
test_set = mutable_state.test_set
|
||||
sqla_model = mutable_state.sqla_model
|
||||
|
||||
assert isinstance(array, MutableProxy)
|
||||
assert isinstance(array, list)
|
||||
@ -1448,6 +1460,15 @@ def test_setattr_of_mutable_types(mutable_state):
|
||||
assert isinstance(test_set, MutableProxy)
|
||||
assert isinstance(test_set, set)
|
||||
|
||||
assert isinstance(sqla_model, MutableProxy)
|
||||
assert isinstance(sqla_model, MutableSQLAModel)
|
||||
assert isinstance(sqla_model.strlist, MutableProxy)
|
||||
assert isinstance(sqla_model.strlist, list)
|
||||
assert isinstance(sqla_model.hashmap, MutableProxy)
|
||||
assert isinstance(sqla_model.hashmap, dict)
|
||||
assert isinstance(sqla_model.test_set, MutableProxy)
|
||||
assert isinstance(sqla_model.test_set, set)
|
||||
|
||||
|
||||
def test_error_on_state_method_shadow():
|
||||
"""Test that an error is thrown when an event handler shadows a state method."""
|
||||
@ -2091,7 +2112,7 @@ async def test_background_task_no_chain():
|
||||
await bts.bad_chain2()
|
||||
|
||||
|
||||
def test_mutable_list(mutable_state):
|
||||
def test_mutable_list(mutable_state: MutableTestState):
|
||||
"""Test that mutable lists are tracked correctly.
|
||||
|
||||
Args:
|
||||
@ -2121,7 +2142,7 @@ def test_mutable_list(mutable_state):
|
||||
assert_array_dirty()
|
||||
mutable_state.array.reverse()
|
||||
assert_array_dirty()
|
||||
mutable_state.array.sort()
|
||||
mutable_state.array.sort() # type: ignore[reportCallIssue,reportUnknownMemberType]
|
||||
assert_array_dirty()
|
||||
mutable_state.array[0] = 666
|
||||
assert_array_dirty()
|
||||
@ -2145,7 +2166,7 @@ def test_mutable_list(mutable_state):
|
||||
assert_array_dirty()
|
||||
|
||||
|
||||
def test_mutable_dict(mutable_state):
|
||||
def test_mutable_dict(mutable_state: MutableTestState):
|
||||
"""Test that mutable dicts are tracked correctly.
|
||||
|
||||
Args:
|
||||
@ -2159,40 +2180,40 @@ def test_mutable_dict(mutable_state):
|
||||
assert not mutable_state.dirty_vars
|
||||
|
||||
# Test all dict operations
|
||||
mutable_state.hashmap.update({"new_key": 43})
|
||||
mutable_state.hashmap.update({"new_key": "43"})
|
||||
assert_hashmap_dirty()
|
||||
assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value"
|
||||
assert mutable_state.hashmap.setdefault("another_key", "66") == "another_value"
|
||||
assert_hashmap_dirty()
|
||||
assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67
|
||||
assert mutable_state.hashmap.setdefault("setdefault_key", "67") == "67"
|
||||
assert_hashmap_dirty()
|
||||
assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67
|
||||
assert mutable_state.hashmap.setdefault("setdefault_key", "68") == "67"
|
||||
assert_hashmap_dirty()
|
||||
assert mutable_state.hashmap.pop("new_key") == 43
|
||||
assert mutable_state.hashmap.pop("new_key") == "43"
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap.popitem()
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap.clear()
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap["new_key"] = 42
|
||||
mutable_state.hashmap["new_key"] = "42"
|
||||
assert_hashmap_dirty()
|
||||
del mutable_state.hashmap["new_key"]
|
||||
assert_hashmap_dirty()
|
||||
if sys.version_info >= (3, 9):
|
||||
mutable_state.hashmap |= {"new_key": 44}
|
||||
mutable_state.hashmap |= {"new_key": "44"}
|
||||
assert_hashmap_dirty()
|
||||
|
||||
# Test nested dict operations
|
||||
mutable_state.hashmap["array"] = []
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap["array"].append(1)
|
||||
mutable_state.hashmap["array"].append("1")
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap["dict"] = {}
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap["dict"]["key"] = 42
|
||||
mutable_state.hashmap["dict"]["key"] = "42"
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap["dict"]["dict"] = {}
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap["dict"]["dict"]["key"] = 43
|
||||
mutable_state.hashmap["dict"]["dict"]["key"] = "43"
|
||||
assert_hashmap_dirty()
|
||||
|
||||
# Test proxy returned from `setdefault` and `get`
|
||||
@ -2214,14 +2235,14 @@ def test_mutable_dict(mutable_state):
|
||||
mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key")
|
||||
assert not isinstance(mutable_value_third_ref, MutableProxy)
|
||||
assert_hashmap_dirty()
|
||||
mutable_value_third_ref.append("baz")
|
||||
mutable_value_third_ref.append("baz") # type: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnusedCallResult]
|
||||
assert not mutable_state.dirty_vars
|
||||
# Unfortunately previous refs still will mark the state dirty... nothing doing about that
|
||||
assert mutable_value.pop()
|
||||
assert_hashmap_dirty()
|
||||
|
||||
|
||||
def test_mutable_set(mutable_state):
|
||||
def test_mutable_set(mutable_state: MutableTestState):
|
||||
"""Test that mutable sets are tracked correctly.
|
||||
|
||||
Args:
|
||||
@ -2263,7 +2284,7 @@ def test_mutable_set(mutable_state):
|
||||
assert_set_dirty()
|
||||
|
||||
|
||||
def test_mutable_custom(mutable_state):
|
||||
def test_mutable_custom(mutable_state: MutableTestState):
|
||||
"""Test that mutable custom types derived from Base are tracked correctly.
|
||||
|
||||
Args:
|
||||
@ -2278,17 +2299,38 @@ def test_mutable_custom(mutable_state):
|
||||
|
||||
mutable_state.custom.foo = "bar"
|
||||
assert_custom_dirty()
|
||||
mutable_state.custom.array.append(42)
|
||||
mutable_state.custom.array.append("42")
|
||||
assert_custom_dirty()
|
||||
mutable_state.custom.hashmap["key"] = 68
|
||||
mutable_state.custom.hashmap["key"] = "value"
|
||||
assert_custom_dirty()
|
||||
mutable_state.custom.test_set.add(42)
|
||||
mutable_state.custom.test_set.add("foo")
|
||||
assert_custom_dirty()
|
||||
mutable_state.custom.custom.bar = "baz"
|
||||
assert_custom_dirty()
|
||||
|
||||
|
||||
def test_mutable_backend(mutable_state):
|
||||
def test_mutable_sqla_model(mutable_state: MutableTestState):
|
||||
"""Test that mutable SQLA models are tracked correctly.
|
||||
|
||||
Args:
|
||||
mutable_state: A test state.
|
||||
"""
|
||||
assert not mutable_state.dirty_vars
|
||||
|
||||
def assert_sqla_model_dirty():
|
||||
assert mutable_state.dirty_vars == {"sqla_model"}
|
||||
mutable_state._clean()
|
||||
assert not mutable_state.dirty_vars
|
||||
|
||||
mutable_state.sqla_model.strlist.append("foo")
|
||||
assert_sqla_model_dirty()
|
||||
mutable_state.sqla_model.hashmap["key"] = "value"
|
||||
assert_sqla_model_dirty()
|
||||
mutable_state.sqla_model.test_set.add("bar")
|
||||
assert_sqla_model_dirty()
|
||||
|
||||
|
||||
def test_mutable_backend(mutable_state: MutableTestState):
|
||||
"""Test that mutable backend vars are tracked correctly.
|
||||
|
||||
Args:
|
||||
@ -2303,11 +2345,11 @@ def test_mutable_backend(mutable_state):
|
||||
|
||||
mutable_state._be_custom.foo = "bar"
|
||||
assert_custom_dirty()
|
||||
mutable_state._be_custom.array.append(42)
|
||||
mutable_state._be_custom.array.append("baz")
|
||||
assert_custom_dirty()
|
||||
mutable_state._be_custom.hashmap["key"] = 68
|
||||
mutable_state._be_custom.hashmap["key"] = "value"
|
||||
assert_custom_dirty()
|
||||
mutable_state._be_custom.test_set.add(42)
|
||||
mutable_state._be_custom.test_set.add("foo")
|
||||
assert_custom_dirty()
|
||||
mutable_state._be_custom.custom.bar = "baz"
|
||||
assert_custom_dirty()
|
||||
@ -2320,7 +2362,7 @@ def test_mutable_backend(mutable_state):
|
||||
(copy.deepcopy,),
|
||||
],
|
||||
)
|
||||
def test_mutable_copy(mutable_state, copy_func):
|
||||
def test_mutable_copy(mutable_state: MutableTestState, copy_func: Callable):
|
||||
"""Test that mutable types are copied correctly.
|
||||
|
||||
Args:
|
||||
@ -2347,7 +2389,7 @@ def test_mutable_copy(mutable_state, copy_func):
|
||||
(copy.deepcopy,),
|
||||
],
|
||||
)
|
||||
def test_mutable_copy_vars(mutable_state, copy_func):
|
||||
def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable):
|
||||
"""Test that mutable types are copied correctly.
|
||||
|
||||
Args:
|
||||
|
Loading…
Reference in New Issue
Block a user