improve lifespan typecheck and debug (#4014)

* add lifespan debug statement

* improve some of the logic for lifespan tasks

* fix partial name with update_wrapper
This commit is contained in:
Thomas Brandého 2024-09-27 16:17:30 -07:00 committed by GitHub
parent 9ca5d4a095
commit 1b3422dab6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 6 deletions

View File

@ -6,11 +6,13 @@ import asyncio
import contextlib
import functools
import inspect
import sys
from typing import Callable, Coroutine, Set, Union
from fastapi import FastAPI
from reflex.utils import console
from reflex.utils.exceptions import InvalidLifespanTaskType
from .mixin import AppMixin
@ -26,6 +28,7 @@ class LifespanMixin(AppMixin):
try:
async with contextlib.AsyncExitStack() as stack:
for task in self.lifespan_tasks:
run_msg = f"Started lifespan task: {task.__name__} as {{type}}" # type: ignore
if isinstance(task, asyncio.Task):
running_tasks.append(task)
else:
@ -35,15 +38,19 @@ class LifespanMixin(AppMixin):
_t = task()
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
await stack.enter_async_context(_t)
console.debug(run_msg.format(type="asynccontextmanager"))
elif isinstance(_t, Coroutine):
running_tasks.append(asyncio.create_task(_t))
task_ = asyncio.create_task(_t)
task_.add_done_callback(lambda t: t.result())
running_tasks.append(task_)
console.debug(run_msg.format(type="coroutine"))
else:
console.debug(run_msg.format(type="function"))
yield
finally:
cancel_kwargs = (
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
)
for task in running_tasks:
task.cancel(**cancel_kwargs)
console.debug(f"Canceling lifespan task: {task}")
task.cancel(msg="lifespan_cleanup")
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
"""Register a task to run during the lifespan of the app.
@ -51,7 +58,18 @@ class LifespanMixin(AppMixin):
Args:
task: The task to register.
task_kwargs: The kwargs of the task.
Raises:
InvalidLifespanTaskType: If the task is a generator function.
"""
if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task):
raise InvalidLifespanTaskType(
f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
)
if task_kwargs:
original_task = task
task = functools.partial(task, **task_kwargs) # type: ignore
functools.update_wrapper(task, original_task) # type: ignore
self.lifespan_tasks.add(task) # type: ignore
console.debug(f"Registered lifespan task: {task.__name__}") # type: ignore

View File

@ -111,3 +111,7 @@ class GeneratedCodeHasNoFunctionDefs(ReflexError):
class PrimitiveUnserializableToJSON(ReflexError, ValueError):
"""Raised when a primitive type is unserializable to JSON. Usually with NaN and Infinity."""
class InvalidLifespanTaskType(ReflexError, TypeError):
"""Raised when an invalid task type is registered as a lifespan task."""