From 4254eadce3019f6651d17084e4c35152bfba5d9b Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Wed, 30 Oct 2024 16:52:16 -0700 Subject: [PATCH] use better typing for on_load (#4274) * use better typing for on_load * make app dataclass * get it right pyright * make lifespan into a dataclass --- reflex/app.py | 60 ++++++++++++++++----------------- reflex/app_mixins/lifespan.py | 6 +++- reflex/app_mixins/middleware.py | 4 ++- reflex/app_mixins/mixin.py | 5 +-- reflex/page.py | 3 +- tests/units/test_app.py | 9 +---- tests/units/test_state.py | 1 + 7 files changed, 45 insertions(+), 43 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 88fc9d473..5367ef20b 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -46,7 +46,6 @@ from starlette_admin.contrib.sqla.view import ModelView from reflex import constants from reflex.admin import AdminDash from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin -from reflex.base import Base from reflex.compiler import compiler from reflex.compiler import utils as compiler_utils from reflex.compiler.compiler import ( @@ -70,7 +69,14 @@ from reflex.components.core.client_side_routing import ( from reflex.components.core.upload import Upload, get_upload_dir from reflex.components.radix import themes from reflex.config import environment, get_config -from reflex.event import Event, EventHandler, EventSpec, window_alert +from reflex.event import ( + Event, + EventHandler, + EventSpec, + EventType, + IndividualEventType, + window_alert, +) from reflex.model import Model, get_db_status from reflex.page import ( DECORATED_PAGES, @@ -189,11 +195,12 @@ class UnevaluatedPage: title: Union[Var, str, None] description: Union[Var, str, None] image: str - on_load: Union[EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], None] + on_load: Union[EventType[[]], None] meta: List[Dict[str, str]] -class App(MiddlewareMixin, LifespanMixin, Base): +@dataclasses.dataclass() +class App(MiddlewareMixin, LifespanMixin): """The main Reflex app that encapsulates the backend and frontend. Every Reflex app needs an app defined in its main module. @@ -215,24 +222,26 @@ class App(MiddlewareMixin, LifespanMixin, Base): """ # The global [theme](https://reflex.dev/docs/styling/theming/#theme) for the entire app. - theme: Optional[Component] = themes.theme(accent_color="blue") + theme: Optional[Component] = dataclasses.field( + default_factory=lambda: themes.theme(accent_color="blue") + ) # The [global style](https://reflex.dev/docs/styling/overview/#global-styles}) for the app. - style: ComponentStyle = {} + style: ComponentStyle = dataclasses.field(default_factory=dict) # A list of URLs to [stylesheets](https://reflex.dev/docs/styling/custom-stylesheets/) to include in the app. - stylesheets: List[str] = [] + stylesheets: List[str] = dataclasses.field(default_factory=list) # A component that is present on every page (defaults to the Connection Error banner). overlay_component: Optional[Union[Component, ComponentCallable]] = ( - default_overlay_component() + dataclasses.field(default_factory=default_overlay_component) ) # Error boundary component to wrap the app with. error_boundary: Optional[ComponentCallable] = default_error_boundary # Components to add to the head of every page. - head_components: List[Component] = [] + head_components: List[Component] = dataclasses.field(default_factory=list) # The Socket.IO AsyncServer instance. sio: Optional[AsyncServer] = None @@ -244,10 +253,12 @@ class App(MiddlewareMixin, LifespanMixin, Base): html_custom_attrs: Optional[Dict[str, str]] = None # A map from a route to an unevaluated page. PRIVATE. - unevaluated_pages: Dict[str, UnevaluatedPage] = {} + unevaluated_pages: Dict[str, UnevaluatedPage] = dataclasses.field( + default_factory=dict + ) # A map from a page route to the component to render. Users should use `add_page`. PRIVATE. - pages: Dict[str, Component] = {} + pages: Dict[str, Component] = dataclasses.field(default_factory=dict) # The backend API object. PRIVATE. api: FastAPI = None # type: ignore @@ -259,7 +270,9 @@ class App(MiddlewareMixin, LifespanMixin, Base): _state_manager: Optional[StateManager] = None # Mapping from a route to event handlers to trigger when the page loads. PRIVATE. - load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {} + load_events: Dict[str, List[IndividualEventType[[]]]] = dataclasses.field( + default_factory=dict + ) # Admin dashboard to view and manage the database. PRIVATE. admin_dash: Optional[AdminDash] = None @@ -268,7 +281,7 @@ class App(MiddlewareMixin, LifespanMixin, Base): event_namespace: Optional[EventNamespace] = None # Background tasks that are currently running. PRIVATE. - background_tasks: Set[asyncio.Task] = set() + background_tasks: Set[asyncio.Task] = dataclasses.field(default_factory=set) # Frontend Error Handler Function frontend_exception_handler: Callable[[Exception], None] = ( @@ -280,23 +293,14 @@ class App(MiddlewareMixin, LifespanMixin, Base): [Exception], Union[EventSpec, List[EventSpec], None] ] = default_backend_exception_handler - def __init__(self, **kwargs): + def __post_init__(self): """Initialize the app. - Args: - **kwargs: Kwargs to initialize the app with. - Raises: ValueError: If the event namespace is not provided in the config. Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist of the DefaultState and the client app state). """ - if "connect_error_component" in kwargs: - raise ValueError( - "`connect_error_component` is deprecated, use `overlay_component` instead" - ) - super().__init__(**kwargs) - # Special case to allow test cases have multiple subclasses of rx.BaseState. if not is_testing_env() and BaseState.__subclasses__() != [State]: # Only rx.State is allowed as Base State subclass. @@ -471,9 +475,7 @@ class App(MiddlewareMixin, LifespanMixin, Base): title: str | Var | None = None, description: str | Var | None = None, image: str = constants.DefaultPage.IMAGE, - on_load: ( - EventHandler | EventSpec | list[EventHandler | EventSpec] | None - ) = None, + on_load: EventType[[]] | None = None, meta: list[dict[str, str]] = constants.DefaultPage.META_LIST, ): """Add a page to the app. @@ -559,7 +561,7 @@ class App(MiddlewareMixin, LifespanMixin, Base): self._check_routes_conflict(route) self.pages[route] = component - def get_load_events(self, route: str) -> list[EventHandler | EventSpec]: + def get_load_events(self, route: str) -> list[IndividualEventType[[]]]: """Get the load events for a route. Args: @@ -618,9 +620,7 @@ class App(MiddlewareMixin, LifespanMixin, Base): title: str = constants.Page404.TITLE, image: str = constants.Page404.IMAGE, description: str = constants.Page404.DESCRIPTION, - on_load: ( - EventHandler | EventSpec | list[EventHandler | EventSpec] | None - ) = None, + on_load: EventType[[]] | None = None, meta: list[dict[str, str]] = constants.DefaultPage.META_LIST, ): """Define a custom 404 page for any url having no match. diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index ef882a2ea..52bf0be1d 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio import contextlib +import dataclasses import functools import inspect from typing import Callable, Coroutine, Set, Union @@ -16,11 +17,14 @@ from reflex.utils.exceptions import InvalidLifespanTaskType from .mixin import AppMixin +@dataclasses.dataclass class LifespanMixin(AppMixin): """A Mixin that allow tasks to run during the whole app lifespan.""" # Lifespan tasks that are planned to run. - lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set() + lifespan_tasks: Set[Union[asyncio.Task, Callable]] = dataclasses.field( + default_factory=set + ) @contextlib.asynccontextmanager async def _run_lifespan_tasks(self, app: FastAPI): diff --git a/reflex/app_mixins/middleware.py b/reflex/app_mixins/middleware.py index 1e42faf18..30593d9ae 100644 --- a/reflex/app_mixins/middleware.py +++ b/reflex/app_mixins/middleware.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import dataclasses from typing import List from reflex.event import Event @@ -12,11 +13,12 @@ from reflex.state import BaseState, StateUpdate from .mixin import AppMixin +@dataclasses.dataclass class MiddlewareMixin(AppMixin): """Middleware Mixin that allow to add middleware to the app.""" # Middleware to add to the app. Users should use `add_middleware`. PRIVATE. - middleware: List[Middleware] = [] + middleware: List[Middleware] = dataclasses.field(default_factory=list) def _init_mixin(self): self.middleware.append(HydrateMiddleware()) diff --git a/reflex/app_mixins/mixin.py b/reflex/app_mixins/mixin.py index ed301c495..23207a462 100644 --- a/reflex/app_mixins/mixin.py +++ b/reflex/app_mixins/mixin.py @@ -1,9 +1,10 @@ """Default mixin for all app mixins.""" -from reflex.base import Base +import dataclasses -class AppMixin(Base): +@dataclasses.dataclass +class AppMixin: """Define the base class for all app mixins.""" def _init_mixin(self): diff --git a/reflex/page.py b/reflex/page.py index 52f0c8efc..87a8c49c2 100644 --- a/reflex/page.py +++ b/reflex/page.py @@ -6,6 +6,7 @@ from collections import defaultdict from typing import Any, Dict, List from reflex.config import get_config +from reflex.event import EventType DECORATED_PAGES: Dict[str, List] = defaultdict(list) @@ -17,7 +18,7 @@ def page( description: str | None = None, meta: list[Any] | None = None, script_tags: list[Any] | None = None, - on_load: Any | list[Any] | None = None, + on_load: EventType[[]] | None = None, ): """Decorate a function as a page. diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 7fba7ba1d..1e34a67c3 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1211,7 +1211,7 @@ async def test_process_events(mocker, token: str): ], ) def test_overlay_component( - state: State | None, + state: Type[State] | None, overlay_component: Component | ComponentCallable | None, exp_page_child: Type[Component] | None, ): @@ -1403,13 +1403,6 @@ def test_app_state_determination(): assert a4.state is not None -# for coverage -def test_raise_on_connect_error(): - """Test that the connect_error function is called.""" - with pytest.raises(ValueError): - App(connect_error_component="Foo") - - def test_raise_on_state(): """Test that the state is set.""" # state kwargs is deprecated, we just make sure the app is created anyway. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index fe2f652ac..271f2e794 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2725,6 +2725,7 @@ class OnLoadState(State): num: int = 0 + @rx.event def test_handler(self): """Test handler.""" self.num += 1