use better typing for on_load (#4274)
* use better typing for on_load * make app dataclass * get it right pyright * make lifespan into a dataclass
This commit is contained in:
parent
2ab662b757
commit
4254eadce3
@ -46,7 +46,6 @@ from starlette_admin.contrib.sqla.view import ModelView
|
||||
from reflex import constants
|
||||
from reflex.admin import AdminDash
|
||||
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
|
||||
from reflex.base import Base
|
||||
from reflex.compiler import compiler
|
||||
from reflex.compiler import utils as compiler_utils
|
||||
from reflex.compiler.compiler import (
|
||||
@ -70,7 +69,14 @@ from reflex.components.core.client_side_routing import (
|
||||
from reflex.components.core.upload import Upload, get_upload_dir
|
||||
from reflex.components.radix import themes
|
||||
from reflex.config import environment, get_config
|
||||
from reflex.event import Event, EventHandler, EventSpec, window_alert
|
||||
from reflex.event import (
|
||||
Event,
|
||||
EventHandler,
|
||||
EventSpec,
|
||||
EventType,
|
||||
IndividualEventType,
|
||||
window_alert,
|
||||
)
|
||||
from reflex.model import Model, get_db_status
|
||||
from reflex.page import (
|
||||
DECORATED_PAGES,
|
||||
@ -189,11 +195,12 @@ class UnevaluatedPage:
|
||||
title: Union[Var, str, None]
|
||||
description: Union[Var, str, None]
|
||||
image: str
|
||||
on_load: Union[EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], None]
|
||||
on_load: Union[EventType[[]], None]
|
||||
meta: List[Dict[str, str]]
|
||||
|
||||
|
||||
class App(MiddlewareMixin, LifespanMixin, Base):
|
||||
@dataclasses.dataclass()
|
||||
class App(MiddlewareMixin, LifespanMixin):
|
||||
"""The main Reflex app that encapsulates the backend and frontend.
|
||||
|
||||
Every Reflex app needs an app defined in its main module.
|
||||
@ -215,24 +222,26 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
||||
"""
|
||||
|
||||
# The global [theme](https://reflex.dev/docs/styling/theming/#theme) for the entire app.
|
||||
theme: Optional[Component] = themes.theme(accent_color="blue")
|
||||
theme: Optional[Component] = dataclasses.field(
|
||||
default_factory=lambda: themes.theme(accent_color="blue")
|
||||
)
|
||||
|
||||
# The [global style](https://reflex.dev/docs/styling/overview/#global-styles}) for the app.
|
||||
style: ComponentStyle = {}
|
||||
style: ComponentStyle = dataclasses.field(default_factory=dict)
|
||||
|
||||
# A list of URLs to [stylesheets](https://reflex.dev/docs/styling/custom-stylesheets/) to include in the app.
|
||||
stylesheets: List[str] = []
|
||||
stylesheets: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
# A component that is present on every page (defaults to the Connection Error banner).
|
||||
overlay_component: Optional[Union[Component, ComponentCallable]] = (
|
||||
default_overlay_component()
|
||||
dataclasses.field(default_factory=default_overlay_component)
|
||||
)
|
||||
|
||||
# Error boundary component to wrap the app with.
|
||||
error_boundary: Optional[ComponentCallable] = default_error_boundary
|
||||
|
||||
# Components to add to the head of every page.
|
||||
head_components: List[Component] = []
|
||||
head_components: List[Component] = dataclasses.field(default_factory=list)
|
||||
|
||||
# The Socket.IO AsyncServer instance.
|
||||
sio: Optional[AsyncServer] = None
|
||||
@ -244,10 +253,12 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
||||
html_custom_attrs: Optional[Dict[str, str]] = None
|
||||
|
||||
# A map from a route to an unevaluated page. PRIVATE.
|
||||
unevaluated_pages: Dict[str, UnevaluatedPage] = {}
|
||||
unevaluated_pages: Dict[str, UnevaluatedPage] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
# A map from a page route to the component to render. Users should use `add_page`. PRIVATE.
|
||||
pages: Dict[str, Component] = {}
|
||||
pages: Dict[str, Component] = dataclasses.field(default_factory=dict)
|
||||
|
||||
# The backend API object. PRIVATE.
|
||||
api: FastAPI = None # type: ignore
|
||||
@ -259,7 +270,9 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
||||
_state_manager: Optional[StateManager] = None
|
||||
|
||||
# Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
|
||||
load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
|
||||
load_events: Dict[str, List[IndividualEventType[[]]]] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
# Admin dashboard to view and manage the database. PRIVATE.
|
||||
admin_dash: Optional[AdminDash] = None
|
||||
@ -268,7 +281,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
||||
event_namespace: Optional[EventNamespace] = None
|
||||
|
||||
# Background tasks that are currently running. PRIVATE.
|
||||
background_tasks: Set[asyncio.Task] = set()
|
||||
background_tasks: Set[asyncio.Task] = dataclasses.field(default_factory=set)
|
||||
|
||||
# Frontend Error Handler Function
|
||||
frontend_exception_handler: Callable[[Exception], None] = (
|
||||
@ -280,23 +293,14 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
||||
[Exception], Union[EventSpec, List[EventSpec], None]
|
||||
] = default_backend_exception_handler
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __post_init__(self):
|
||||
"""Initialize the app.
|
||||
|
||||
Args:
|
||||
**kwargs: Kwargs to initialize the app with.
|
||||
|
||||
Raises:
|
||||
ValueError: If the event namespace is not provided in the config.
|
||||
Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
|
||||
of the DefaultState and the client app state).
|
||||
"""
|
||||
if "connect_error_component" in kwargs:
|
||||
raise ValueError(
|
||||
"`connect_error_component` is deprecated, use `overlay_component` instead"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Special case to allow test cases have multiple subclasses of rx.BaseState.
|
||||
if not is_testing_env() and BaseState.__subclasses__() != [State]:
|
||||
# Only rx.State is allowed as Base State subclass.
|
||||
@ -471,9 +475,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
||||
title: str | Var | None = None,
|
||||
description: str | Var | None = None,
|
||||
image: str = constants.DefaultPage.IMAGE,
|
||||
on_load: (
|
||||
EventHandler | EventSpec | list[EventHandler | EventSpec] | None
|
||||
) = None,
|
||||
on_load: EventType[[]] | None = None,
|
||||
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
|
||||
):
|
||||
"""Add a page to the app.
|
||||
@ -559,7 +561,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
||||
self._check_routes_conflict(route)
|
||||
self.pages[route] = component
|
||||
|
||||
def get_load_events(self, route: str) -> list[EventHandler | EventSpec]:
|
||||
def get_load_events(self, route: str) -> list[IndividualEventType[[]]]:
|
||||
"""Get the load events for a route.
|
||||
|
||||
Args:
|
||||
@ -618,9 +620,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
||||
title: str = constants.Page404.TITLE,
|
||||
image: str = constants.Page404.IMAGE,
|
||||
description: str = constants.Page404.DESCRIPTION,
|
||||
on_load: (
|
||||
EventHandler | EventSpec | list[EventHandler | EventSpec] | None
|
||||
) = None,
|
||||
on_load: EventType[[]] | None = None,
|
||||
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
|
||||
):
|
||||
"""Define a custom 404 page for any url having no match.
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Callable, Coroutine, Set, Union
|
||||
@ -16,11 +17,14 @@ from reflex.utils.exceptions import InvalidLifespanTaskType
|
||||
from .mixin import AppMixin
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LifespanMixin(AppMixin):
|
||||
"""A Mixin that allow tasks to run during the whole app lifespan."""
|
||||
|
||||
# Lifespan tasks that are planned to run.
|
||||
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
|
||||
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = dataclasses.field(
|
||||
default_factory=set
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _run_lifespan_tasks(self, app: FastAPI):
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
from reflex.event import Event
|
||||
@ -12,11 +13,12 @@ from reflex.state import BaseState, StateUpdate
|
||||
from .mixin import AppMixin
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MiddlewareMixin(AppMixin):
|
||||
"""Middleware Mixin that allow to add middleware to the app."""
|
||||
|
||||
# Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
|
||||
middleware: List[Middleware] = []
|
||||
middleware: List[Middleware] = dataclasses.field(default_factory=list)
|
||||
|
||||
def _init_mixin(self):
|
||||
self.middleware.append(HydrateMiddleware())
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Default mixin for all app mixins."""
|
||||
|
||||
from reflex.base import Base
|
||||
import dataclasses
|
||||
|
||||
|
||||
class AppMixin(Base):
|
||||
@dataclasses.dataclass
|
||||
class AppMixin:
|
||||
"""Define the base class for all app mixins."""
|
||||
|
||||
def _init_mixin(self):
|
||||
|
@ -6,6 +6,7 @@ from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from reflex.config import get_config
|
||||
from reflex.event import EventType
|
||||
|
||||
DECORATED_PAGES: Dict[str, List] = defaultdict(list)
|
||||
|
||||
@ -17,7 +18,7 @@ def page(
|
||||
description: str | None = None,
|
||||
meta: list[Any] | None = None,
|
||||
script_tags: list[Any] | None = None,
|
||||
on_load: Any | list[Any] | None = None,
|
||||
on_load: EventType[[]] | None = None,
|
||||
):
|
||||
"""Decorate a function as a page.
|
||||
|
||||
|
@ -1211,7 +1211,7 @@ async def test_process_events(mocker, token: str):
|
||||
],
|
||||
)
|
||||
def test_overlay_component(
|
||||
state: State | None,
|
||||
state: Type[State] | None,
|
||||
overlay_component: Component | ComponentCallable | None,
|
||||
exp_page_child: Type[Component] | None,
|
||||
):
|
||||
@ -1403,13 +1403,6 @@ def test_app_state_determination():
|
||||
assert a4.state is not None
|
||||
|
||||
|
||||
# for coverage
|
||||
def test_raise_on_connect_error():
|
||||
"""Test that the connect_error function is called."""
|
||||
with pytest.raises(ValueError):
|
||||
App(connect_error_component="Foo")
|
||||
|
||||
|
||||
def test_raise_on_state():
|
||||
"""Test that the state is set."""
|
||||
# state kwargs is deprecated, we just make sure the app is created anyway.
|
||||
|
@ -2725,6 +2725,7 @@ class OnLoadState(State):
|
||||
|
||||
num: int = 0
|
||||
|
||||
@rx.event
|
||||
def test_handler(self):
|
||||
"""Test handler."""
|
||||
self.num += 1
|
||||
|
Loading…
Reference in New Issue
Block a user