split lifespan and middleware logic in separate mixin files (#3557)

* split lifespan and middleware logic in separate mixin files

* fix for 3.8

* fix for unit tests

* add missing sys import

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
Thomas Brandého 2024-06-29 17:50:04 +02:00 committed by GitHub
parent ad1d82f7ad
commit f0bab665ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 177 additions and 123 deletions

View File

@ -7,7 +7,6 @@ import concurrent.futures
import contextlib
import copy
import functools
import inspect
import io
import multiprocessing
import os
@ -40,6 +39,7 @@ from starlette_admin.contrib.sqla.view import ModelView
from reflex import constants
from reflex.admin import AdminDash
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
from reflex.base import Base
from reflex.compiler import compiler
from reflex.compiler import utils as compiler_utils
@ -61,7 +61,6 @@ from reflex.components.core.upload import Upload, get_upload_dir
from reflex.components.radix import themes
from reflex.config import get_config
from reflex.event import Event, EventHandler, EventSpec
from reflex.middleware import HydrateMiddleware, Middleware
from reflex.model import Model
from reflex.page import (
DECORATED_PAGES,
@ -108,50 +107,7 @@ class OverlayFragment(Fragment):
pass
class LifespanMixin(Base):
"""A Mixin that allow tasks to run during the whole app lifespan."""
# Lifespan tasks that are planned to run.
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
@contextlib.asynccontextmanager
async def _run_lifespan_tasks(self, app: FastAPI):
running_tasks = []
try:
async with contextlib.AsyncExitStack() as stack:
for task in self.lifespan_tasks:
if isinstance(task, asyncio.Task):
running_tasks.append(task)
else:
signature = inspect.signature(task)
if "app" in signature.parameters:
task = functools.partial(task, app=app)
_t = task()
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
await stack.enter_async_context(_t)
elif isinstance(_t, Coroutine):
running_tasks.append(asyncio.create_task(_t))
yield
finally:
cancel_kwargs = (
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
)
for task in running_tasks:
task.cancel(**cancel_kwargs)
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
"""Register a task to run during the lifespan of the app.
Args:
task: The task to register.
task_kwargs: The kwargs of the task.
"""
if task_kwargs:
task = functools.partial(task, **task_kwargs) # type: ignore
self.lifespan_tasks.add(task) # type: ignore
class App(LifespanMixin, Base):
class App(MiddlewareMixin, LifespanMixin, Base):
"""The main Reflex app that encapsulates the backend and frontend.
Every Reflex app needs an app defined in its main module.
@ -210,9 +166,6 @@ class App(LifespanMixin, Base):
# Class to manage many client states.
_state_manager: Optional[StateManager] = None
# Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
middleware: List[Middleware] = []
# Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
@ -253,14 +206,17 @@ class App(LifespanMixin, Base):
if "breakpoints" in self.style:
set_breakpoints(self.style.pop("breakpoints"))
# Add middleware.
self.middleware.append(HydrateMiddleware())
# Set up the API.
self.api = FastAPI(lifespan=self._run_lifespan_tasks)
self._add_cors()
self._add_default_endpoints()
for clz in App.__mro__:
if clz == App:
continue
if issubclass(clz, AppMixin):
clz._init_mixin(self)
self._setup_state()
# Set up the admin dash.
@ -385,77 +341,6 @@ class App(LifespanMixin, Base):
raise ValueError("The state manager has not been initialized.")
return self._state_manager
async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
"""Preprocess the event.
This is where middleware can modify the event before it is processed.
Each middleware is called in the order it was added to the app.
If a middleware returns an update, the event is not processed and the
update is returned.
Args:
state: The state to preprocess.
event: The event to preprocess.
Returns:
An optional state to return.
"""
for middleware in self.middleware:
if asyncio.iscoroutinefunction(middleware.preprocess):
out = await middleware.preprocess(app=self, state=state, event=event) # type: ignore
else:
out = middleware.preprocess(app=self, state=state, event=event) # type: ignore
if out is not None:
return out # type: ignore
async def _postprocess(
self, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate:
"""Postprocess the event.
This is where middleware can modify the delta after it is processed.
Each middleware is called in the order it was added to the app.
Args:
state: The state to postprocess.
event: The event to postprocess.
update: The current state update.
Returns:
The state update to return.
"""
for middleware in self.middleware:
if asyncio.iscoroutinefunction(middleware.postprocess):
out = await middleware.postprocess(
app=self, # type: ignore
state=state,
event=event,
update=update,
)
else:
out = middleware.postprocess(
app=self, # type: ignore
state=state,
event=event,
update=update,
)
if out is not None:
return out # type: ignore
return update
def add_middleware(self, middleware: Middleware, index: int | None = None):
"""Add middleware to the app.
Args:
middleware: The middleware to add.
index: The index to add the middleware at.
"""
if index is None:
self.middleware.append(middleware)
else:
self.middleware.insert(index, middleware)
@staticmethod
def _generate_component(component: Component | ComponentCallable) -> Component:
"""Generate a component from a callable.

View File

@ -0,0 +1,5 @@
"""Reflex AppMixins package."""
from .lifespan import LifespanMixin
from .middleware import MiddlewareMixin
from .mixin import AppMixin

View File

@ -0,0 +1,57 @@
"""Mixin that allow tasks to run during the whole app lifespan."""
from __future__ import annotations
import asyncio
import contextlib
import functools
import inspect
import sys
from typing import Callable, Coroutine, Set, Union
from fastapi import FastAPI
from .mixin import AppMixin
class LifespanMixin(AppMixin):
"""A Mixin that allow tasks to run during the whole app lifespan."""
# Lifespan tasks that are planned to run.
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
@contextlib.asynccontextmanager
async def _run_lifespan_tasks(self, app: FastAPI):
running_tasks = []
try:
async with contextlib.AsyncExitStack() as stack:
for task in self.lifespan_tasks:
if isinstance(task, asyncio.Task):
running_tasks.append(task)
else:
signature = inspect.signature(task)
if "app" in signature.parameters:
task = functools.partial(task, app=app)
_t = task()
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
await stack.enter_async_context(_t)
elif isinstance(_t, Coroutine):
running_tasks.append(asyncio.create_task(_t))
yield
finally:
cancel_kwargs = (
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
)
for task in running_tasks:
task.cancel(**cancel_kwargs)
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
"""Register a task to run during the lifespan of the app.
Args:
task: The task to register.
task_kwargs: The kwargs of the task.
"""
if task_kwargs:
task = functools.partial(task, **task_kwargs) # type: ignore
self.lifespan_tasks.add(task) # type: ignore

View File

@ -0,0 +1,93 @@
"""Middleware Mixin that allow to add middleware to the app."""
from __future__ import annotations
import asyncio
from typing import List
from reflex.event import Event
from reflex.middleware import HydrateMiddleware, Middleware
from reflex.state import BaseState, StateUpdate
from .mixin import AppMixin
class MiddlewareMixin(AppMixin):
"""Middleware Mixin that allow to add middleware to the app."""
# Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
middleware: List[Middleware] = []
def _init_mixin(self):
self.middleware.append(HydrateMiddleware())
def add_middleware(self, middleware: Middleware, index: int | None = None):
"""Add middleware to the app.
Args:
middleware: The middleware to add.
index: The index to add the middleware at.
"""
if index is None:
self.middleware.append(middleware)
else:
self.middleware.insert(index, middleware)
async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
"""Preprocess the event.
This is where middleware can modify the event before it is processed.
Each middleware is called in the order it was added to the app.
If a middleware returns an update, the event is not processed and the
update is returned.
Args:
state: The state to preprocess.
event: The event to preprocess.
Returns:
An optional state to return.
"""
for middleware in self.middleware:
if asyncio.iscoroutinefunction(middleware.preprocess):
out = await middleware.preprocess(app=self, state=state, event=event) # type: ignore
else:
out = middleware.preprocess(app=self, state=state, event=event) # type: ignore
if out is not None:
return out # type: ignore
async def _postprocess(
self, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate:
"""Postprocess the event.
This is where middleware can modify the delta after it is processed.
Each middleware is called in the order it was added to the app.
Args:
state: The state to postprocess.
event: The event to postprocess.
update: The current state update.
Returns:
The state update to return.
"""
for middleware in self.middleware:
if asyncio.iscoroutinefunction(middleware.postprocess):
out = await middleware.postprocess(
app=self, # type: ignore
state=state,
event=event,
update=update,
)
else:
out = middleware.postprocess(
app=self, # type: ignore
state=state,
event=event,
update=update,
)
if out is not None:
return out # type: ignore
return update

View File

@ -0,0 +1,14 @@
"""Default mixin for all app mixins."""
from reflex.base import Base
class AppMixin(Base):
"""Define the base class for all app mixins."""
def _init_mixin(self):
"""Initialize the mixin.
Any App mixin can override this method to perform any initialization.
"""
...