[ENG-4137] Handle generic alias passing inspect.isclass check (#4427)

On py3.9 and py3.10, `dict[str, str]` and other typing forms are kinda
considered classes, but they still fail when doing `issubclass`, so
specifically exclude generic aliases before calling issubclass.

Fix #4424

Bonus fix: support upcasting of pydantic v1 and v2 models
This commit is contained in:
Masen Furer 2024-11-23 10:48:50 -08:00 committed by GitHub
parent 000938414f
commit c7d3876fe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 7 deletions

View File

@ -1748,7 +1748,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if value is None: if value is None:
continue continue
hinted_args = value_inside_optional(hinted_args) hinted_args = value_inside_optional(hinted_args)
if isinstance(value, dict) and inspect.isclass(hinted_args): if (
isinstance(value, dict)
and inspect.isclass(hinted_args)
and not types.is_generic_alias(hinted_args) # py3.9-py3.10
):
if issubclass(hinted_args, Model): if issubclass(hinted_args, Model):
# Remove non-fields from the payload # Remove non-fields from the payload
payload[arg] = hinted_args( payload[arg] = hinted_args(
@ -1759,7 +1763,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
} }
) )
elif dataclasses.is_dataclass(hinted_args) or issubclass( elif dataclasses.is_dataclass(hinted_args) or issubclass(
hinted_args, Base hinted_args, (Base, BaseModelV1, BaseModelV2)
): ):
payload[arg] = hinted_args(**value) payload[arg] = hinted_args(**value)
if isinstance(value, list) and (hinted_args is set or hinted_args is Set): if isinstance(value, list) and (hinted_args is set or hinted_args is Set):

View File

@ -10,7 +10,17 @@ import os
import sys import sys
import threading import threading
from textwrap import dedent from textwrap import dedent
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
)
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
import pytest import pytest
@ -1828,12 +1838,11 @@ async def test_state_manager_lock_expire_contend(
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: def mock_app_simple(monkeypatch) -> rx.App:
"""Mock app fixture. """Simple Mock app fixture.
Args: Args:
monkeypatch: Pytest monkeypatch object. monkeypatch: Pytest monkeypatch object.
state_manager: A state manager.
Returns: Returns:
The app, after mocking out prerequisites.get_app() The app, after mocking out prerequisites.get_app()
@ -1844,7 +1853,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
setattr(app_module, CompileVars.APP, app) setattr(app_module, CompileVars.APP, app)
app.state = TestState app.state = TestState
app._state_manager = state_manager
app.event_namespace.emit = AsyncMock() # type: ignore app.event_namespace.emit = AsyncMock() # type: ignore
def _mock_get_app(*args, **kwargs): def _mock_get_app(*args, **kwargs):
@ -1854,6 +1862,21 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
return app return app
@pytest.fixture(scope="function")
def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
"""Mock app fixture.
Args:
mock_app_simple: A simple mock app.
state_manager: A state manager.
Returns:
The app, after mocking out prerequisites.get_app()
"""
mock_app_simple._state_manager = state_manager
return mock_app_simple
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
"""Test that the state proxy works. """Test that the state proxy works.
@ -3506,3 +3529,106 @@ def test_init_mixin() -> None:
with pytest.raises(ReflexRuntimeError): with pytest.raises(ReflexRuntimeError):
SubMixin() SubMixin()
class ReflexModel(rx.Model):
"""A model for testing."""
foo: str
class UpcastState(rx.State):
"""A state for testing upcasting."""
passed: bool = False
def rx_model(self, m: ReflexModel): # noqa: D102
assert isinstance(m, ReflexModel)
self.passed = True
def rx_base(self, o: Object): # noqa: D102
assert isinstance(o, Object)
self.passed = True
def rx_base_or_none(self, o: Optional[Object]): # noqa: D102
if o is not None:
assert isinstance(o, Object)
self.passed = True
def rx_basemodelv1(self, m: ModelV1): # noqa: D102
assert isinstance(m, ModelV1)
self.passed = True
def rx_basemodelv2(self, m: ModelV2): # noqa: D102
assert isinstance(m, ModelV2)
self.passed = True
def rx_dataclass(self, dc: ModelDC): # noqa: D102
assert isinstance(dc, ModelDC)
self.passed = True
def py_set(self, s: set): # noqa: D102
assert isinstance(s, set)
self.passed = True
def py_Set(self, s: Set): # noqa: D102
assert isinstance(s, Set)
self.passed = True
def py_tuple(self, t: tuple): # noqa: D102
assert isinstance(t, tuple)
self.passed = True
def py_Tuple(self, t: Tuple): # noqa: D102
assert isinstance(t, tuple)
self.passed = True
def py_dict(self, d: dict[str, str]): # noqa: D102
assert isinstance(d, dict)
self.passed = True
def py_list(self, ls: list[str]): # noqa: D102
assert isinstance(ls, list)
self.passed = True
def py_Any(self, a: Any): # noqa: D102
assert isinstance(a, list)
self.passed = True
def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore
assert isinstance(u, list)
self.passed = True
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_app_simple")
@pytest.mark.parametrize(
("handler", "payload"),
[
(UpcastState.rx_model, {"m": {"foo": "bar"}}),
(UpcastState.rx_base, {"o": {"foo": "bar"}}),
(UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}),
(UpcastState.rx_base_or_none, {"o": None}),
(UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}),
(UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}),
(UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}),
(UpcastState.py_set, {"s": ["foo", "foo"]}),
(UpcastState.py_Set, {"s": ["foo", "foo"]}),
(UpcastState.py_tuple, {"t": ["foo", "foo"]}),
(UpcastState.py_Tuple, {"t": ["foo", "foo"]}),
(UpcastState.py_dict, {"d": {"foo": "bar"}}),
(UpcastState.py_list, {"ls": ["foo", "foo"]}),
(UpcastState.py_Any, {"a": ["foo"]}),
(UpcastState.py_unresolvable, {"u": ["foo"]}),
],
)
async def test_upcast_event_handler_arg(handler, payload):
"""Test that upcast event handler args work correctly.
Args:
handler: The handler to test.
payload: The payload to test.
"""
state = UpcastState()
async for update in state._process_event(handler, state, payload):
assert update.delta == {UpcastState.get_full_name(): {"passed": True}}