diff --git a/reflex/app.py b/reflex/app.py index 5923e3389..e350be515 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1389,7 +1389,7 @@ def upload(app: App): if isinstance(func, EventHandler): if func.is_background: raise UploadTypeError( - f"@rx.background is not supported for upload handler `{handler}`.", + f"@rx.event(background=True) is not supported for upload handler `{handler}`.", ) func = func.fn if isinstance(func, functools.partial): diff --git a/reflex/event.py b/reflex/event.py index 86620e65d..aa366e3bb 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -83,7 +83,7 @@ class Event: BACKGROUND_TASK_MARKER = "_reflex_background_task" -def background(fn): +def background(fn, *, __internal_reflex_call: bool = False): """Decorator to mark event handler as running in the background. Args: @@ -96,6 +96,13 @@ def background(fn): Raises: TypeError: If the function is not a coroutine function or async generator. """ + if not __internal_reflex_call: + console.deprecate( + "background-decorator", + "Use `rx.event(background=True)` instead.", + "0.6.5", + "0.7.0", + ) if not inspect.iscoroutinefunction(fn) and not inspect.isasyncgenfunction(fn): raise TypeError("Background task must be async function or generator.") setattr(fn, BACKGROUND_TASK_MARKER, True) @@ -1457,6 +1464,8 @@ V3 = TypeVar("V3") V4 = TypeVar("V4") V5 = TypeVar("V5") +background_event_decorator = background + if sys.version_info >= (3, 10): from typing import Concatenate @@ -1557,32 +1566,12 @@ if sys.version_info >= (3, 10): return partial(self.func, instance) # type: ignore - def event_handler(func: Callable[Concatenate[Any, P], T]) -> EventCallback[P, T]: - """Wrap a function to be used as an event. - Args: - func: The function to wrap. - - Returns: - The wrapped function. - """ - return func # type: ignore else: class EventCallback(Generic[P, T]): """A descriptor that wraps a function to be used as an event.""" - def event_handler(func: Callable[P, T]) -> Callable[P, T]: - """Wrap a function to be used as an event. - - Args: - func: The function to wrap. - - Returns: - The wrapped function. - """ - return func - G = ParamSpec("G") @@ -1608,8 +1597,93 @@ class EventNamespace(types.SimpleNamespace): EventChainVar = EventChainVar LiteralEventChainVar = LiteralEventChainVar EventType = EventType + EventCallback = EventCallback + + if sys.version_info >= (3, 10): + + @overload + @staticmethod + def __call__( + func: None = None, *, background: bool | None = None + ) -> Callable[[Callable[Concatenate[Any, P], T]], EventCallback[P, T]]: ... + + @overload + @staticmethod + def __call__( + func: Callable[Concatenate[Any, P], T], + *, + background: bool | None = None, + ) -> EventCallback[P, T]: ... + + @staticmethod + def __call__( + func: Callable[Concatenate[Any, P], T] | None = None, + *, + background: bool | None = None, + ) -> Union[ + EventCallback[P, T], + Callable[[Callable[Concatenate[Any, P], T]], EventCallback[P, T]], + ]: + """Wrap a function to be used as an event. + + Args: + func: The function to wrap. + background: Whether the event should be run in the background. Defaults to False. + + Returns: + The wrapped function. + """ + + def wrapper(func: Callable[Concatenate[Any, P], T]) -> EventCallback[P, T]: + if background is True: + return background_event_decorator(func, __internal_reflex_call=True) # type: ignore + return func # type: ignore + + if func is not None: + return wrapper(func) + return wrapper + else: + + @overload + @staticmethod + def __call__( + func: None = None, *, background: bool | None = None + ) -> Callable[[Callable[P, T]], Callable[P, T]]: ... + + @overload + @staticmethod + def __call__( + func: Callable[P, T], *, background: bool | None = None + ) -> Callable[P, T]: ... + + @staticmethod + def __call__( + func: Callable[P, T] | None = None, + *, + background: bool | None = None, + ) -> Union[ + Callable[P, T], + Callable[[Callable[P, T]], Callable[P, T]], + ]: + """Wrap a function to be used as an event. + + Args: + func: The function to wrap. + background: Whether the event should be run in the background. Defaults to False. + + Returns: + The wrapped function. + """ + + def wrapper(func: Callable[P, T]) -> Callable[P, T]: + if background is True: + return background_event_decorator(func, __internal_reflex_call=True) # type: ignore + return func # type: ignore + + if func is not None: + return wrapper(func) + return wrapper - __call__ = staticmethod(event_handler) get_event = staticmethod(get_event) get_hydrate_event = staticmethod(get_hydrate_event) fix_events = staticmethod(fix_events) diff --git a/reflex/experimental/misc.py b/reflex/experimental/misc.py index e3d237153..a2a5a0615 100644 --- a/reflex/experimental/misc.py +++ b/reflex/experimental/misc.py @@ -7,7 +7,7 @@ from typing import Any async def run_in_thread(func) -> Any: """Run a function in a separate thread. - To not block the UI event queue, run_in_thread must be inside inside a rx.background() decorated method. + To not block the UI event queue, run_in_thread must be inside inside a rx.event(background=True) decorated method. Args: func (callable): The non-async function to run. diff --git a/reflex/state.py b/reflex/state.py index 6e229b97d..2704d58f2 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2346,7 +2346,7 @@ class StateProxy(wrapt.ObjectProxy): class State(rx.State): counter: int = 0 - @rx.background + @rx.event(background=True) async def bg_increment(self): await asyncio.sleep(1) async with self: @@ -3248,7 +3248,7 @@ class StateManagerRedis(StateManager): raise LockExpiredError( f"Lock expired for token {token} while processing. Consider increasing " f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " - "or use `@rx.background` decorator for long-running tasks." + "or use `@rx.event(background=True)` decorator for long-running tasks." ) client_token, substate_name = _split_substate_key(token) # If the substate name on the token doesn't match the instance name, it cannot have a parent. diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index a445112f3..87aa1459b 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -1,4 +1,4 @@ -"""Test @rx.background task functionality.""" +"""Test @rx.event(background=True) task functionality.""" from typing import Generator @@ -22,8 +22,7 @@ def BackgroundTask(): _task_id: int = 0 iterations: int = 10 - @rx.background - @rx.event + @rx.event(background=True) async def handle_event(self): async with self: self._task_id += 1 @@ -32,8 +31,7 @@ def BackgroundTask(): self.counter += 1 await asyncio.sleep(0.005) - @rx.background - @rx.event + @rx.event(background=True) async def handle_event_yield_only(self): async with self: self._task_id += 1 @@ -48,7 +46,7 @@ def BackgroundTask(): def increment(self): self.counter += 1 - @rx.background + @rx.event(background=True) async def increment_arbitrary(self, amount: int): async with self: self.counter += int(amount) @@ -61,8 +59,7 @@ def BackgroundTask(): async def blocking_pause(self): await asyncio.sleep(0.02) - @rx.background - @rx.event + @rx.event(background=True) async def non_blocking_pause(self): await asyncio.sleep(0.02) @@ -74,15 +71,13 @@ def BackgroundTask(): self.counter += 1 await asyncio.sleep(0.005) - @rx.background - @rx.event + @rx.event(background=True) async def handle_racy_event(self): await asyncio.gather( self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task() ) - @rx.background - @rx.event + @rx.event(background=True) async def nested_async_with_self(self): async with self: self.counter += 1 @@ -94,8 +89,7 @@ def BackgroundTask(): third_state = await self.get_state(ThirdState) await third_state._triple_count() - @rx.background - @rx.event + @rx.event(background=True) async def yield_in_async_with_self(self): async with self: self.counter += 1 @@ -103,8 +97,7 @@ def BackgroundTask(): self.counter += 1 class OtherState(rx.State): - @rx.background - @rx.event + @rx.event(background=True) async def get_other_state(self): async with self: state = await self.get_state(State) diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py index f81e9f235..338025bcd 100644 --- a/tests/units/states/upload.py +++ b/tests/units/states/upload.py @@ -71,7 +71,7 @@ class FileUploadState(State): assert file.filename is not None self.img_list.append(file.filename) - @rx.background + @rx.event(background=True) async def bg_upload(self, files: List[rx.UploadFile]): """Background task cannot be upload handler. @@ -119,7 +119,7 @@ class ChildFileUploadState(FileStateBase1): assert file.filename is not None self.img_list.append(file.filename) - @rx.background + @rx.event(background=True) async def bg_upload(self, files: List[rx.UploadFile]): """Background task cannot be upload handler. @@ -167,7 +167,7 @@ class GrandChildFileUploadState(FileStateBase2): assert file.filename is not None self.img_list.append(file.filename) - @rx.background + @rx.event(background=True) async def bg_upload(self, files: List[rx.UploadFile]): """Background task cannot be upload handler. diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 6bb81522f..7fba7ba1d 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -874,7 +874,7 @@ async def test_upload_file_background(state, tmp_path, token): await fn(request_mock, [file_mock]) assert ( err.value.args[0] - == f"@rx.background is not supported for upload handler `{state.get_full_name()}.bg_upload`." + == f"@rx.event(background=True) is not supported for upload handler `{state.get_full_name()}.bg_upload`." ) if isinstance(app.state_manager, StateManagerRedis): diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 89dd1fd3d..8397954cf 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1965,7 +1965,7 @@ class BackgroundTaskState(BaseState): """ return self.order - @rx.background + @rx.event(background=True) async def background_task(self): """A background task that updates the state.""" async with self: @@ -2002,7 +2002,7 @@ class BackgroundTaskState(BaseState): self.other() # direct calling event handlers works in context self._private_method() - @rx.background + @rx.event(background=True) async def background_task_reset(self): """A background task that resets the state.""" with pytest.raises(ImmutableStateError): @@ -2016,7 +2016,7 @@ class BackgroundTaskState(BaseState): async with self: self.order.append("reset") - @rx.background + @rx.event(background=True) async def background_task_generator(self): """A background task generator that does nothing.