diff --git a/reflex/app.py b/reflex/app.py index d9104ece6..2f4e57a63 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -25,7 +25,6 @@ from typing import ( Callable, Coroutine, Dict, - Generic, List, MutableMapping, Optional, @@ -76,7 +75,6 @@ from reflex.components.radix import themes from reflex.config import environment, get_config from reflex.event import ( _EVENT_FIELDS, - BASE_STATE, Event, EventHandler, EventSpec, @@ -196,7 +194,7 @@ class OverlayFragment(Fragment): @dataclasses.dataclass( frozen=True, ) -class UnevaluatedPage(Generic[BASE_STATE]): +class UnevaluatedPage: """An uncompiled page.""" component: Union[Component, ComponentCallable] @@ -204,7 +202,7 @@ class UnevaluatedPage(Generic[BASE_STATE]): title: Union[Var, str, None] description: Union[Var, str, None] image: str - on_load: Union[EventType[[], BASE_STATE], None] + on_load: Union[EventType[()], None] meta: List[Dict[str, str]] @@ -279,7 +277,7 @@ class App(MiddlewareMixin, LifespanMixin): _state_manager: Optional[StateManager] = None # Mapping from a route to event handlers to trigger when the page loads. - _load_events: Dict[str, List[IndividualEventType[[], Any]]] = dataclasses.field( + _load_events: Dict[str, List[IndividualEventType[()]]] = dataclasses.field( default_factory=dict ) @@ -544,7 +542,7 @@ class App(MiddlewareMixin, LifespanMixin): title: str | Var | None = None, description: str | Var | None = None, image: str = constants.DefaultPage.IMAGE, - on_load: EventType[[], BASE_STATE] | None = None, + on_load: EventType[()] | None = None, meta: list[dict[str, str]] = constants.DefaultPage.META_LIST, ): """Add a page to the app. @@ -648,7 +646,7 @@ class App(MiddlewareMixin, LifespanMixin): if save_page: self._pages[route] = component - def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]: + def get_load_events(self, route: str) -> list[IndividualEventType[()]]: """Get the load events for a route. Args: @@ -710,7 +708,7 @@ class App(MiddlewareMixin, LifespanMixin): title: str = constants.Page404.TITLE, image: str = constants.Page404.IMAGE, description: str = constants.Page404.DESCRIPTION, - on_load: EventType[[], BASE_STATE] | None = None, + on_load: EventType[()] | None = None, meta: list[dict[str, str]] = constants.DefaultPage.META_LIST, ): """Define a custom 404 page for any url having no match. diff --git a/reflex/event.py b/reflex/event.py index f35e88389..a838e616d 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -20,17 +20,16 @@ from typing import ( Tuple, Type, Union, + Unpack, get_type_hints, overload, ) from typing_extensions import ( - Concatenate, - ParamSpec, Protocol, - TypeAliasType, TypedDict, TypeVar, + TypeVarTuple, get_args, get_origin, ) @@ -1763,8 +1762,8 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV ) -P = ParamSpec("P") -Q = ParamSpec("Q") +P = TypeVarTuple("P") +Q = TypeVarTuple("Q") T = TypeVar("T") V = TypeVar("V") V2 = TypeVar("V2") @@ -1773,10 +1772,10 @@ V4 = TypeVar("V4") V5 = TypeVar("V5") -class EventCallback(Generic[P, T]): +class EventCallback(Generic[Unpack[P]]): """A descriptor that wraps a function to be used as an event.""" - def __init__(self, func: Callable[Concatenate[Any, P], T]): + def __init__(self, func: Callable[[Any, Unpack[P]], Any]): """Initialize the descriptor with the function to be wrapped. Args: @@ -1804,37 +1803,37 @@ class EventCallback(Generic[P, T]): @overload def __call__( - self: EventCallback[Q, T], - ) -> EventCallback[Q, T]: ... + self: EventCallback[Unpack[Q]], + ) -> EventCallback[Unpack[Q]]: ... @overload def __call__( - self: EventCallback[Concatenate[V, Q], T], value: V | Var[V] - ) -> EventCallback[Q, T]: ... + self: EventCallback[V, Unpack[Q]], value: V | Var[V] + ) -> EventCallback[Unpack[Q]]: ... @overload def __call__( - self: EventCallback[Concatenate[V, V2, Q], T], + self: EventCallback[V, V2, Unpack[Q]], value: V | Var[V], value2: V2 | Var[V2], - ) -> EventCallback[Q, T]: ... + ) -> EventCallback[Unpack[Q]]: ... @overload def __call__( - self: EventCallback[Concatenate[V, V2, V3, Q], T], + self: EventCallback[V, V2, V3, Unpack[Q]], value: V | Var[V], value2: V2 | Var[V2], value3: V3 | Var[V3], - ) -> EventCallback[Q, T]: ... + ) -> EventCallback[Unpack[Q]]: ... @overload def __call__( - self: EventCallback[Concatenate[V, V2, V3, V4, Q], T], + self: EventCallback[V, V2, V3, V4, Unpack[Q]], value: V | Var[V], value2: V2 | Var[V2], value3: V3 | Var[V3], value4: V4 | Var[V4], - ) -> EventCallback[Q, T]: ... + ) -> EventCallback[Unpack[Q]]: ... def __call__(self, *values) -> EventCallback: # pyright: ignore [reportInconsistentOverload] """Call the function with the values. @@ -1849,11 +1848,11 @@ class EventCallback(Generic[P, T]): @overload def __get__( - self: EventCallback[P, T], instance: None, owner: Any - ) -> EventCallback[P, T]: ... + self: EventCallback[Unpack[P]], instance: None, owner: Any + ) -> EventCallback[Unpack[P]]: ... @overload - def __get__(self, instance: Any, owner: Any) -> Callable[P, T]: ... + def __get__(self, instance: Any, owner: Any) -> Callable[[Unpack[P]]]: ... def __get__(self, instance: Any, owner: Any) -> Callable: """Get the function with the instance bound to it. @@ -1871,7 +1870,51 @@ class EventCallback(Generic[P, T]): return partial(self.func, instance) -G = ParamSpec("G") +class LambdaEventCallback(Protocol[Unpack[P]]): + """A protocol for a lambda event callback.""" + + @overload + def __call__(self: LambdaEventCallback[()]) -> Any: ... + + @overload + def __call__(self: LambdaEventCallback[V], value: Var[V], /) -> Any: ... + + @overload + def __call__( + self: LambdaEventCallback[V, V2], value: Var[V], value2: Var[V2], / + ) -> Any: ... + + @overload + def __call__( + self: LambdaEventCallback[V, V2, V3], + value: Var[V], + value2: Var[V2], + value3: Var[V3], + /, + ) -> Any: ... + + def __call__(self, *args: Var) -> Any: + """Call the lambda with the args. + + Args: + *args: The args to call the lambda with. + + Returns: + The result of calling the lambda with the args. + """ + + +BasicEventTypes = EventSpec | EventHandler | Var[Any] + +ARGS = TypeVarTuple("ARGS") + +LambdaOrState = LambdaEventCallback[Unpack[ARGS]] | EventCallback[Unpack[ARGS]] + +ItemOrList = V | List[V] + +IndividualEventType = BasicEventTypes | LambdaOrState[Unpack[ARGS]] | LambdaOrState[()] +EventType = ItemOrList[IndividualEventType[Unpack[ARGS]]] + if TYPE_CHECKING: from reflex.state import BaseState @@ -1880,25 +1923,6 @@ if TYPE_CHECKING: else: BASE_STATE = TypeVar("BASE_STATE") -StateCallable = TypeAliasType( - "StateCallable", - Callable[Concatenate[BASE_STATE, G], Any], - type_params=(G, BASE_STATE), -) - -IndividualEventType = Union[ - EventSpec, - EventHandler, - Callable[G, Any], - StateCallable[G, BASE_STATE], - EventCallback[G, Any], - Var[Any], -] - -ItemOrList = Union[V, List[V]] - -EventType = ItemOrList[IndividualEventType[G, BASE_STATE]] - class EventNamespace(types.SimpleNamespace): """A namespace for event related classes.""" @@ -1919,24 +1943,26 @@ class EventNamespace(types.SimpleNamespace): @staticmethod def __call__( func: None = None, *, background: bool | None = None - ) -> Callable[[Callable[Concatenate[BASE_STATE, P], T]], EventCallback[P, T]]: ... # pyright: ignore [reportInvalidTypeVarUse] + ) -> Callable[ + [Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]] # pyright: ignore [reportInvalidTypeVarUse] + ]: ... @overload @staticmethod def __call__( - func: Callable[Concatenate[BASE_STATE, P], T], + func: Callable[[BASE_STATE, Unpack[P]], Any], *, background: bool | None = None, - ) -> EventCallback[P, T]: ... + ) -> EventCallback[Unpack[P]]: ... @staticmethod def __call__( - func: Callable[Concatenate[BASE_STATE, P], T] | None = None, + func: Callable[[BASE_STATE, Unpack[P]], Any] | None = None, *, background: bool | None = None, ) -> Union[ - EventCallback[P, T], - Callable[[Callable[Concatenate[BASE_STATE, P], T]], EventCallback[P, T]], + EventCallback[Unpack[P]], + Callable[[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]]], ]: """Wrap a function to be used as an event. @@ -1952,8 +1978,8 @@ class EventNamespace(types.SimpleNamespace): """ def wrapper( - func: Callable[Concatenate[BASE_STATE, P], T], - ) -> EventCallback[P, T]: + func: Callable[[BASE_STATE, Unpack[P]], T], + ) -> EventCallback[Unpack[P]]: if background is True: if not inspect.iscoroutinefunction( func diff --git a/reflex/page.py b/reflex/page.py index 5f118aad1..7d9ac1b5b 100644 --- a/reflex/page.py +++ b/reflex/page.py @@ -6,7 +6,7 @@ from collections import defaultdict from typing import Any, Callable, Dict, List from reflex.config import get_config -from reflex.event import BASE_STATE, EventType +from reflex.event import EventType DECORATED_PAGES: Dict[str, List] = defaultdict(list) @@ -18,7 +18,7 @@ def page( description: str | None = None, meta: list[Any] | None = None, script_tags: list[Any] | None = None, - on_load: EventType[[], BASE_STATE] | None = None, + on_load: EventType[()] | None = None, ): """Decorate a function as a page. diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 6f05e0982..225d52f3a 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -502,7 +502,7 @@ if TYPE_CHECKING: def format_queue_events( - events: EventType | None = None, + events: EventType[Any] | None = None, args_spec: Optional[ArgsSpec] = None, ) -> Var[EventChain]: """Format a list of event handler / event spec as a javascript callback. diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index beb355d31..f47173a0a 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -75,7 +75,6 @@ DEFAULT_IMPORTS = { "EventHandler", "EventSpec", "EventType", - "BASE_STATE", "KeyInputInfo", ], "reflex.style": ["Style"], @@ -502,7 +501,7 @@ def _generate_component_create_functiondef( def figure_out_return_type(annotation: Any): if inspect.isclass(annotation) and issubclass(annotation, inspect._empty): - return ast.Name(id="EventType[..., BASE_STATE]") + return ast.Name(id="EventType[Any]") if not isinstance(annotation, str) and get_origin(annotation) is tuple: arguments = get_args(annotation) @@ -518,7 +517,7 @@ def _generate_component_create_functiondef( # Get all prefixes of the type arguments all_count_args_type = [ ast.Name( - f"EventType[[{', '.join([ast.unparse(arg) for arg in type_args[:i]])}], BASE_STATE]" + f"EventType[[{', '.join([ast.unparse(arg) for arg in type_args[:i]])}]]" ) for i in range(len(type_args) + 1) ] @@ -532,7 +531,7 @@ def _generate_component_create_functiondef( inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]") if inside_of_tuple == "()": - return ast.Name(id="EventType[[], BASE_STATE]") + return ast.Name(id="EventType[()]") arguments = [""] @@ -559,16 +558,14 @@ def _generate_component_create_functiondef( ] all_count_args_type = [ - ast.Name( - f"EventType[[{', '.join(arguments_without_var[:i])}], BASE_STATE]" - ) + ast.Name(f"EventType[[{', '.join(arguments_without_var[:i])}]]") for i in range(len(arguments) + 1) ] return ast.Name( id=f"Union[{', '.join(map(ast.unparse, all_count_args_type))}]" ) - return ast.Name(id="EventType[..., BASE_STATE]") + return ast.Name(id="EventType[Any]") event_triggers = clz().get_event_triggers()