split lifespan and middleware logic in separate mixin files (#3557)
* split lifespan and middleware logic in separate mixin files * fix for 3.8 * fix for unit tests * add missing sys import --------- Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
parent
ad1d82f7ad
commit
f0bab665ce
131
reflex/app.py
131
reflex/app.py
@ -7,7 +7,6 @@ import concurrent.futures
|
|||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
|
||||||
import io
|
import io
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
@ -40,6 +39,7 @@ from starlette_admin.contrib.sqla.view import ModelView
|
|||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.admin import AdminDash
|
from reflex.admin import AdminDash
|
||||||
|
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.compiler import compiler
|
from reflex.compiler import compiler
|
||||||
from reflex.compiler import utils as compiler_utils
|
from reflex.compiler import utils as compiler_utils
|
||||||
@ -61,7 +61,6 @@ from reflex.components.core.upload import Upload, get_upload_dir
|
|||||||
from reflex.components.radix import themes
|
from reflex.components.radix import themes
|
||||||
from reflex.config import get_config
|
from reflex.config import get_config
|
||||||
from reflex.event import Event, EventHandler, EventSpec
|
from reflex.event import Event, EventHandler, EventSpec
|
||||||
from reflex.middleware import HydrateMiddleware, Middleware
|
|
||||||
from reflex.model import Model
|
from reflex.model import Model
|
||||||
from reflex.page import (
|
from reflex.page import (
|
||||||
DECORATED_PAGES,
|
DECORATED_PAGES,
|
||||||
@ -108,50 +107,7 @@ class OverlayFragment(Fragment):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LifespanMixin(Base):
|
class App(MiddlewareMixin, LifespanMixin, Base):
|
||||||
"""A Mixin that allow tasks to run during the whole app lifespan."""
|
|
||||||
|
|
||||||
# Lifespan tasks that are planned to run.
|
|
||||||
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
|
||||||
async def _run_lifespan_tasks(self, app: FastAPI):
|
|
||||||
running_tasks = []
|
|
||||||
try:
|
|
||||||
async with contextlib.AsyncExitStack() as stack:
|
|
||||||
for task in self.lifespan_tasks:
|
|
||||||
if isinstance(task, asyncio.Task):
|
|
||||||
running_tasks.append(task)
|
|
||||||
else:
|
|
||||||
signature = inspect.signature(task)
|
|
||||||
if "app" in signature.parameters:
|
|
||||||
task = functools.partial(task, app=app)
|
|
||||||
_t = task()
|
|
||||||
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
|
|
||||||
await stack.enter_async_context(_t)
|
|
||||||
elif isinstance(_t, Coroutine):
|
|
||||||
running_tasks.append(asyncio.create_task(_t))
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
cancel_kwargs = (
|
|
||||||
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
|
|
||||||
)
|
|
||||||
for task in running_tasks:
|
|
||||||
task.cancel(**cancel_kwargs)
|
|
||||||
|
|
||||||
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
|
|
||||||
"""Register a task to run during the lifespan of the app.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The task to register.
|
|
||||||
task_kwargs: The kwargs of the task.
|
|
||||||
"""
|
|
||||||
if task_kwargs:
|
|
||||||
task = functools.partial(task, **task_kwargs) # type: ignore
|
|
||||||
self.lifespan_tasks.add(task) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class App(LifespanMixin, Base):
|
|
||||||
"""The main Reflex app that encapsulates the backend and frontend.
|
"""The main Reflex app that encapsulates the backend and frontend.
|
||||||
|
|
||||||
Every Reflex app needs an app defined in its main module.
|
Every Reflex app needs an app defined in its main module.
|
||||||
@ -210,9 +166,6 @@ class App(LifespanMixin, Base):
|
|||||||
# Class to manage many client states.
|
# Class to manage many client states.
|
||||||
_state_manager: Optional[StateManager] = None
|
_state_manager: Optional[StateManager] = None
|
||||||
|
|
||||||
# Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
|
|
||||||
middleware: List[Middleware] = []
|
|
||||||
|
|
||||||
# Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
|
# Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
|
||||||
load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
|
load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
|
||||||
|
|
||||||
@ -253,14 +206,17 @@ class App(LifespanMixin, Base):
|
|||||||
if "breakpoints" in self.style:
|
if "breakpoints" in self.style:
|
||||||
set_breakpoints(self.style.pop("breakpoints"))
|
set_breakpoints(self.style.pop("breakpoints"))
|
||||||
|
|
||||||
# Add middleware.
|
|
||||||
self.middleware.append(HydrateMiddleware())
|
|
||||||
|
|
||||||
# Set up the API.
|
# Set up the API.
|
||||||
self.api = FastAPI(lifespan=self._run_lifespan_tasks)
|
self.api = FastAPI(lifespan=self._run_lifespan_tasks)
|
||||||
self._add_cors()
|
self._add_cors()
|
||||||
self._add_default_endpoints()
|
self._add_default_endpoints()
|
||||||
|
|
||||||
|
for clz in App.__mro__:
|
||||||
|
if clz == App:
|
||||||
|
continue
|
||||||
|
if issubclass(clz, AppMixin):
|
||||||
|
clz._init_mixin(self)
|
||||||
|
|
||||||
self._setup_state()
|
self._setup_state()
|
||||||
|
|
||||||
# Set up the admin dash.
|
# Set up the admin dash.
|
||||||
@ -385,77 +341,6 @@ class App(LifespanMixin, Base):
|
|||||||
raise ValueError("The state manager has not been initialized.")
|
raise ValueError("The state manager has not been initialized.")
|
||||||
return self._state_manager
|
return self._state_manager
|
||||||
|
|
||||||
async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
|
|
||||||
"""Preprocess the event.
|
|
||||||
|
|
||||||
This is where middleware can modify the event before it is processed.
|
|
||||||
Each middleware is called in the order it was added to the app.
|
|
||||||
|
|
||||||
If a middleware returns an update, the event is not processed and the
|
|
||||||
update is returned.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The state to preprocess.
|
|
||||||
event: The event to preprocess.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An optional state to return.
|
|
||||||
"""
|
|
||||||
for middleware in self.middleware:
|
|
||||||
if asyncio.iscoroutinefunction(middleware.preprocess):
|
|
||||||
out = await middleware.preprocess(app=self, state=state, event=event) # type: ignore
|
|
||||||
else:
|
|
||||||
out = middleware.preprocess(app=self, state=state, event=event) # type: ignore
|
|
||||||
if out is not None:
|
|
||||||
return out # type: ignore
|
|
||||||
|
|
||||||
async def _postprocess(
|
|
||||||
self, state: BaseState, event: Event, update: StateUpdate
|
|
||||||
) -> StateUpdate:
|
|
||||||
"""Postprocess the event.
|
|
||||||
|
|
||||||
This is where middleware can modify the delta after it is processed.
|
|
||||||
Each middleware is called in the order it was added to the app.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The state to postprocess.
|
|
||||||
event: The event to postprocess.
|
|
||||||
update: The current state update.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The state update to return.
|
|
||||||
"""
|
|
||||||
for middleware in self.middleware:
|
|
||||||
if asyncio.iscoroutinefunction(middleware.postprocess):
|
|
||||||
out = await middleware.postprocess(
|
|
||||||
app=self, # type: ignore
|
|
||||||
state=state,
|
|
||||||
event=event,
|
|
||||||
update=update,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out = middleware.postprocess(
|
|
||||||
app=self, # type: ignore
|
|
||||||
state=state,
|
|
||||||
event=event,
|
|
||||||
update=update,
|
|
||||||
)
|
|
||||||
if out is not None:
|
|
||||||
return out # type: ignore
|
|
||||||
return update
|
|
||||||
|
|
||||||
def add_middleware(self, middleware: Middleware, index: int | None = None):
|
|
||||||
"""Add middleware to the app.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
middleware: The middleware to add.
|
|
||||||
index: The index to add the middleware at.
|
|
||||||
"""
|
|
||||||
if index is None:
|
|
||||||
self.middleware.append(middleware)
|
|
||||||
else:
|
|
||||||
self.middleware.insert(index, middleware)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_component(component: Component | ComponentCallable) -> Component:
|
def _generate_component(component: Component | ComponentCallable) -> Component:
|
||||||
"""Generate a component from a callable.
|
"""Generate a component from a callable.
|
||||||
|
5
reflex/app_mixins/__init__.py
Normal file
5
reflex/app_mixins/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
"""Reflex AppMixins package."""
|
||||||
|
|
||||||
|
from .lifespan import LifespanMixin
|
||||||
|
from .middleware import MiddlewareMixin
|
||||||
|
from .mixin import AppMixin
|
57
reflex/app_mixins/lifespan.py
Normal file
57
reflex/app_mixins/lifespan.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Mixin that allow tasks to run during the whole app lifespan."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
from typing import Callable, Coroutine, Set, Union
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from .mixin import AppMixin
|
||||||
|
|
||||||
|
|
||||||
|
class LifespanMixin(AppMixin):
|
||||||
|
"""A Mixin that allow tasks to run during the whole app lifespan."""
|
||||||
|
|
||||||
|
# Lifespan tasks that are planned to run.
|
||||||
|
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _run_lifespan_tasks(self, app: FastAPI):
|
||||||
|
running_tasks = []
|
||||||
|
try:
|
||||||
|
async with contextlib.AsyncExitStack() as stack:
|
||||||
|
for task in self.lifespan_tasks:
|
||||||
|
if isinstance(task, asyncio.Task):
|
||||||
|
running_tasks.append(task)
|
||||||
|
else:
|
||||||
|
signature = inspect.signature(task)
|
||||||
|
if "app" in signature.parameters:
|
||||||
|
task = functools.partial(task, app=app)
|
||||||
|
_t = task()
|
||||||
|
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
|
||||||
|
await stack.enter_async_context(_t)
|
||||||
|
elif isinstance(_t, Coroutine):
|
||||||
|
running_tasks.append(asyncio.create_task(_t))
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
cancel_kwargs = (
|
||||||
|
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
|
||||||
|
)
|
||||||
|
for task in running_tasks:
|
||||||
|
task.cancel(**cancel_kwargs)
|
||||||
|
|
||||||
|
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
|
||||||
|
"""Register a task to run during the lifespan of the app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task to register.
|
||||||
|
task_kwargs: The kwargs of the task.
|
||||||
|
"""
|
||||||
|
if task_kwargs:
|
||||||
|
task = functools.partial(task, **task_kwargs) # type: ignore
|
||||||
|
self.lifespan_tasks.add(task) # type: ignore
|
93
reflex/app_mixins/middleware.py
Normal file
93
reflex/app_mixins/middleware.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
"""Middleware Mixin that allow to add middleware to the app."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from reflex.event import Event
|
||||||
|
from reflex.middleware import HydrateMiddleware, Middleware
|
||||||
|
from reflex.state import BaseState, StateUpdate
|
||||||
|
|
||||||
|
from .mixin import AppMixin
|
||||||
|
|
||||||
|
|
||||||
|
class MiddlewareMixin(AppMixin):
|
||||||
|
"""Middleware Mixin that allow to add middleware to the app."""
|
||||||
|
|
||||||
|
# Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
|
||||||
|
middleware: List[Middleware] = []
|
||||||
|
|
||||||
|
def _init_mixin(self):
|
||||||
|
self.middleware.append(HydrateMiddleware())
|
||||||
|
|
||||||
|
def add_middleware(self, middleware: Middleware, index: int | None = None):
|
||||||
|
"""Add middleware to the app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
middleware: The middleware to add.
|
||||||
|
index: The index to add the middleware at.
|
||||||
|
"""
|
||||||
|
if index is None:
|
||||||
|
self.middleware.append(middleware)
|
||||||
|
else:
|
||||||
|
self.middleware.insert(index, middleware)
|
||||||
|
|
||||||
|
async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
|
||||||
|
"""Preprocess the event.
|
||||||
|
|
||||||
|
This is where middleware can modify the event before it is processed.
|
||||||
|
Each middleware is called in the order it was added to the app.
|
||||||
|
|
||||||
|
If a middleware returns an update, the event is not processed and the
|
||||||
|
update is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The state to preprocess.
|
||||||
|
event: The event to preprocess.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An optional state to return.
|
||||||
|
"""
|
||||||
|
for middleware in self.middleware:
|
||||||
|
if asyncio.iscoroutinefunction(middleware.preprocess):
|
||||||
|
out = await middleware.preprocess(app=self, state=state, event=event) # type: ignore
|
||||||
|
else:
|
||||||
|
out = middleware.preprocess(app=self, state=state, event=event) # type: ignore
|
||||||
|
if out is not None:
|
||||||
|
return out # type: ignore
|
||||||
|
|
||||||
|
async def _postprocess(
|
||||||
|
self, state: BaseState, event: Event, update: StateUpdate
|
||||||
|
) -> StateUpdate:
|
||||||
|
"""Postprocess the event.
|
||||||
|
|
||||||
|
This is where middleware can modify the delta after it is processed.
|
||||||
|
Each middleware is called in the order it was added to the app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The state to postprocess.
|
||||||
|
event: The event to postprocess.
|
||||||
|
update: The current state update.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The state update to return.
|
||||||
|
"""
|
||||||
|
for middleware in self.middleware:
|
||||||
|
if asyncio.iscoroutinefunction(middleware.postprocess):
|
||||||
|
out = await middleware.postprocess(
|
||||||
|
app=self, # type: ignore
|
||||||
|
state=state,
|
||||||
|
event=event,
|
||||||
|
update=update,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = middleware.postprocess(
|
||||||
|
app=self, # type: ignore
|
||||||
|
state=state,
|
||||||
|
event=event,
|
||||||
|
update=update,
|
||||||
|
)
|
||||||
|
if out is not None:
|
||||||
|
return out # type: ignore
|
||||||
|
return update
|
14
reflex/app_mixins/mixin.py
Normal file
14
reflex/app_mixins/mixin.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
"""Default mixin for all app mixins."""
|
||||||
|
|
||||||
|
from reflex.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class AppMixin(Base):
|
||||||
|
"""Define the base class for all app mixins."""
|
||||||
|
|
||||||
|
def _init_mixin(self):
|
||||||
|
"""Initialize the mixin.
|
||||||
|
|
||||||
|
Any App mixin can override this method to perform any initialization.
|
||||||
|
"""
|
||||||
|
...
|
Loading…
Reference in New Issue
Block a user