split lifespan and middleware logic in separate mixin files (#3557)

* split lifespan and middleware logic in separate mixin files

* fix for 3.8

* fix for unit tests

* add missing sys import

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
Thomas Brandého 2024-06-29 17:50:04 +02:00 committed by GitHub
parent ad1d82f7ad
commit f0bab665ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 177 additions and 123 deletions

View File

@ -7,7 +7,6 @@ import concurrent.futures
import contextlib import contextlib
import copy import copy
import functools import functools
import inspect
import io import io
import multiprocessing import multiprocessing
import os import os
@ -40,6 +39,7 @@ from starlette_admin.contrib.sqla.view import ModelView
from reflex import constants from reflex import constants
from reflex.admin import AdminDash from reflex.admin import AdminDash
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
from reflex.base import Base from reflex.base import Base
from reflex.compiler import compiler from reflex.compiler import compiler
from reflex.compiler import utils as compiler_utils from reflex.compiler import utils as compiler_utils
@ -61,7 +61,6 @@ from reflex.components.core.upload import Upload, get_upload_dir
from reflex.components.radix import themes from reflex.components.radix import themes
from reflex.config import get_config from reflex.config import get_config
from reflex.event import Event, EventHandler, EventSpec from reflex.event import Event, EventHandler, EventSpec
from reflex.middleware import HydrateMiddleware, Middleware
from reflex.model import Model from reflex.model import Model
from reflex.page import ( from reflex.page import (
DECORATED_PAGES, DECORATED_PAGES,
@ -108,50 +107,7 @@ class OverlayFragment(Fragment):
pass pass
class LifespanMixin(Base): class App(MiddlewareMixin, LifespanMixin, Base):
"""A Mixin that allow tasks to run during the whole app lifespan."""
# Lifespan tasks that are planned to run.
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
@contextlib.asynccontextmanager
async def _run_lifespan_tasks(self, app: FastAPI):
running_tasks = []
try:
async with contextlib.AsyncExitStack() as stack:
for task in self.lifespan_tasks:
if isinstance(task, asyncio.Task):
running_tasks.append(task)
else:
signature = inspect.signature(task)
if "app" in signature.parameters:
task = functools.partial(task, app=app)
_t = task()
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
await stack.enter_async_context(_t)
elif isinstance(_t, Coroutine):
running_tasks.append(asyncio.create_task(_t))
yield
finally:
cancel_kwargs = (
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
)
for task in running_tasks:
task.cancel(**cancel_kwargs)
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
"""Register a task to run during the lifespan of the app.
Args:
task: The task to register.
task_kwargs: The kwargs of the task.
"""
if task_kwargs:
task = functools.partial(task, **task_kwargs) # type: ignore
self.lifespan_tasks.add(task) # type: ignore
class App(LifespanMixin, Base):
"""The main Reflex app that encapsulates the backend and frontend. """The main Reflex app that encapsulates the backend and frontend.
Every Reflex app needs an app defined in its main module. Every Reflex app needs an app defined in its main module.
@ -210,9 +166,6 @@ class App(LifespanMixin, Base):
# Class to manage many client states. # Class to manage many client states.
_state_manager: Optional[StateManager] = None _state_manager: Optional[StateManager] = None
# Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
middleware: List[Middleware] = []
# Mapping from a route to event handlers to trigger when the page loads. PRIVATE. # Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {} load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
@ -253,14 +206,17 @@ class App(LifespanMixin, Base):
if "breakpoints" in self.style: if "breakpoints" in self.style:
set_breakpoints(self.style.pop("breakpoints")) set_breakpoints(self.style.pop("breakpoints"))
# Add middleware.
self.middleware.append(HydrateMiddleware())
# Set up the API. # Set up the API.
self.api = FastAPI(lifespan=self._run_lifespan_tasks) self.api = FastAPI(lifespan=self._run_lifespan_tasks)
self._add_cors() self._add_cors()
self._add_default_endpoints() self._add_default_endpoints()
for clz in App.__mro__:
if clz == App:
continue
if issubclass(clz, AppMixin):
clz._init_mixin(self)
self._setup_state() self._setup_state()
# Set up the admin dash. # Set up the admin dash.
@ -385,77 +341,6 @@ class App(LifespanMixin, Base):
raise ValueError("The state manager has not been initialized.") raise ValueError("The state manager has not been initialized.")
return self._state_manager return self._state_manager
async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
"""Preprocess the event.
This is where middleware can modify the event before it is processed.
Each middleware is called in the order it was added to the app.
If a middleware returns an update, the event is not processed and the
update is returned.
Args:
state: The state to preprocess.
event: The event to preprocess.
Returns:
An optional state to return.
"""
for middleware in self.middleware:
if asyncio.iscoroutinefunction(middleware.preprocess):
out = await middleware.preprocess(app=self, state=state, event=event) # type: ignore
else:
out = middleware.preprocess(app=self, state=state, event=event) # type: ignore
if out is not None:
return out # type: ignore
async def _postprocess(
self, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate:
"""Postprocess the event.
This is where middleware can modify the delta after it is processed.
Each middleware is called in the order it was added to the app.
Args:
state: The state to postprocess.
event: The event to postprocess.
update: The current state update.
Returns:
The state update to return.
"""
for middleware in self.middleware:
if asyncio.iscoroutinefunction(middleware.postprocess):
out = await middleware.postprocess(
app=self, # type: ignore
state=state,
event=event,
update=update,
)
else:
out = middleware.postprocess(
app=self, # type: ignore
state=state,
event=event,
update=update,
)
if out is not None:
return out # type: ignore
return update
def add_middleware(self, middleware: Middleware, index: int | None = None):
"""Add middleware to the app.
Args:
middleware: The middleware to add.
index: The index to add the middleware at.
"""
if index is None:
self.middleware.append(middleware)
else:
self.middleware.insert(index, middleware)
@staticmethod @staticmethod
def _generate_component(component: Component | ComponentCallable) -> Component: def _generate_component(component: Component | ComponentCallable) -> Component:
"""Generate a component from a callable. """Generate a component from a callable.

View File

@ -0,0 +1,5 @@
"""Reflex AppMixins package."""
from .lifespan import LifespanMixin
from .middleware import MiddlewareMixin
from .mixin import AppMixin

View File

@ -0,0 +1,57 @@
"""Mixin that allow tasks to run during the whole app lifespan."""
from __future__ import annotations
import asyncio
import contextlib
import functools
import inspect
import sys
from typing import Callable, Coroutine, Set, Union
from fastapi import FastAPI
from .mixin import AppMixin
class LifespanMixin(AppMixin):
"""A Mixin that allow tasks to run during the whole app lifespan."""
# Lifespan tasks that are planned to run.
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
@contextlib.asynccontextmanager
async def _run_lifespan_tasks(self, app: FastAPI):
running_tasks = []
try:
async with contextlib.AsyncExitStack() as stack:
for task in self.lifespan_tasks:
if isinstance(task, asyncio.Task):
running_tasks.append(task)
else:
signature = inspect.signature(task)
if "app" in signature.parameters:
task = functools.partial(task, app=app)
_t = task()
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
await stack.enter_async_context(_t)
elif isinstance(_t, Coroutine):
running_tasks.append(asyncio.create_task(_t))
yield
finally:
cancel_kwargs = (
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
)
for task in running_tasks:
task.cancel(**cancel_kwargs)
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
"""Register a task to run during the lifespan of the app.
Args:
task: The task to register.
task_kwargs: The kwargs of the task.
"""
if task_kwargs:
task = functools.partial(task, **task_kwargs) # type: ignore
self.lifespan_tasks.add(task) # type: ignore

View File

@ -0,0 +1,93 @@
"""Middleware Mixin that allow to add middleware to the app."""
from __future__ import annotations
import asyncio
from typing import List
from reflex.event import Event
from reflex.middleware import HydrateMiddleware, Middleware
from reflex.state import BaseState, StateUpdate
from .mixin import AppMixin
class MiddlewareMixin(AppMixin):
"""Middleware Mixin that allow to add middleware to the app."""
# Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
middleware: List[Middleware] = []
def _init_mixin(self):
self.middleware.append(HydrateMiddleware())
def add_middleware(self, middleware: Middleware, index: int | None = None):
"""Add middleware to the app.
Args:
middleware: The middleware to add.
index: The index to add the middleware at.
"""
if index is None:
self.middleware.append(middleware)
else:
self.middleware.insert(index, middleware)
async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
"""Preprocess the event.
This is where middleware can modify the event before it is processed.
Each middleware is called in the order it was added to the app.
If a middleware returns an update, the event is not processed and the
update is returned.
Args:
state: The state to preprocess.
event: The event to preprocess.
Returns:
An optional state to return.
"""
for middleware in self.middleware:
if asyncio.iscoroutinefunction(middleware.preprocess):
out = await middleware.preprocess(app=self, state=state, event=event) # type: ignore
else:
out = middleware.preprocess(app=self, state=state, event=event) # type: ignore
if out is not None:
return out # type: ignore
async def _postprocess(
self, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate:
"""Postprocess the event.
This is where middleware can modify the delta after it is processed.
Each middleware is called in the order it was added to the app.
Args:
state: The state to postprocess.
event: The event to postprocess.
update: The current state update.
Returns:
The state update to return.
"""
for middleware in self.middleware:
if asyncio.iscoroutinefunction(middleware.postprocess):
out = await middleware.postprocess(
app=self, # type: ignore
state=state,
event=event,
update=update,
)
else:
out = middleware.postprocess(
app=self, # type: ignore
state=state,
event=event,
update=update,
)
if out is not None:
return out # type: ignore
return update

View File

@ -0,0 +1,14 @@
"""Default mixin for all app mixins."""
from reflex.base import Base
class AppMixin(Base):
"""Define the base class for all app mixins."""
def _init_mixin(self):
"""Initialize the mixin.
Any App mixin can override this method to perform any initialization.
"""
...