use better typing for on_load (#4274)
* use better typing for on_load * make app dataclass * get it right pyright * make lifespan into a dataclass
This commit is contained in:
parent
2ab662b757
commit
4254eadce3
@ -46,7 +46,6 @@ from starlette_admin.contrib.sqla.view import ModelView
|
|||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.admin import AdminDash
|
from reflex.admin import AdminDash
|
||||||
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
|
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
|
||||||
from reflex.base import Base
|
|
||||||
from reflex.compiler import compiler
|
from reflex.compiler import compiler
|
||||||
from reflex.compiler import utils as compiler_utils
|
from reflex.compiler import utils as compiler_utils
|
||||||
from reflex.compiler.compiler import (
|
from reflex.compiler.compiler import (
|
||||||
@ -70,7 +69,14 @@ from reflex.components.core.client_side_routing import (
|
|||||||
from reflex.components.core.upload import Upload, get_upload_dir
|
from reflex.components.core.upload import Upload, get_upload_dir
|
||||||
from reflex.components.radix import themes
|
from reflex.components.radix import themes
|
||||||
from reflex.config import environment, get_config
|
from reflex.config import environment, get_config
|
||||||
from reflex.event import Event, EventHandler, EventSpec, window_alert
|
from reflex.event import (
|
||||||
|
Event,
|
||||||
|
EventHandler,
|
||||||
|
EventSpec,
|
||||||
|
EventType,
|
||||||
|
IndividualEventType,
|
||||||
|
window_alert,
|
||||||
|
)
|
||||||
from reflex.model import Model, get_db_status
|
from reflex.model import Model, get_db_status
|
||||||
from reflex.page import (
|
from reflex.page import (
|
||||||
DECORATED_PAGES,
|
DECORATED_PAGES,
|
||||||
@ -189,11 +195,12 @@ class UnevaluatedPage:
|
|||||||
title: Union[Var, str, None]
|
title: Union[Var, str, None]
|
||||||
description: Union[Var, str, None]
|
description: Union[Var, str, None]
|
||||||
image: str
|
image: str
|
||||||
on_load: Union[EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], None]
|
on_load: Union[EventType[[]], None]
|
||||||
meta: List[Dict[str, str]]
|
meta: List[Dict[str, str]]
|
||||||
|
|
||||||
|
|
||||||
class App(MiddlewareMixin, LifespanMixin, Base):
|
@dataclasses.dataclass()
|
||||||
|
class App(MiddlewareMixin, LifespanMixin):
|
||||||
"""The main Reflex app that encapsulates the backend and frontend.
|
"""The main Reflex app that encapsulates the backend and frontend.
|
||||||
|
|
||||||
Every Reflex app needs an app defined in its main module.
|
Every Reflex app needs an app defined in its main module.
|
||||||
@ -215,24 +222,26 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# The global [theme](https://reflex.dev/docs/styling/theming/#theme) for the entire app.
|
# The global [theme](https://reflex.dev/docs/styling/theming/#theme) for the entire app.
|
||||||
theme: Optional[Component] = themes.theme(accent_color="blue")
|
theme: Optional[Component] = dataclasses.field(
|
||||||
|
default_factory=lambda: themes.theme(accent_color="blue")
|
||||||
|
)
|
||||||
|
|
||||||
# The [global style](https://reflex.dev/docs/styling/overview/#global-styles}) for the app.
|
# The [global style](https://reflex.dev/docs/styling/overview/#global-styles}) for the app.
|
||||||
style: ComponentStyle = {}
|
style: ComponentStyle = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
# A list of URLs to [stylesheets](https://reflex.dev/docs/styling/custom-stylesheets/) to include in the app.
|
# A list of URLs to [stylesheets](https://reflex.dev/docs/styling/custom-stylesheets/) to include in the app.
|
||||||
stylesheets: List[str] = []
|
stylesheets: List[str] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
# A component that is present on every page (defaults to the Connection Error banner).
|
# A component that is present on every page (defaults to the Connection Error banner).
|
||||||
overlay_component: Optional[Union[Component, ComponentCallable]] = (
|
overlay_component: Optional[Union[Component, ComponentCallable]] = (
|
||||||
default_overlay_component()
|
dataclasses.field(default_factory=default_overlay_component)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Error boundary component to wrap the app with.
|
# Error boundary component to wrap the app with.
|
||||||
error_boundary: Optional[ComponentCallable] = default_error_boundary
|
error_boundary: Optional[ComponentCallable] = default_error_boundary
|
||||||
|
|
||||||
# Components to add to the head of every page.
|
# Components to add to the head of every page.
|
||||||
head_components: List[Component] = []
|
head_components: List[Component] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
# The Socket.IO AsyncServer instance.
|
# The Socket.IO AsyncServer instance.
|
||||||
sio: Optional[AsyncServer] = None
|
sio: Optional[AsyncServer] = None
|
||||||
@ -244,10 +253,12 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
|||||||
html_custom_attrs: Optional[Dict[str, str]] = None
|
html_custom_attrs: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
# A map from a route to an unevaluated page. PRIVATE.
|
# A map from a route to an unevaluated page. PRIVATE.
|
||||||
unevaluated_pages: Dict[str, UnevaluatedPage] = {}
|
unevaluated_pages: Dict[str, UnevaluatedPage] = dataclasses.field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
|
||||||
# A map from a page route to the component to render. Users should use `add_page`. PRIVATE.
|
# A map from a page route to the component to render. Users should use `add_page`. PRIVATE.
|
||||||
pages: Dict[str, Component] = {}
|
pages: Dict[str, Component] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
# The backend API object. PRIVATE.
|
# The backend API object. PRIVATE.
|
||||||
api: FastAPI = None # type: ignore
|
api: FastAPI = None # type: ignore
|
||||||
@ -259,7 +270,9 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
|||||||
_state_manager: Optional[StateManager] = None
|
_state_manager: Optional[StateManager] = None
|
||||||
|
|
||||||
# Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
|
# Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
|
||||||
load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
|
load_events: Dict[str, List[IndividualEventType[[]]]] = dataclasses.field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
|
||||||
# Admin dashboard to view and manage the database. PRIVATE.
|
# Admin dashboard to view and manage the database. PRIVATE.
|
||||||
admin_dash: Optional[AdminDash] = None
|
admin_dash: Optional[AdminDash] = None
|
||||||
@ -268,7 +281,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
|||||||
event_namespace: Optional[EventNamespace] = None
|
event_namespace: Optional[EventNamespace] = None
|
||||||
|
|
||||||
# Background tasks that are currently running. PRIVATE.
|
# Background tasks that are currently running. PRIVATE.
|
||||||
background_tasks: Set[asyncio.Task] = set()
|
background_tasks: Set[asyncio.Task] = dataclasses.field(default_factory=set)
|
||||||
|
|
||||||
# Frontend Error Handler Function
|
# Frontend Error Handler Function
|
||||||
frontend_exception_handler: Callable[[Exception], None] = (
|
frontend_exception_handler: Callable[[Exception], None] = (
|
||||||
@ -280,23 +293,14 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
|||||||
[Exception], Union[EventSpec, List[EventSpec], None]
|
[Exception], Union[EventSpec, List[EventSpec], None]
|
||||||
] = default_backend_exception_handler
|
] = default_backend_exception_handler
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __post_init__(self):
|
||||||
"""Initialize the app.
|
"""Initialize the app.
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: Kwargs to initialize the app with.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the event namespace is not provided in the config.
|
ValueError: If the event namespace is not provided in the config.
|
||||||
Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
|
Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
|
||||||
of the DefaultState and the client app state).
|
of the DefaultState and the client app state).
|
||||||
"""
|
"""
|
||||||
if "connect_error_component" in kwargs:
|
|
||||||
raise ValueError(
|
|
||||||
"`connect_error_component` is deprecated, use `overlay_component` instead"
|
|
||||||
)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
# Special case to allow test cases have multiple subclasses of rx.BaseState.
|
# Special case to allow test cases have multiple subclasses of rx.BaseState.
|
||||||
if not is_testing_env() and BaseState.__subclasses__() != [State]:
|
if not is_testing_env() and BaseState.__subclasses__() != [State]:
|
||||||
# Only rx.State is allowed as Base State subclass.
|
# Only rx.State is allowed as Base State subclass.
|
||||||
@ -471,9 +475,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
|||||||
title: str | Var | None = None,
|
title: str | Var | None = None,
|
||||||
description: str | Var | None = None,
|
description: str | Var | None = None,
|
||||||
image: str = constants.DefaultPage.IMAGE,
|
image: str = constants.DefaultPage.IMAGE,
|
||||||
on_load: (
|
on_load: EventType[[]] | None = None,
|
||||||
EventHandler | EventSpec | list[EventHandler | EventSpec] | None
|
|
||||||
) = None,
|
|
||||||
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
|
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
|
||||||
):
|
):
|
||||||
"""Add a page to the app.
|
"""Add a page to the app.
|
||||||
@ -559,7 +561,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
|||||||
self._check_routes_conflict(route)
|
self._check_routes_conflict(route)
|
||||||
self.pages[route] = component
|
self.pages[route] = component
|
||||||
|
|
||||||
def get_load_events(self, route: str) -> list[EventHandler | EventSpec]:
|
def get_load_events(self, route: str) -> list[IndividualEventType[[]]]:
|
||||||
"""Get the load events for a route.
|
"""Get the load events for a route.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -618,9 +620,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
|||||||
title: str = constants.Page404.TITLE,
|
title: str = constants.Page404.TITLE,
|
||||||
image: str = constants.Page404.IMAGE,
|
image: str = constants.Page404.IMAGE,
|
||||||
description: str = constants.Page404.DESCRIPTION,
|
description: str = constants.Page404.DESCRIPTION,
|
||||||
on_load: (
|
on_load: EventType[[]] | None = None,
|
||||||
EventHandler | EventSpec | list[EventHandler | EventSpec] | None
|
|
||||||
) = None,
|
|
||||||
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
|
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
|
||||||
):
|
):
|
||||||
"""Define a custom 404 page for any url having no match.
|
"""Define a custom 404 page for any url having no match.
|
||||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, Coroutine, Set, Union
|
from typing import Callable, Coroutine, Set, Union
|
||||||
@ -16,11 +17,14 @@ from reflex.utils.exceptions import InvalidLifespanTaskType
|
|||||||
from .mixin import AppMixin
|
from .mixin import AppMixin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
class LifespanMixin(AppMixin):
|
class LifespanMixin(AppMixin):
|
||||||
"""A Mixin that allow tasks to run during the whole app lifespan."""
|
"""A Mixin that allow tasks to run during the whole app lifespan."""
|
||||||
|
|
||||||
# Lifespan tasks that are planned to run.
|
# Lifespan tasks that are planned to run.
|
||||||
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
|
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = dataclasses.field(
|
||||||
|
default_factory=set
|
||||||
|
)
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _run_lifespan_tasks(self, app: FastAPI):
|
async def _run_lifespan_tasks(self, app: FastAPI):
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from reflex.event import Event
|
from reflex.event import Event
|
||||||
@ -12,11 +13,12 @@ from reflex.state import BaseState, StateUpdate
|
|||||||
from .mixin import AppMixin
|
from .mixin import AppMixin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
class MiddlewareMixin(AppMixin):
|
class MiddlewareMixin(AppMixin):
|
||||||
"""Middleware Mixin that allow to add middleware to the app."""
|
"""Middleware Mixin that allow to add middleware to the app."""
|
||||||
|
|
||||||
# Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
|
# Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
|
||||||
middleware: List[Middleware] = []
|
middleware: List[Middleware] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
def _init_mixin(self):
|
def _init_mixin(self):
|
||||||
self.middleware.append(HydrateMiddleware())
|
self.middleware.append(HydrateMiddleware())
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
"""Default mixin for all app mixins."""
|
"""Default mixin for all app mixins."""
|
||||||
|
|
||||||
from reflex.base import Base
|
import dataclasses
|
||||||
|
|
||||||
|
|
||||||
class AppMixin(Base):
|
@dataclasses.dataclass
|
||||||
|
class AppMixin:
|
||||||
"""Define the base class for all app mixins."""
|
"""Define the base class for all app mixins."""
|
||||||
|
|
||||||
def _init_mixin(self):
|
def _init_mixin(self):
|
||||||
|
@ -6,6 +6,7 @@ from collections import defaultdict
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from reflex.config import get_config
|
from reflex.config import get_config
|
||||||
|
from reflex.event import EventType
|
||||||
|
|
||||||
DECORATED_PAGES: Dict[str, List] = defaultdict(list)
|
DECORATED_PAGES: Dict[str, List] = defaultdict(list)
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ def page(
|
|||||||
description: str | None = None,
|
description: str | None = None,
|
||||||
meta: list[Any] | None = None,
|
meta: list[Any] | None = None,
|
||||||
script_tags: list[Any] | None = None,
|
script_tags: list[Any] | None = None,
|
||||||
on_load: Any | list[Any] | None = None,
|
on_load: EventType[[]] | None = None,
|
||||||
):
|
):
|
||||||
"""Decorate a function as a page.
|
"""Decorate a function as a page.
|
||||||
|
|
||||||
|
@ -1211,7 +1211,7 @@ async def test_process_events(mocker, token: str):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_overlay_component(
|
def test_overlay_component(
|
||||||
state: State | None,
|
state: Type[State] | None,
|
||||||
overlay_component: Component | ComponentCallable | None,
|
overlay_component: Component | ComponentCallable | None,
|
||||||
exp_page_child: Type[Component] | None,
|
exp_page_child: Type[Component] | None,
|
||||||
):
|
):
|
||||||
@ -1403,13 +1403,6 @@ def test_app_state_determination():
|
|||||||
assert a4.state is not None
|
assert a4.state is not None
|
||||||
|
|
||||||
|
|
||||||
# for coverage
|
|
||||||
def test_raise_on_connect_error():
|
|
||||||
"""Test that the connect_error function is called."""
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
App(connect_error_component="Foo")
|
|
||||||
|
|
||||||
|
|
||||||
def test_raise_on_state():
|
def test_raise_on_state():
|
||||||
"""Test that the state is set."""
|
"""Test that the state is set."""
|
||||||
# state kwargs is deprecated, we just make sure the app is created anyway.
|
# state kwargs is deprecated, we just make sure the app is created anyway.
|
||||||
|
@ -2725,6 +2725,7 @@ class OnLoadState(State):
|
|||||||
|
|
||||||
num: int = 0
|
num: int = 0
|
||||||
|
|
||||||
|
@rx.event
|
||||||
def test_handler(self):
|
def test_handler(self):
|
||||||
"""Test handler."""
|
"""Test handler."""
|
||||||
self.num += 1
|
self.num += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user