From c7d3876fe63ff91cfe82fba45cab0752dd0d9645 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Sat, 23 Nov 2024 10:48:50 -0800 Subject: [PATCH] [ENG-4137] Handle generic alias passing inspect.isclass check (#4427) On py3.9 and py3.10, `dict[str, str]` and other typing forms are kinda considered classes, but they still fail when doing `issubclass`, so specifically exclude generic aliases before calling issubclass. Fix #4424 Bonus fix: support upcasting of pydantic v1 and v2 models --- reflex/state.py | 8 ++- tests/units/test_state.py | 136 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 137 insertions(+), 7 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 349dc59e9..95f7f64f6 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1748,7 +1748,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if value is None: continue hinted_args = value_inside_optional(hinted_args) - if isinstance(value, dict) and inspect.isclass(hinted_args): + if ( + isinstance(value, dict) + and inspect.isclass(hinted_args) + and not types.is_generic_alias(hinted_args) # py3.9-py3.10 + ): if issubclass(hinted_args, Model): # Remove non-fields from the payload payload[arg] = hinted_args( @@ -1759,7 +1763,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): } ) elif dataclasses.is_dataclass(hinted_args) or issubclass( - hinted_args, Base + hinted_args, (Base, BaseModelV1, BaseModelV2) ): payload[arg] = hinted_args(**value) if isinstance(value, list) and (hinted_args is set or hinted_args is Set): diff --git a/tests/units/test_state.py b/tests/units/test_state.py index c8a52e6c0..45c021bd8 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -10,7 +10,17 @@ import os import sys import threading from textwrap import dedent -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, +) from unittest.mock import AsyncMock, Mock import pytest @@ -1828,12 +1838,11 @@ async def test_state_manager_lock_expire_contend( @pytest.fixture(scope="function") -def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: - """Mock app fixture. +def mock_app_simple(monkeypatch) -> rx.App: + """Simple Mock app fixture. Args: monkeypatch: Pytest monkeypatch object. - state_manager: A state manager. Returns: The app, after mocking out prerequisites.get_app() @@ -1844,7 +1853,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: setattr(app_module, CompileVars.APP, app) app.state = TestState - app._state_manager = state_manager app.event_namespace.emit = AsyncMock() # type: ignore def _mock_get_app(*args, **kwargs): @@ -1854,6 +1862,21 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: return app +@pytest.fixture(scope="function") +def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App: + """Mock app fixture. + + Args: + mock_app_simple: A simple mock app. + state_manager: A state manager. + + Returns: + The app, after mocking out prerequisites.get_app() + """ + mock_app_simple._state_manager = state_manager + return mock_app_simple + + @pytest.mark.asyncio async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): """Test that the state proxy works. @@ -3506,3 +3529,106 @@ def test_init_mixin() -> None: with pytest.raises(ReflexRuntimeError): SubMixin() + + +class ReflexModel(rx.Model): + """A model for testing.""" + + foo: str + + +class UpcastState(rx.State): + """A state for testing upcasting.""" + + passed: bool = False + + def rx_model(self, m: ReflexModel): # noqa: D102 + assert isinstance(m, ReflexModel) + self.passed = True + + def rx_base(self, o: Object): # noqa: D102 + assert isinstance(o, Object) + self.passed = True + + def rx_base_or_none(self, o: Optional[Object]): # noqa: D102 + if o is not None: + assert isinstance(o, Object) + self.passed = True + + def rx_basemodelv1(self, m: ModelV1): # noqa: D102 + assert isinstance(m, ModelV1) + self.passed = True + + def rx_basemodelv2(self, m: ModelV2): # noqa: D102 + assert isinstance(m, ModelV2) + self.passed = True + + def rx_dataclass(self, dc: ModelDC): # noqa: D102 + assert isinstance(dc, ModelDC) + self.passed = True + + def py_set(self, s: set): # noqa: D102 + assert isinstance(s, set) + self.passed = True + + def py_Set(self, s: Set): # noqa: D102 + assert isinstance(s, Set) + self.passed = True + + def py_tuple(self, t: tuple): # noqa: D102 + assert isinstance(t, tuple) + self.passed = True + + def py_Tuple(self, t: Tuple): # noqa: D102 + assert isinstance(t, tuple) + self.passed = True + + def py_dict(self, d: dict[str, str]): # noqa: D102 + assert isinstance(d, dict) + self.passed = True + + def py_list(self, ls: list[str]): # noqa: D102 + assert isinstance(ls, list) + self.passed = True + + def py_Any(self, a: Any): # noqa: D102 + assert isinstance(a, list) + self.passed = True + + def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore + assert isinstance(u, list) + self.passed = True + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_app_simple") +@pytest.mark.parametrize( + ("handler", "payload"), + [ + (UpcastState.rx_model, {"m": {"foo": "bar"}}), + (UpcastState.rx_base, {"o": {"foo": "bar"}}), + (UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}), + (UpcastState.rx_base_or_none, {"o": None}), + (UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}), + (UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}), + (UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}), + (UpcastState.py_set, {"s": ["foo", "foo"]}), + (UpcastState.py_Set, {"s": ["foo", "foo"]}), + (UpcastState.py_tuple, {"t": ["foo", "foo"]}), + (UpcastState.py_Tuple, {"t": ["foo", "foo"]}), + (UpcastState.py_dict, {"d": {"foo": "bar"}}), + (UpcastState.py_list, {"ls": ["foo", "foo"]}), + (UpcastState.py_Any, {"a": ["foo"]}), + (UpcastState.py_unresolvable, {"u": ["foo"]}), + ], +) +async def test_upcast_event_handler_arg(handler, payload): + """Test that upcast event handler args work correctly. + + Args: + handler: The handler to test. + payload: The payload to test. + """ + state = UpcastState() + async for update in state._process_event(handler, state, payload): + assert update.delta == {UpcastState.get_full_name(): {"passed": True}}