diff --git a/reflex/app.py b/reflex/app.py index afc40e3b8..fc8efb420 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1462,10 +1462,10 @@ class EventNamespace(AsyncNamespace): app: App # Keep a mapping between socket ID and client token. - token_to_sid: dict[str, str] = {} + token_to_sid: dict[str, str] # Keep a mapping between client token and socket ID. - sid_to_token: dict[str, str] = {} + sid_to_token: dict[str, str] def __init__(self, namespace: str, app: App): """Initialize the event namespace. @@ -1475,6 +1475,8 @@ class EventNamespace(AsyncNamespace): app: The application object. """ super().__init__(namespace) + self.token_to_sid = {} + self.sid_to_token = {} self.app = app def on_connect(self, sid, environ): diff --git a/reflex/state.py b/reflex/state.py index 4c2fb8f59..c1e423f9e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1806,7 +1806,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if value is None: continue hinted_args = value_inside_optional(hinted_args) - if isinstance(value, dict) and inspect.isclass(hinted_args): + if ( + isinstance(value, dict) + and inspect.isclass(hinted_args) + and not types.is_generic_alias(hinted_args) # py3.9-py3.10 + ): if issubclass(hinted_args, Model): # Remove non-fields from the payload payload[arg] = hinted_args( @@ -1817,7 +1821,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): } ) elif dataclasses.is_dataclass(hinted_args) or issubclass( - hinted_args, Base + hinted_args, (Base, BaseModelV1, BaseModelV2) ): payload[arg] = hinted_args(**value) if isinstance(value, list) and (hinted_args is set or hinted_args is Set): diff --git a/reflex/utils/console.py b/reflex/utils/console.py index 04e590910..b3ba7163d 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -26,7 +26,22 @@ def set_log_level(log_level: LogLevel): Args: log_level: The log level to set. + + Raises: + ValueError: If the log level is invalid. """ + if not isinstance(log_level, LogLevel): + deprecate( + feature_name="Passing a string to set_log_level", + reason="use reflex.constants.LogLevel enum instead", + deprecation_version="0.6.6", + removal_version="0.7.0", + ) + try: + log_level = getattr(LogLevel, log_level.upper()) + except AttributeError as ae: + raise ValueError(f"Invalid log level: {log_level}") from ae + global _LOG_LEVEL _LOG_LEVEL = log_level diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 41f5444ca..e79878c98 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -10,7 +10,17 @@ import os import sys import threading from textwrap import dedent -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, +) from unittest.mock import AsyncMock, Mock import pytest @@ -1829,12 +1839,11 @@ async def test_state_manager_lock_expire_contend( @pytest.fixture(scope="function") -def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: - """Mock app fixture. +def mock_app_simple(monkeypatch) -> rx.App: + """Simple Mock app fixture. Args: monkeypatch: Pytest monkeypatch object. - state_manager: A state manager. Returns: The app, after mocking out prerequisites.get_app() @@ -1845,7 +1854,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: setattr(app_module, CompileVars.APP, app) app.state = TestState - app._state_manager = state_manager app.event_namespace.emit = AsyncMock() # type: ignore def _mock_get_app(*args, **kwargs): @@ -1855,6 +1863,21 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: return app +@pytest.fixture(scope="function") +def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App: + """Mock app fixture. + + Args: + mock_app_simple: A simple mock app. + state_manager: A state manager. + + Returns: + The app, after mocking out prerequisites.get_app() + """ + mock_app_simple._state_manager = state_manager + return mock_app_simple + + @pytest.mark.asyncio async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): """Test that the state proxy works. @@ -3507,3 +3530,106 @@ def test_init_mixin() -> None: with pytest.raises(ReflexRuntimeError): SubMixin() + + +class ReflexModel(rx.Model): + """A model for testing.""" + + foo: str + + +class UpcastState(rx.State): + """A state for testing upcasting.""" + + passed: bool = False + + def rx_model(self, m: ReflexModel): # noqa: D102 + assert isinstance(m, ReflexModel) + self.passed = True + + def rx_base(self, o: Object): # noqa: D102 + assert isinstance(o, Object) + self.passed = True + + def rx_base_or_none(self, o: Optional[Object]): # noqa: D102 + if o is not None: + assert isinstance(o, Object) + self.passed = True + + def rx_basemodelv1(self, m: ModelV1): # noqa: D102 + assert isinstance(m, ModelV1) + self.passed = True + + def rx_basemodelv2(self, m: ModelV2): # noqa: D102 + assert isinstance(m, ModelV2) + self.passed = True + + def rx_dataclass(self, dc: ModelDC): # noqa: D102 + assert isinstance(dc, ModelDC) + self.passed = True + + def py_set(self, s: set): # noqa: D102 + assert isinstance(s, set) + self.passed = True + + def py_Set(self, s: Set): # noqa: D102 + assert isinstance(s, Set) + self.passed = True + + def py_tuple(self, t: tuple): # noqa: D102 + assert isinstance(t, tuple) + self.passed = True + + def py_Tuple(self, t: Tuple): # noqa: D102 + assert isinstance(t, tuple) + self.passed = True + + def py_dict(self, d: dict[str, str]): # noqa: D102 + assert isinstance(d, dict) + self.passed = True + + def py_list(self, ls: list[str]): # noqa: D102 + assert isinstance(ls, list) + self.passed = True + + def py_Any(self, a: Any): # noqa: D102 + assert isinstance(a, list) + self.passed = True + + def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore + assert isinstance(u, list) + self.passed = True + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_app_simple") +@pytest.mark.parametrize( + ("handler", "payload"), + [ + (UpcastState.rx_model, {"m": {"foo": "bar"}}), + (UpcastState.rx_base, {"o": {"foo": "bar"}}), + (UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}), + (UpcastState.rx_base_or_none, {"o": None}), + (UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}), + (UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}), + (UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}), + (UpcastState.py_set, {"s": ["foo", "foo"]}), + (UpcastState.py_Set, {"s": ["foo", "foo"]}), + (UpcastState.py_tuple, {"t": ["foo", "foo"]}), + (UpcastState.py_Tuple, {"t": ["foo", "foo"]}), + (UpcastState.py_dict, {"d": {"foo": "bar"}}), + (UpcastState.py_list, {"ls": ["foo", "foo"]}), + (UpcastState.py_Any, {"a": ["foo"]}), + (UpcastState.py_unresolvable, {"u": ["foo"]}), + ], +) +async def test_upcast_event_handler_arg(handler, payload): + """Test that upcast event handler args work correctly. + + Args: + handler: The handler to test. + payload: The payload to test. + """ + state = UpcastState() + async for update in state._process_event(handler, state, payload): + assert update.delta == {UpcastState.get_full_name(): {"passed": True}}