diff --git a/reflex/event.py b/reflex/event.py index f576d4ff8..74504e713 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1566,86 +1566,12 @@ if sys.version_info >= (3, 10): return partial(self.func, instance) # type: ignore - @overload - def event_handler( - func: None = None, *, background: bool | None = None - ) -> Callable[[Callable[Concatenate[Any, P], T]], EventCallback[P, T]]: ... - @overload - def event_handler( - func: Callable[Concatenate[Any, P], T], - *, - background: bool | None = None, - ) -> EventCallback[P, T]: ... - - def event_handler( - func: Callable[Concatenate[Any, P], T] | None = None, - *, - background: bool | None = None, - ) -> Union[ - EventCallback[P, T], - Callable[[Callable[Concatenate[Any, P], T]], EventCallback[P, T]], - ]: - """Wrap a function to be used as an event. - - Args: - func: The function to wrap. - background: Whether the event should be run in the background. Defaults to False. - - Returns: - The wrapped function. - """ - - def wrapper(func: Callable[Concatenate[Any, P], T]) -> EventCallback[P, T]: - if background is True: - return background_event_decorator(func, __internal_reflex_call=True) # type: ignore - return func # type: ignore - - if func is not None: - return wrapper(func) - return wrapper else: class EventCallback(Generic[P, T]): """A descriptor that wraps a function to be used as an event.""" - @overload - def event_handler( - func: None = None, *, background: bool | None = None - ) -> Callable[[Callable[P, T]], Callable[P, T]]: ... - - @overload - def event_handler( - func: Callable[P, T], *, background: bool | None = None - ) -> Callable[P, T]: ... - - def event_handler( - func: Callable[P, T] | None = None, - *, - background: bool | None = None, - ) -> Union[ - Callable[P, T], - Callable[[Callable[P, T]], Callable[P, T]], - ]: - """Wrap a function to be used as an event. - - Args: - func: The function to wrap. - background: Whether the event should be run in the background. Defaults to False. - - Returns: - The wrapped function. - """ - - def wrapper(func: Callable[P, T]) -> Callable[P, T]: - if background is True: - return background_event_decorator(func, __internal_reflex_call=True) # type: ignore - return func # type: ignore - - if func is not None: - return wrapper(func) - return wrapper - G = ParamSpec("G") @@ -1672,7 +1598,93 @@ class EventNamespace(types.SimpleNamespace): LiteralEventChainVar = LiteralEventChainVar EventType = EventType - __call__ = staticmethod(event_handler) + # __call__ = staticmethod(event_handler) + + if sys.version_info >= (3, 10): + + @overload + @staticmethod + def __call__( + func: None = None, *, background: bool | None = None + ) -> Callable[[Callable[Concatenate[Any, P], T]], EventCallback[P, T]]: ... + + @overload + @staticmethod + def __call__( + func: Callable[Concatenate[Any, P], T], + *, + background: bool | None = None, + ) -> EventCallback[P, T]: ... + + @staticmethod + def __call__( + func: Callable[Concatenate[Any, P], T] | None = None, + *, + background: bool | None = None, + ) -> Union[ + EventCallback[P, T], + Callable[[Callable[Concatenate[Any, P], T]], EventCallback[P, T]], + ]: + """Wrap a function to be used as an event. + + Args: + func: The function to wrap. + background: Whether the event should be run in the background. Defaults to False. + + Returns: + The wrapped function. + """ + + def wrapper(func: Callable[Concatenate[Any, P], T]) -> EventCallback[P, T]: + if background is True: + return background_event_decorator(func, __internal_reflex_call=True) # type: ignore + return func # type: ignore + + if func is not None: + return wrapper(func) + return wrapper + else: + + @overload + @staticmethod + def event_handler( + func: None = None, *, background: bool | None = None + ) -> Callable[[Callable[P, T]], Callable[P, T]]: ... + + @overload + @staticmethod + def event_handler( + func: Callable[P, T], *, background: bool | None = None + ) -> Callable[P, T]: ... + + @staticmethod + def event_handler( + func: Callable[P, T] | None = None, + *, + background: bool | None = None, + ) -> Union[ + Callable[P, T], + Callable[[Callable[P, T]], Callable[P, T]], + ]: + """Wrap a function to be used as an event. + + Args: + func: The function to wrap. + background: Whether the event should be run in the background. Defaults to False. + + Returns: + The wrapped function. + """ + + def wrapper(func: Callable[P, T]) -> Callable[P, T]: + if background is True: + return background_event_decorator(func, __internal_reflex_call=True) # type: ignore + return func # type: ignore + + if func is not None: + return wrapper(func) + return wrapper + get_event = staticmethod(get_event) get_hydrate_event = staticmethod(get_hydrate_event) fix_events = staticmethod(fix_events) @@ -1705,3 +1717,5 @@ class EventNamespace(types.SimpleNamespace): event = EventNamespace() + +event.event_handler() diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 7c3f36f7a..87aa1459b 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -23,7 +23,6 @@ def BackgroundTask(): iterations: int = 10 @rx.event(background=True) - @rx.event async def handle_event(self): async with self: self._task_id += 1 @@ -33,7 +32,6 @@ def BackgroundTask(): await asyncio.sleep(0.005) @rx.event(background=True) - @rx.event async def handle_event_yield_only(self): async with self: self._task_id += 1 @@ -62,7 +60,6 @@ def BackgroundTask(): await asyncio.sleep(0.02) @rx.event(background=True) - @rx.event async def non_blocking_pause(self): await asyncio.sleep(0.02) @@ -75,14 +72,12 @@ def BackgroundTask(): await asyncio.sleep(0.005) @rx.event(background=True) - @rx.event async def handle_racy_event(self): await asyncio.gather( self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task() ) @rx.event(background=True) - @rx.event async def nested_async_with_self(self): async with self: self.counter += 1 @@ -95,7 +90,6 @@ def BackgroundTask(): await third_state._triple_count() @rx.event(background=True) - @rx.event async def yield_in_async_with_self(self): async with self: self.counter += 1 @@ -104,7 +98,6 @@ def BackgroundTask(): class OtherState(rx.State): @rx.event(background=True) - @rx.event async def get_other_state(self): async with self: state = await self.get_state(State)