From d621115f9b30e988e872921c3969cfd5447d5ddf Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Tue, 9 Jul 2024 20:13:28 +0200 Subject: [PATCH] add Bare SQLAlchemy mutation tracking, improve typing (#3628) --- reflex/state.py | 3 +- tests/conftest.py | 2 +- tests/states/mutation.py | 50 ++++++++++++++++++++- tests/test_state.py | 96 +++++++++++++++++++++++++++++----------- 4 files changed, 120 insertions(+), 31 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 9569b2aba..c8d970a4f 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -29,6 +29,7 @@ from typing import ( ) import dill +from sqlalchemy.orm import DeclarativeBase try: import pydantic.v1 as pydantic @@ -2963,7 +2964,7 @@ class MutableProxy(wrapt.ObjectProxy): pydantic.BaseModel.__dict__ ) - __mutable_types__ = (list, dict, set, Base) + __mutable_types__ = (list, dict, set, Base, DeclarativeBase) def __init__(self, wrapped: Any, state: BaseState, field_name: str): """Create a proxy for a mutable object that tracks changes. diff --git a/tests/conftest.py b/tests/conftest.py index 71815ca9a..589d35cd7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -231,7 +231,7 @@ def tmp_working_dir(tmp_path): @pytest.fixture -def mutable_state(): +def mutable_state() -> MutableTestState: """Create a Test state containing mutable types. Returns: diff --git a/tests/states/mutation.py b/tests/states/mutation.py index 5825b6d12..b05f558a1 100644 --- a/tests/states/mutation.py +++ b/tests/states/mutation.py @@ -2,8 +2,12 @@ from typing import Dict, List, Set, Union +from sqlalchemy import ARRAY, JSON, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + import reflex as rx from reflex.state import BaseState +from reflex.utils.serializers import serializer class DictMutationTestState(BaseState): @@ -145,15 +149,47 @@ class CustomVar(rx.Base): custom: OtherBase = OtherBase() +class MutableSQLABase(DeclarativeBase): + """SQLAlchemy base model for mutable vars.""" + + pass + + +class MutableSQLAModel(MutableSQLABase): + """SQLAlchemy model for mutable vars.""" + + __tablename__: str = "mutable_test_state" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + strlist: Mapped[List[str]] = mapped_column(ARRAY(String)) + hashmap: Mapped[Dict[str, str]] = mapped_column(JSON) + test_set: Mapped[Set[str]] = mapped_column(ARRAY(String)) + + +@serializer +def serialize_mutable_sqla_model( + model: MutableSQLAModel, +) -> Dict[str, Union[List[str], Dict[str, str]]]: + """Serialize the MutableSQLAModel. + + Args: + model: The MutableSQLAModel instance to serialize. + + Returns: + The serialized model. + """ + return {"strlist": model.strlist, "hashmap": model.hashmap} + + class MutableTestState(BaseState): """A test state.""" - array: List[Union[str, List, Dict[str, str]]] = [ + array: List[Union[str, int, List, Dict[str, str]]] = [ "value", [1, 2, 3], {"key": "value"}, ] - hashmap: Dict[str, Union[List, str, Dict[str, str]]] = { + hashmap: Dict[str, Union[List, str, Dict[str, Union[str, Dict]]]] = { "key": ["list", "of", "values"], "another_key": "another_value", "third_key": {"key": "value"}, @@ -161,6 +197,11 @@ class MutableTestState(BaseState): test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"} custom: CustomVar = CustomVar() _be_custom: CustomVar = CustomVar() + sqla_model: MutableSQLAModel = MutableSQLAModel( + strlist=["a", "b", "c"], + hashmap={"key": "value"}, + test_set={"one", "two", "three"}, + ) def reassign_mutables(self): """Assign mutable fields to different values.""" @@ -171,3 +212,8 @@ class MutableTestState(BaseState): "mod_third_key": {"key": "value"}, } self.test_set = {1, 2, 3, 4, "five"} + self.sqla_model = MutableSQLAModel( + strlist=["d", "e", "f"], + hashmap={"key": "value"}, + test_set={"one", "two", "three"}, + ) diff --git a/tests/test_state.py b/tests/test_state.py index 1254a800e..d81d88d82 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -8,7 +8,7 @@ import json import os import sys from textwrap import dedent -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Union from unittest.mock import AsyncMock, Mock import pytest @@ -40,6 +40,7 @@ from reflex.testing import chdir from reflex.utils import format, prerequisites, types from reflex.utils.format import json_dumps from reflex.vars import BaseVar, ComputedVar +from tests.states.mutation import MutableSQLAModel, MutableTestState from .states import GenState @@ -1389,7 +1390,7 @@ def test_backend_method(): assert bms._be_method() -def test_setattr_of_mutable_types(mutable_state): +def test_setattr_of_mutable_types(mutable_state: MutableTestState): """Test that mutable types are converted to corresponding Reflex wrappers. Args: @@ -1398,6 +1399,7 @@ def test_setattr_of_mutable_types(mutable_state): array = mutable_state.array hashmap = mutable_state.hashmap test_set = mutable_state.test_set + sqla_model = mutable_state.sqla_model assert isinstance(array, MutableProxy) assert isinstance(array, list) @@ -1425,11 +1427,21 @@ def test_setattr_of_mutable_types(mutable_state): assert isinstance(mutable_state.custom.test_set, set) assert isinstance(mutable_state.custom.custom, MutableProxy) + assert isinstance(sqla_model, MutableProxy) + assert isinstance(sqla_model, MutableSQLAModel) + assert isinstance(sqla_model.strlist, MutableProxy) + assert isinstance(sqla_model.strlist, list) + assert isinstance(sqla_model.hashmap, MutableProxy) + assert isinstance(sqla_model.hashmap, dict) + assert isinstance(sqla_model.test_set, MutableProxy) + assert isinstance(sqla_model.test_set, set) + mutable_state.reassign_mutables() array = mutable_state.array hashmap = mutable_state.hashmap test_set = mutable_state.test_set + sqla_model = mutable_state.sqla_model assert isinstance(array, MutableProxy) assert isinstance(array, list) @@ -1448,6 +1460,15 @@ def test_setattr_of_mutable_types(mutable_state): assert isinstance(test_set, MutableProxy) assert isinstance(test_set, set) + assert isinstance(sqla_model, MutableProxy) + assert isinstance(sqla_model, MutableSQLAModel) + assert isinstance(sqla_model.strlist, MutableProxy) + assert isinstance(sqla_model.strlist, list) + assert isinstance(sqla_model.hashmap, MutableProxy) + assert isinstance(sqla_model.hashmap, dict) + assert isinstance(sqla_model.test_set, MutableProxy) + assert isinstance(sqla_model.test_set, set) + def test_error_on_state_method_shadow(): """Test that an error is thrown when an event handler shadows a state method.""" @@ -2091,7 +2112,7 @@ async def test_background_task_no_chain(): await bts.bad_chain2() -def test_mutable_list(mutable_state): +def test_mutable_list(mutable_state: MutableTestState): """Test that mutable lists are tracked correctly. Args: @@ -2121,7 +2142,7 @@ def test_mutable_list(mutable_state): assert_array_dirty() mutable_state.array.reverse() assert_array_dirty() - mutable_state.array.sort() + mutable_state.array.sort() # type: ignore[reportCallIssue,reportUnknownMemberType] assert_array_dirty() mutable_state.array[0] = 666 assert_array_dirty() @@ -2145,7 +2166,7 @@ def test_mutable_list(mutable_state): assert_array_dirty() -def test_mutable_dict(mutable_state): +def test_mutable_dict(mutable_state: MutableTestState): """Test that mutable dicts are tracked correctly. Args: @@ -2159,40 +2180,40 @@ def test_mutable_dict(mutable_state): assert not mutable_state.dirty_vars # Test all dict operations - mutable_state.hashmap.update({"new_key": 43}) + mutable_state.hashmap.update({"new_key": "43"}) assert_hashmap_dirty() - assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value" + assert mutable_state.hashmap.setdefault("another_key", "66") == "another_value" assert_hashmap_dirty() - assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67 + assert mutable_state.hashmap.setdefault("setdefault_key", "67") == "67" assert_hashmap_dirty() - assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67 + assert mutable_state.hashmap.setdefault("setdefault_key", "68") == "67" assert_hashmap_dirty() - assert mutable_state.hashmap.pop("new_key") == 43 + assert mutable_state.hashmap.pop("new_key") == "43" assert_hashmap_dirty() mutable_state.hashmap.popitem() assert_hashmap_dirty() mutable_state.hashmap.clear() assert_hashmap_dirty() - mutable_state.hashmap["new_key"] = 42 + mutable_state.hashmap["new_key"] = "42" assert_hashmap_dirty() del mutable_state.hashmap["new_key"] assert_hashmap_dirty() if sys.version_info >= (3, 9): - mutable_state.hashmap |= {"new_key": 44} + mutable_state.hashmap |= {"new_key": "44"} assert_hashmap_dirty() # Test nested dict operations mutable_state.hashmap["array"] = [] assert_hashmap_dirty() - mutable_state.hashmap["array"].append(1) + mutable_state.hashmap["array"].append("1") assert_hashmap_dirty() mutable_state.hashmap["dict"] = {} assert_hashmap_dirty() - mutable_state.hashmap["dict"]["key"] = 42 + mutable_state.hashmap["dict"]["key"] = "42" assert_hashmap_dirty() mutable_state.hashmap["dict"]["dict"] = {} assert_hashmap_dirty() - mutable_state.hashmap["dict"]["dict"]["key"] = 43 + mutable_state.hashmap["dict"]["dict"]["key"] = "43" assert_hashmap_dirty() # Test proxy returned from `setdefault` and `get` @@ -2214,14 +2235,14 @@ def test_mutable_dict(mutable_state): mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key") assert not isinstance(mutable_value_third_ref, MutableProxy) assert_hashmap_dirty() - mutable_value_third_ref.append("baz") + mutable_value_third_ref.append("baz") # type: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnusedCallResult] assert not mutable_state.dirty_vars # Unfortunately previous refs still will mark the state dirty... nothing doing about that assert mutable_value.pop() assert_hashmap_dirty() -def test_mutable_set(mutable_state): +def test_mutable_set(mutable_state: MutableTestState): """Test that mutable sets are tracked correctly. Args: @@ -2263,7 +2284,7 @@ def test_mutable_set(mutable_state): assert_set_dirty() -def test_mutable_custom(mutable_state): +def test_mutable_custom(mutable_state: MutableTestState): """Test that mutable custom types derived from Base are tracked correctly. Args: @@ -2278,17 +2299,38 @@ def test_mutable_custom(mutable_state): mutable_state.custom.foo = "bar" assert_custom_dirty() - mutable_state.custom.array.append(42) + mutable_state.custom.array.append("42") assert_custom_dirty() - mutable_state.custom.hashmap["key"] = 68 + mutable_state.custom.hashmap["key"] = "value" assert_custom_dirty() - mutable_state.custom.test_set.add(42) + mutable_state.custom.test_set.add("foo") assert_custom_dirty() mutable_state.custom.custom.bar = "baz" assert_custom_dirty() -def test_mutable_backend(mutable_state): +def test_mutable_sqla_model(mutable_state: MutableTestState): + """Test that mutable SQLA models are tracked correctly. + + Args: + mutable_state: A test state. + """ + assert not mutable_state.dirty_vars + + def assert_sqla_model_dirty(): + assert mutable_state.dirty_vars == {"sqla_model"} + mutable_state._clean() + assert not mutable_state.dirty_vars + + mutable_state.sqla_model.strlist.append("foo") + assert_sqla_model_dirty() + mutable_state.sqla_model.hashmap["key"] = "value" + assert_sqla_model_dirty() + mutable_state.sqla_model.test_set.add("bar") + assert_sqla_model_dirty() + + +def test_mutable_backend(mutable_state: MutableTestState): """Test that mutable backend vars are tracked correctly. Args: @@ -2303,11 +2345,11 @@ def test_mutable_backend(mutable_state): mutable_state._be_custom.foo = "bar" assert_custom_dirty() - mutable_state._be_custom.array.append(42) + mutable_state._be_custom.array.append("baz") assert_custom_dirty() - mutable_state._be_custom.hashmap["key"] = 68 + mutable_state._be_custom.hashmap["key"] = "value" assert_custom_dirty() - mutable_state._be_custom.test_set.add(42) + mutable_state._be_custom.test_set.add("foo") assert_custom_dirty() mutable_state._be_custom.custom.bar = "baz" assert_custom_dirty() @@ -2320,7 +2362,7 @@ def test_mutable_backend(mutable_state): (copy.deepcopy,), ], ) -def test_mutable_copy(mutable_state, copy_func): +def test_mutable_copy(mutable_state: MutableTestState, copy_func: Callable): """Test that mutable types are copied correctly. Args: @@ -2347,7 +2389,7 @@ def test_mutable_copy(mutable_state, copy_func): (copy.deepcopy,), ], ) -def test_mutable_copy_vars(mutable_state, copy_func): +def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable): """Test that mutable types are copied correctly. Args: