From 549ab4e708f7992c1e2fe87513789fb9e6b5fc56 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Fri, 21 Jul 2023 18:47:38 +0000 Subject: [PATCH] rx.App `state` arg should not be required (#1361) --- reflex/app.py | 20 ++++++++++- reflex/constants.py | 4 +++ tests/conftest.py | 11 ++++++ tests/middleware/test_hydrate_middleware.py | 28 ++++++++------- tests/test_app.py | 39 +++++++++++++++------ 5 files changed, 77 insertions(+), 25 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index efed23b10..94194358e 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -2,6 +2,7 @@ import asyncio import inspect +import os from multiprocessing.pool import ThreadPool from typing import ( Any, @@ -42,7 +43,7 @@ from reflex.route import ( verify_route_validity, ) from reflex.state import DefaultState, State, StateManager, StateUpdate -from reflex.utils import format, types +from reflex.utils import console, format, types # Define custom types. ComponentCallable = Callable[[], Component] @@ -100,8 +101,25 @@ class App(Base): Raises: ValueError: If the event namespace is not provided in the config. + Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist + of the DefaultState and the client app state). """ super().__init__(*args, **kwargs) + state_subclasses = State.__subclasses__() + is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ + + # Special case to allow test cases have multiple subclasses of rx.State. + if not is_testing_env: + # Only the default state and the client state should be allowed as subclasses. + if len(state_subclasses) > 2: + raise ValueError( + "rx.State has been subclassed multiple times. Only one subclass is allowed" + ) + if self.state != DefaultState: + console.deprecate( + "Passing the state as keyword argument to `rx.App` is deprecated." + ) + self.state = state_subclasses[-1] # Get the config config = get_config() diff --git a/reflex/constants.py b/reflex/constants.py index f8adf5744..2763ba71c 100644 --- a/reflex/constants.py +++ b/reflex/constants.py @@ -200,6 +200,10 @@ TOKEN_EXPIRATION = 60 * 60 # The event namespace for websocket EVENT_NAMESPACE = get_value("EVENT_NAMESPACE") +# Testing variables. +# Testing os env set by pytest when running a test case. +PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" + # Env modes class Env(str, Enum): """The environment modes.""" diff --git a/tests/conftest.py b/tests/conftest.py index 6ba10f2ee..807860ac6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,9 +9,20 @@ import pytest import reflex as rx from reflex import constants +from reflex.app import App from reflex.event import EventSpec +@pytest.fixture +def app() -> App: + """A base app. + + Returns: + The app. + """ + return App() + + @pytest.fixture(scope="function") def windows_platform() -> Generator: """Check if system is windows. diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index c67244b60..8d446a6ae 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -78,25 +78,27 @@ def hydrate_middleware() -> HydrateMiddleware: @pytest.mark.asyncio @pytest.mark.parametrize( - "State, expected, event_fixture", + "test_state, expected, event_fixture", [ (TestState, {"test_state": {"num": 1}}, "event1"), (TestState2, {"test_state2": {"num": 1}}, "event2"), (TestState3, {"test_state3": {"num": 1}}, "event3"), ], ) -async def test_preprocess(State, hydrate_middleware, request, event_fixture, expected): +async def test_preprocess( + test_state, hydrate_middleware, request, event_fixture, expected +): """Test that a state hydrate event is processed correctly. Args: - State: state to process event - hydrate_middleware: instance of HydrateMiddleware - request: pytest fixture request - event_fixture: The event fixture(an Event) - expected: expected delta + test_state: State to process event. + hydrate_middleware: Instance of HydrateMiddleware. + request: Pytest fixture request. + event_fixture: The event fixture(an Event). + expected: Expected delta. """ - app = App(state=State, load_events={"index": [State.test_handler]}) - state = State() + app = App(state=test_state, load_events={"index": [test_state.test_handler]}) + state = test_state() update = await hydrate_middleware.preprocess( app=app, event=request.getfixturevalue(event_fixture), state=state @@ -120,8 +122,8 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1): """Test that a state hydrate event for multiple on-load events is processed correctly. Args: - hydrate_middleware: instance of HydrateMiddleware - event1: an Event. + hydrate_middleware: Instance of HydrateMiddleware + event1: An Event. """ app = App( state=TestState, @@ -151,8 +153,8 @@ async def test_preprocess_no_events(hydrate_middleware, event1): """Test that app without on_load is processed correctly. Args: - hydrate_middleware: instance of HydrateMiddleware - event1: an Event. + hydrate_middleware: Instance of HydrateMiddleware + event1: An Event. """ state = TestState() update = await hydrate_middleware.preprocess( diff --git a/tests/test_app.py b/tests/test_app.py index bd85ffe7d..391a45f49 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -27,16 +27,6 @@ from reflex.utils import format from reflex.vars import ComputedVar -@pytest.fixture -def app() -> App: - """A base app. - - Returns: - The app. - """ - return App() - - @pytest.fixture def index_page(): """An index page. @@ -79,6 +69,20 @@ def test_state() -> Type[State]: return TestState +@pytest.fixture() +def redundant_test_state() -> Type[State]: + """A default state. + + Returns: + A default state. + """ + + class RedundantTestState(State): + var: int + + return RedundantTestState + + @pytest.fixture() def test_model() -> Type[Model]: """A default model. @@ -170,6 +174,19 @@ def test_default_app(app: App): assert app.admin_dash is None +def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): + """Test that an error is thrown when multiple classes subclass rx.State. + + Args: + monkeypatch: Pytest monkeypatch object. + test_state: A test state subclassing rx.State. + redundant_test_state: Another test state subclassing rx.State. + """ + monkeypatch.delenv(constants.PYTEST_CURRENT_TEST) + with pytest.raises(ValueError): + App() + + def test_add_page_default_route(app: App, index_page, about_page): """Test adding a page to an app. @@ -708,7 +725,7 @@ class DynamicState(State): There are several counters: * loaded: counts how many times `on_load` was triggered by the hydrate middleware - * counter: counts how many times `on_counter` was triggered by a non-naviagational event + * counter: counts how many times `on_counter` was triggered by a non-navigational event -> these events should NOT trigger reload or recalculation of router_data dependent vars * side_effect_counter: counts how many times a computed var was recalculated when the dynamic route var was dirty