add Bare SQLAlchemy mutation tracking, improve typing (#3628)

This commit is contained in:
benedikt-bartscher 2024-07-09 20:13:28 +02:00 committed by GitHub
parent 9e1789a6c2
commit d621115f9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 120 additions and 31 deletions

View File

@ -29,6 +29,7 @@ from typing import (
)
import dill
from sqlalchemy.orm import DeclarativeBase
try:
import pydantic.v1 as pydantic
@ -2963,7 +2964,7 @@ class MutableProxy(wrapt.ObjectProxy):
pydantic.BaseModel.__dict__
)
__mutable_types__ = (list, dict, set, Base)
__mutable_types__ = (list, dict, set, Base, DeclarativeBase)
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes.

View File

@ -231,7 +231,7 @@ def tmp_working_dir(tmp_path):
@pytest.fixture
def mutable_state():
def mutable_state() -> MutableTestState:
"""Create a Test state containing mutable types.
Returns:

View File

@ -2,8 +2,12 @@
from typing import Dict, List, Set, Union
from sqlalchemy import ARRAY, JSON, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
import reflex as rx
from reflex.state import BaseState
from reflex.utils.serializers import serializer
class DictMutationTestState(BaseState):
@ -145,15 +149,47 @@ class CustomVar(rx.Base):
custom: OtherBase = OtherBase()
class MutableSQLABase(DeclarativeBase):
"""SQLAlchemy base model for mutable vars."""
pass
class MutableSQLAModel(MutableSQLABase):
"""SQLAlchemy model for mutable vars."""
__tablename__: str = "mutable_test_state"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
strlist: Mapped[List[str]] = mapped_column(ARRAY(String))
hashmap: Mapped[Dict[str, str]] = mapped_column(JSON)
test_set: Mapped[Set[str]] = mapped_column(ARRAY(String))
@serializer
def serialize_mutable_sqla_model(
model: MutableSQLAModel,
) -> Dict[str, Union[List[str], Dict[str, str]]]:
"""Serialize the MutableSQLAModel.
Args:
model: The MutableSQLAModel instance to serialize.
Returns:
The serialized model.
"""
return {"strlist": model.strlist, "hashmap": model.hashmap}
class MutableTestState(BaseState):
"""A test state."""
array: List[Union[str, List, Dict[str, str]]] = [
array: List[Union[str, int, List, Dict[str, str]]] = [
"value",
[1, 2, 3],
{"key": "value"},
]
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
hashmap: Dict[str, Union[List, str, Dict[str, Union[str, Dict]]]] = {
"key": ["list", "of", "values"],
"another_key": "another_value",
"third_key": {"key": "value"},
@ -161,6 +197,11 @@ class MutableTestState(BaseState):
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
custom: CustomVar = CustomVar()
_be_custom: CustomVar = CustomVar()
sqla_model: MutableSQLAModel = MutableSQLAModel(
strlist=["a", "b", "c"],
hashmap={"key": "value"},
test_set={"one", "two", "three"},
)
def reassign_mutables(self):
"""Assign mutable fields to different values."""
@ -171,3 +212,8 @@ class MutableTestState(BaseState):
"mod_third_key": {"key": "value"},
}
self.test_set = {1, 2, 3, 4, "five"}
self.sqla_model = MutableSQLAModel(
strlist=["d", "e", "f"],
hashmap={"key": "value"},
test_set={"one", "two", "three"},
)

View File

@ -8,7 +8,7 @@ import json
import os
import sys
from textwrap import dedent
from typing import Any, Dict, Generator, List, Optional, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Union
from unittest.mock import AsyncMock, Mock
import pytest
@ -40,6 +40,7 @@ from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
from reflex.utils.format import json_dumps
from reflex.vars import BaseVar, ComputedVar
from tests.states.mutation import MutableSQLAModel, MutableTestState
from .states import GenState
@ -1389,7 +1390,7 @@ def test_backend_method():
assert bms._be_method()
def test_setattr_of_mutable_types(mutable_state):
def test_setattr_of_mutable_types(mutable_state: MutableTestState):
"""Test that mutable types are converted to corresponding Reflex wrappers.
Args:
@ -1398,6 +1399,7 @@ def test_setattr_of_mutable_types(mutable_state):
array = mutable_state.array
hashmap = mutable_state.hashmap
test_set = mutable_state.test_set
sqla_model = mutable_state.sqla_model
assert isinstance(array, MutableProxy)
assert isinstance(array, list)
@ -1425,11 +1427,21 @@ def test_setattr_of_mutable_types(mutable_state):
assert isinstance(mutable_state.custom.test_set, set)
assert isinstance(mutable_state.custom.custom, MutableProxy)
assert isinstance(sqla_model, MutableProxy)
assert isinstance(sqla_model, MutableSQLAModel)
assert isinstance(sqla_model.strlist, MutableProxy)
assert isinstance(sqla_model.strlist, list)
assert isinstance(sqla_model.hashmap, MutableProxy)
assert isinstance(sqla_model.hashmap, dict)
assert isinstance(sqla_model.test_set, MutableProxy)
assert isinstance(sqla_model.test_set, set)
mutable_state.reassign_mutables()
array = mutable_state.array
hashmap = mutable_state.hashmap
test_set = mutable_state.test_set
sqla_model = mutable_state.sqla_model
assert isinstance(array, MutableProxy)
assert isinstance(array, list)
@ -1448,6 +1460,15 @@ def test_setattr_of_mutable_types(mutable_state):
assert isinstance(test_set, MutableProxy)
assert isinstance(test_set, set)
assert isinstance(sqla_model, MutableProxy)
assert isinstance(sqla_model, MutableSQLAModel)
assert isinstance(sqla_model.strlist, MutableProxy)
assert isinstance(sqla_model.strlist, list)
assert isinstance(sqla_model.hashmap, MutableProxy)
assert isinstance(sqla_model.hashmap, dict)
assert isinstance(sqla_model.test_set, MutableProxy)
assert isinstance(sqla_model.test_set, set)
def test_error_on_state_method_shadow():
"""Test that an error is thrown when an event handler shadows a state method."""
@ -2091,7 +2112,7 @@ async def test_background_task_no_chain():
await bts.bad_chain2()
def test_mutable_list(mutable_state):
def test_mutable_list(mutable_state: MutableTestState):
"""Test that mutable lists are tracked correctly.
Args:
@ -2121,7 +2142,7 @@ def test_mutable_list(mutable_state):
assert_array_dirty()
mutable_state.array.reverse()
assert_array_dirty()
mutable_state.array.sort()
mutable_state.array.sort() # type: ignore[reportCallIssue,reportUnknownMemberType]
assert_array_dirty()
mutable_state.array[0] = 666
assert_array_dirty()
@ -2145,7 +2166,7 @@ def test_mutable_list(mutable_state):
assert_array_dirty()
def test_mutable_dict(mutable_state):
def test_mutable_dict(mutable_state: MutableTestState):
"""Test that mutable dicts are tracked correctly.
Args:
@ -2159,40 +2180,40 @@ def test_mutable_dict(mutable_state):
assert not mutable_state.dirty_vars
# Test all dict operations
mutable_state.hashmap.update({"new_key": 43})
mutable_state.hashmap.update({"new_key": "43"})
assert_hashmap_dirty()
assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value"
assert mutable_state.hashmap.setdefault("another_key", "66") == "another_value"
assert_hashmap_dirty()
assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67
assert mutable_state.hashmap.setdefault("setdefault_key", "67") == "67"
assert_hashmap_dirty()
assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67
assert mutable_state.hashmap.setdefault("setdefault_key", "68") == "67"
assert_hashmap_dirty()
assert mutable_state.hashmap.pop("new_key") == 43
assert mutable_state.hashmap.pop("new_key") == "43"
assert_hashmap_dirty()
mutable_state.hashmap.popitem()
assert_hashmap_dirty()
mutable_state.hashmap.clear()
assert_hashmap_dirty()
mutable_state.hashmap["new_key"] = 42
mutable_state.hashmap["new_key"] = "42"
assert_hashmap_dirty()
del mutable_state.hashmap["new_key"]
assert_hashmap_dirty()
if sys.version_info >= (3, 9):
mutable_state.hashmap |= {"new_key": 44}
mutable_state.hashmap |= {"new_key": "44"}
assert_hashmap_dirty()
# Test nested dict operations
mutable_state.hashmap["array"] = []
assert_hashmap_dirty()
mutable_state.hashmap["array"].append(1)
mutable_state.hashmap["array"].append("1")
assert_hashmap_dirty()
mutable_state.hashmap["dict"] = {}
assert_hashmap_dirty()
mutable_state.hashmap["dict"]["key"] = 42
mutable_state.hashmap["dict"]["key"] = "42"
assert_hashmap_dirty()
mutable_state.hashmap["dict"]["dict"] = {}
assert_hashmap_dirty()
mutable_state.hashmap["dict"]["dict"]["key"] = 43
mutable_state.hashmap["dict"]["dict"]["key"] = "43"
assert_hashmap_dirty()
# Test proxy returned from `setdefault` and `get`
@ -2214,14 +2235,14 @@ def test_mutable_dict(mutable_state):
mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key")
assert not isinstance(mutable_value_third_ref, MutableProxy)
assert_hashmap_dirty()
mutable_value_third_ref.append("baz")
mutable_value_third_ref.append("baz") # type: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnusedCallResult]
assert not mutable_state.dirty_vars
# Unfortunately previous refs still will mark the state dirty... nothing doing about that
assert mutable_value.pop()
assert_hashmap_dirty()
def test_mutable_set(mutable_state):
def test_mutable_set(mutable_state: MutableTestState):
"""Test that mutable sets are tracked correctly.
Args:
@ -2263,7 +2284,7 @@ def test_mutable_set(mutable_state):
assert_set_dirty()
def test_mutable_custom(mutable_state):
def test_mutable_custom(mutable_state: MutableTestState):
"""Test that mutable custom types derived from Base are tracked correctly.
Args:
@ -2278,17 +2299,38 @@ def test_mutable_custom(mutable_state):
mutable_state.custom.foo = "bar"
assert_custom_dirty()
mutable_state.custom.array.append(42)
mutable_state.custom.array.append("42")
assert_custom_dirty()
mutable_state.custom.hashmap["key"] = 68
mutable_state.custom.hashmap["key"] = "value"
assert_custom_dirty()
mutable_state.custom.test_set.add(42)
mutable_state.custom.test_set.add("foo")
assert_custom_dirty()
mutable_state.custom.custom.bar = "baz"
assert_custom_dirty()
def test_mutable_backend(mutable_state):
def test_mutable_sqla_model(mutable_state: MutableTestState):
"""Test that mutable SQLA models are tracked correctly.
Args:
mutable_state: A test state.
"""
assert not mutable_state.dirty_vars
def assert_sqla_model_dirty():
assert mutable_state.dirty_vars == {"sqla_model"}
mutable_state._clean()
assert not mutable_state.dirty_vars
mutable_state.sqla_model.strlist.append("foo")
assert_sqla_model_dirty()
mutable_state.sqla_model.hashmap["key"] = "value"
assert_sqla_model_dirty()
mutable_state.sqla_model.test_set.add("bar")
assert_sqla_model_dirty()
def test_mutable_backend(mutable_state: MutableTestState):
"""Test that mutable backend vars are tracked correctly.
Args:
@ -2303,11 +2345,11 @@ def test_mutable_backend(mutable_state):
mutable_state._be_custom.foo = "bar"
assert_custom_dirty()
mutable_state._be_custom.array.append(42)
mutable_state._be_custom.array.append("baz")
assert_custom_dirty()
mutable_state._be_custom.hashmap["key"] = 68
mutable_state._be_custom.hashmap["key"] = "value"
assert_custom_dirty()
mutable_state._be_custom.test_set.add(42)
mutable_state._be_custom.test_set.add("foo")
assert_custom_dirty()
mutable_state._be_custom.custom.bar = "baz"
assert_custom_dirty()
@ -2320,7 +2362,7 @@ def test_mutable_backend(mutable_state):
(copy.deepcopy,),
],
)
def test_mutable_copy(mutable_state, copy_func):
def test_mutable_copy(mutable_state: MutableTestState, copy_func: Callable):
"""Test that mutable types are copied correctly.
Args:
@ -2347,7 +2389,7 @@ def test_mutable_copy(mutable_state, copy_func):
(copy.deepcopy,),
],
)
def test_mutable_copy_vars(mutable_state, copy_func):
def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable):
"""Test that mutable types are copied correctly.
Args: