improve lifespan typecheck and debug (#4014)
* add lifespan debug statement * improve some of the logic for lifespan tasks * fix partial name with update_wrapper
This commit is contained in:
parent
9ca5d4a095
commit
1b3422dab6
@ -6,11 +6,13 @@ import asyncio
|
|||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
|
||||||
from typing import Callable, Coroutine, Set, Union
|
from typing import Callable, Coroutine, Set, Union
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from reflex.utils import console
|
||||||
|
from reflex.utils.exceptions import InvalidLifespanTaskType
|
||||||
|
|
||||||
from .mixin import AppMixin
|
from .mixin import AppMixin
|
||||||
|
|
||||||
|
|
||||||
@ -26,6 +28,7 @@ class LifespanMixin(AppMixin):
|
|||||||
try:
|
try:
|
||||||
async with contextlib.AsyncExitStack() as stack:
|
async with contextlib.AsyncExitStack() as stack:
|
||||||
for task in self.lifespan_tasks:
|
for task in self.lifespan_tasks:
|
||||||
|
run_msg = f"Started lifespan task: {task.__name__} as {{type}}" # type: ignore
|
||||||
if isinstance(task, asyncio.Task):
|
if isinstance(task, asyncio.Task):
|
||||||
running_tasks.append(task)
|
running_tasks.append(task)
|
||||||
else:
|
else:
|
||||||
@ -35,15 +38,19 @@ class LifespanMixin(AppMixin):
|
|||||||
_t = task()
|
_t = task()
|
||||||
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
|
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
|
||||||
await stack.enter_async_context(_t)
|
await stack.enter_async_context(_t)
|
||||||
|
console.debug(run_msg.format(type="asynccontextmanager"))
|
||||||
elif isinstance(_t, Coroutine):
|
elif isinstance(_t, Coroutine):
|
||||||
running_tasks.append(asyncio.create_task(_t))
|
task_ = asyncio.create_task(_t)
|
||||||
|
task_.add_done_callback(lambda t: t.result())
|
||||||
|
running_tasks.append(task_)
|
||||||
|
console.debug(run_msg.format(type="coroutine"))
|
||||||
|
else:
|
||||||
|
console.debug(run_msg.format(type="function"))
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
cancel_kwargs = (
|
|
||||||
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
|
|
||||||
)
|
|
||||||
for task in running_tasks:
|
for task in running_tasks:
|
||||||
task.cancel(**cancel_kwargs)
|
console.debug(f"Canceling lifespan task: {task}")
|
||||||
|
task.cancel(msg="lifespan_cleanup")
|
||||||
|
|
||||||
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
|
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
|
||||||
"""Register a task to run during the lifespan of the app.
|
"""Register a task to run during the lifespan of the app.
|
||||||
@ -51,7 +58,18 @@ class LifespanMixin(AppMixin):
|
|||||||
Args:
|
Args:
|
||||||
task: The task to register.
|
task: The task to register.
|
||||||
task_kwargs: The kwargs of the task.
|
task_kwargs: The kwargs of the task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidLifespanTaskType: If the task is a generator function.
|
||||||
"""
|
"""
|
||||||
|
if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task):
|
||||||
|
raise InvalidLifespanTaskType(
|
||||||
|
f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
|
||||||
|
)
|
||||||
|
|
||||||
if task_kwargs:
|
if task_kwargs:
|
||||||
|
original_task = task
|
||||||
task = functools.partial(task, **task_kwargs) # type: ignore
|
task = functools.partial(task, **task_kwargs) # type: ignore
|
||||||
|
functools.update_wrapper(task, original_task) # type: ignore
|
||||||
self.lifespan_tasks.add(task) # type: ignore
|
self.lifespan_tasks.add(task) # type: ignore
|
||||||
|
console.debug(f"Registered lifespan task: {task.__name__}") # type: ignore
|
||||||
|
@ -111,3 +111,7 @@ class GeneratedCodeHasNoFunctionDefs(ReflexError):
|
|||||||
|
|
||||||
class PrimitiveUnserializableToJSON(ReflexError, ValueError):
|
class PrimitiveUnserializableToJSON(ReflexError, ValueError):
|
||||||
"""Raised when a primitive type is unserializable to JSON. Usually with NaN and Infinity."""
|
"""Raised when a primitive type is unserializable to JSON. Usually with NaN and Infinity."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidLifespanTaskType(ReflexError, TypeError):
|
||||||
|
"""Raised when an invalid task type is registered as a lifespan task."""
|
||||||
|
Loading…
Reference in New Issue
Block a user