From f0bab665ceecb7321020b00ffeea28b335683150 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Sat, 29 Jun 2024 17:50:04 +0200 Subject: [PATCH] split lifespan and middleware logic in separate mixin files (#3557) * split lifespan and middleware logic in separate mixin files * fix for 3.8 * fix for unit tests * add missing sys import --------- Co-authored-by: Masen Furer --- reflex/app.py | 131 ++------------------------------ reflex/app_mixins/__init__.py | 5 ++ reflex/app_mixins/lifespan.py | 57 ++++++++++++++ reflex/app_mixins/middleware.py | 93 +++++++++++++++++++++++ reflex/app_mixins/mixin.py | 14 ++++ 5 files changed, 177 insertions(+), 123 deletions(-) create mode 100644 reflex/app_mixins/__init__.py create mode 100644 reflex/app_mixins/lifespan.py create mode 100644 reflex/app_mixins/middleware.py create mode 100644 reflex/app_mixins/mixin.py diff --git a/reflex/app.py b/reflex/app.py index 50fda94de..8e8544e7b 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -7,7 +7,6 @@ import concurrent.futures import contextlib import copy import functools -import inspect import io import multiprocessing import os @@ -40,6 +39,7 @@ from starlette_admin.contrib.sqla.view import ModelView from reflex import constants from reflex.admin import AdminDash +from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin from reflex.base import Base from reflex.compiler import compiler from reflex.compiler import utils as compiler_utils @@ -61,7 +61,6 @@ from reflex.components.core.upload import Upload, get_upload_dir from reflex.components.radix import themes from reflex.config import get_config from reflex.event import Event, EventHandler, EventSpec -from reflex.middleware import HydrateMiddleware, Middleware from reflex.model import Model from reflex.page import ( DECORATED_PAGES, @@ -108,50 +107,7 @@ class OverlayFragment(Fragment): pass -class LifespanMixin(Base): - """A Mixin that allow tasks to run during the whole app lifespan.""" - - # Lifespan tasks that are planned to run. - lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set() - - @contextlib.asynccontextmanager - async def _run_lifespan_tasks(self, app: FastAPI): - running_tasks = [] - try: - async with contextlib.AsyncExitStack() as stack: - for task in self.lifespan_tasks: - if isinstance(task, asyncio.Task): - running_tasks.append(task) - else: - signature = inspect.signature(task) - if "app" in signature.parameters: - task = functools.partial(task, app=app) - _t = task() - if isinstance(_t, contextlib._AsyncGeneratorContextManager): - await stack.enter_async_context(_t) - elif isinstance(_t, Coroutine): - running_tasks.append(asyncio.create_task(_t)) - yield - finally: - cancel_kwargs = ( - {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {} - ) - for task in running_tasks: - task.cancel(**cancel_kwargs) - - def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): - """Register a task to run during the lifespan of the app. - - Args: - task: The task to register. - task_kwargs: The kwargs of the task. - """ - if task_kwargs: - task = functools.partial(task, **task_kwargs) # type: ignore - self.lifespan_tasks.add(task) # type: ignore - - -class App(LifespanMixin, Base): +class App(MiddlewareMixin, LifespanMixin, Base): """The main Reflex app that encapsulates the backend and frontend. Every Reflex app needs an app defined in its main module. @@ -210,9 +166,6 @@ class App(LifespanMixin, Base): # Class to manage many client states. _state_manager: Optional[StateManager] = None - # Middleware to add to the app. Users should use `add_middleware`. PRIVATE. - middleware: List[Middleware] = [] - # Mapping from a route to event handlers to trigger when the page loads. PRIVATE. load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {} @@ -253,14 +206,17 @@ class App(LifespanMixin, Base): if "breakpoints" in self.style: set_breakpoints(self.style.pop("breakpoints")) - # Add middleware. - self.middleware.append(HydrateMiddleware()) - # Set up the API. self.api = FastAPI(lifespan=self._run_lifespan_tasks) self._add_cors() self._add_default_endpoints() + for clz in App.__mro__: + if clz == App: + continue + if issubclass(clz, AppMixin): + clz._init_mixin(self) + self._setup_state() # Set up the admin dash. @@ -385,77 +341,6 @@ class App(LifespanMixin, Base): raise ValueError("The state manager has not been initialized.") return self._state_manager - async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None: - """Preprocess the event. - - This is where middleware can modify the event before it is processed. - Each middleware is called in the order it was added to the app. - - If a middleware returns an update, the event is not processed and the - update is returned. - - Args: - state: The state to preprocess. - event: The event to preprocess. - - Returns: - An optional state to return. - """ - for middleware in self.middleware: - if asyncio.iscoroutinefunction(middleware.preprocess): - out = await middleware.preprocess(app=self, state=state, event=event) # type: ignore - else: - out = middleware.preprocess(app=self, state=state, event=event) # type: ignore - if out is not None: - return out # type: ignore - - async def _postprocess( - self, state: BaseState, event: Event, update: StateUpdate - ) -> StateUpdate: - """Postprocess the event. - - This is where middleware can modify the delta after it is processed. - Each middleware is called in the order it was added to the app. - - Args: - state: The state to postprocess. - event: The event to postprocess. - update: The current state update. - - Returns: - The state update to return. - """ - for middleware in self.middleware: - if asyncio.iscoroutinefunction(middleware.postprocess): - out = await middleware.postprocess( - app=self, # type: ignore - state=state, - event=event, - update=update, - ) - else: - out = middleware.postprocess( - app=self, # type: ignore - state=state, - event=event, - update=update, - ) - if out is not None: - return out # type: ignore - return update - - def add_middleware(self, middleware: Middleware, index: int | None = None): - """Add middleware to the app. - - Args: - middleware: The middleware to add. - index: The index to add the middleware at. - """ - if index is None: - self.middleware.append(middleware) - else: - self.middleware.insert(index, middleware) - @staticmethod def _generate_component(component: Component | ComponentCallable) -> Component: """Generate a component from a callable. diff --git a/reflex/app_mixins/__init__.py b/reflex/app_mixins/__init__.py new file mode 100644 index 000000000..86c0aa2bd --- /dev/null +++ b/reflex/app_mixins/__init__.py @@ -0,0 +1,5 @@ +"""Reflex AppMixins package.""" + +from .lifespan import LifespanMixin +from .middleware import MiddlewareMixin +from .mixin import AppMixin diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py new file mode 100644 index 000000000..2b5ed8b58 --- /dev/null +++ b/reflex/app_mixins/lifespan.py @@ -0,0 +1,57 @@ +"""Mixin that allow tasks to run during the whole app lifespan.""" + +from __future__ import annotations + +import asyncio +import contextlib +import functools +import inspect +import sys +from typing import Callable, Coroutine, Set, Union + +from fastapi import FastAPI + +from .mixin import AppMixin + + +class LifespanMixin(AppMixin): + """A Mixin that allow tasks to run during the whole app lifespan.""" + + # Lifespan tasks that are planned to run. + lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set() + + @contextlib.asynccontextmanager + async def _run_lifespan_tasks(self, app: FastAPI): + running_tasks = [] + try: + async with contextlib.AsyncExitStack() as stack: + for task in self.lifespan_tasks: + if isinstance(task, asyncio.Task): + running_tasks.append(task) + else: + signature = inspect.signature(task) + if "app" in signature.parameters: + task = functools.partial(task, app=app) + _t = task() + if isinstance(_t, contextlib._AsyncGeneratorContextManager): + await stack.enter_async_context(_t) + elif isinstance(_t, Coroutine): + running_tasks.append(asyncio.create_task(_t)) + yield + finally: + cancel_kwargs = ( + {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {} + ) + for task in running_tasks: + task.cancel(**cancel_kwargs) + + def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): + """Register a task to run during the lifespan of the app. + + Args: + task: The task to register. + task_kwargs: The kwargs of the task. + """ + if task_kwargs: + task = functools.partial(task, **task_kwargs) # type: ignore + self.lifespan_tasks.add(task) # type: ignore diff --git a/reflex/app_mixins/middleware.py b/reflex/app_mixins/middleware.py new file mode 100644 index 000000000..1e42faf18 --- /dev/null +++ b/reflex/app_mixins/middleware.py @@ -0,0 +1,93 @@ +"""Middleware Mixin that allow to add middleware to the app.""" + +from __future__ import annotations + +import asyncio +from typing import List + +from reflex.event import Event +from reflex.middleware import HydrateMiddleware, Middleware +from reflex.state import BaseState, StateUpdate + +from .mixin import AppMixin + + +class MiddlewareMixin(AppMixin): + """Middleware Mixin that allow to add middleware to the app.""" + + # Middleware to add to the app. Users should use `add_middleware`. PRIVATE. + middleware: List[Middleware] = [] + + def _init_mixin(self): + self.middleware.append(HydrateMiddleware()) + + def add_middleware(self, middleware: Middleware, index: int | None = None): + """Add middleware to the app. + + Args: + middleware: The middleware to add. + index: The index to add the middleware at. + """ + if index is None: + self.middleware.append(middleware) + else: + self.middleware.insert(index, middleware) + + async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None: + """Preprocess the event. + + This is where middleware can modify the event before it is processed. + Each middleware is called in the order it was added to the app. + + If a middleware returns an update, the event is not processed and the + update is returned. + + Args: + state: The state to preprocess. + event: The event to preprocess. + + Returns: + An optional state to return. + """ + for middleware in self.middleware: + if asyncio.iscoroutinefunction(middleware.preprocess): + out = await middleware.preprocess(app=self, state=state, event=event) # type: ignore + else: + out = middleware.preprocess(app=self, state=state, event=event) # type: ignore + if out is not None: + return out # type: ignore + + async def _postprocess( + self, state: BaseState, event: Event, update: StateUpdate + ) -> StateUpdate: + """Postprocess the event. + + This is where middleware can modify the delta after it is processed. + Each middleware is called in the order it was added to the app. + + Args: + state: The state to postprocess. + event: The event to postprocess. + update: The current state update. + + Returns: + The state update to return. + """ + for middleware in self.middleware: + if asyncio.iscoroutinefunction(middleware.postprocess): + out = await middleware.postprocess( + app=self, # type: ignore + state=state, + event=event, + update=update, + ) + else: + out = middleware.postprocess( + app=self, # type: ignore + state=state, + event=event, + update=update, + ) + if out is not None: + return out # type: ignore + return update diff --git a/reflex/app_mixins/mixin.py b/reflex/app_mixins/mixin.py new file mode 100644 index 000000000..ed301c495 --- /dev/null +++ b/reflex/app_mixins/mixin.py @@ -0,0 +1,14 @@ +"""Default mixin for all app mixins.""" + +from reflex.base import Base + + +class AppMixin(Base): + """Define the base class for all app mixins.""" + + def _init_mixin(self): + """Initialize the mixin. + + Any App mixin can override this method to perform any initialization. + """ + ...