remove base_state from event types

This commit is contained in:
Khaleel Al-Adhami 2025-02-03 14:52:52 -08:00
parent ef93161840
commit 035b116a21
5 changed files with 88 additions and 67 deletions

View File

@ -25,7 +25,6 @@ from typing import (
Callable, Callable,
Coroutine, Coroutine,
Dict, Dict,
Generic,
List, List,
MutableMapping, MutableMapping,
Optional, Optional,
@ -76,7 +75,6 @@ from reflex.components.radix import themes
from reflex.config import environment, get_config from reflex.config import environment, get_config
from reflex.event import ( from reflex.event import (
_EVENT_FIELDS, _EVENT_FIELDS,
BASE_STATE,
Event, Event,
EventHandler, EventHandler,
EventSpec, EventSpec,
@ -196,7 +194,7 @@ class OverlayFragment(Fragment):
@dataclasses.dataclass( @dataclasses.dataclass(
frozen=True, frozen=True,
) )
class UnevaluatedPage(Generic[BASE_STATE]): class UnevaluatedPage:
"""An uncompiled page.""" """An uncompiled page."""
component: Union[Component, ComponentCallable] component: Union[Component, ComponentCallable]
@ -204,7 +202,7 @@ class UnevaluatedPage(Generic[BASE_STATE]):
title: Union[Var, str, None] title: Union[Var, str, None]
description: Union[Var, str, None] description: Union[Var, str, None]
image: str image: str
on_load: Union[EventType[[], BASE_STATE], None] on_load: Union[EventType[()], None]
meta: List[Dict[str, str]] meta: List[Dict[str, str]]
@ -279,7 +277,7 @@ class App(MiddlewareMixin, LifespanMixin):
_state_manager: Optional[StateManager] = None _state_manager: Optional[StateManager] = None
# Mapping from a route to event handlers to trigger when the page loads. # Mapping from a route to event handlers to trigger when the page loads.
_load_events: Dict[str, List[IndividualEventType[[], Any]]] = dataclasses.field( _load_events: Dict[str, List[IndividualEventType[()]]] = dataclasses.field(
default_factory=dict default_factory=dict
) )
@ -544,7 +542,7 @@ class App(MiddlewareMixin, LifespanMixin):
title: str | Var | None = None, title: str | Var | None = None,
description: str | Var | None = None, description: str | Var | None = None,
image: str = constants.DefaultPage.IMAGE, image: str = constants.DefaultPage.IMAGE,
on_load: EventType[[], BASE_STATE] | None = None, on_load: EventType[()] | None = None,
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST, meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
): ):
"""Add a page to the app. """Add a page to the app.
@ -648,7 +646,7 @@ class App(MiddlewareMixin, LifespanMixin):
if save_page: if save_page:
self._pages[route] = component self._pages[route] = component
def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]: def get_load_events(self, route: str) -> list[IndividualEventType[()]]:
"""Get the load events for a route. """Get the load events for a route.
Args: Args:
@ -710,7 +708,7 @@ class App(MiddlewareMixin, LifespanMixin):
title: str = constants.Page404.TITLE, title: str = constants.Page404.TITLE,
image: str = constants.Page404.IMAGE, image: str = constants.Page404.IMAGE,
description: str = constants.Page404.DESCRIPTION, description: str = constants.Page404.DESCRIPTION,
on_load: EventType[[], BASE_STATE] | None = None, on_load: EventType[()] | None = None,
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST, meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
): ):
"""Define a custom 404 page for any url having no match. """Define a custom 404 page for any url having no match.

View File

@ -20,17 +20,16 @@ from typing import (
Tuple, Tuple,
Type, Type,
Union, Union,
Unpack,
get_type_hints, get_type_hints,
overload, overload,
) )
from typing_extensions import ( from typing_extensions import (
Concatenate,
ParamSpec,
Protocol, Protocol,
TypeAliasType,
TypedDict, TypedDict,
TypeVar, TypeVar,
TypeVarTuple,
get_args, get_args,
get_origin, get_origin,
) )
@ -1763,8 +1762,8 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
) )
P = ParamSpec("P") P = TypeVarTuple("P")
Q = ParamSpec("Q") Q = TypeVarTuple("Q")
T = TypeVar("T") T = TypeVar("T")
V = TypeVar("V") V = TypeVar("V")
V2 = TypeVar("V2") V2 = TypeVar("V2")
@ -1773,10 +1772,10 @@ V4 = TypeVar("V4")
V5 = TypeVar("V5") V5 = TypeVar("V5")
class EventCallback(Generic[P, T]): class EventCallback(Generic[Unpack[P]]):
"""A descriptor that wraps a function to be used as an event.""" """A descriptor that wraps a function to be used as an event."""
def __init__(self, func: Callable[Concatenate[Any, P], T]): def __init__(self, func: Callable[[Any, Unpack[P]], Any]):
"""Initialize the descriptor with the function to be wrapped. """Initialize the descriptor with the function to be wrapped.
Args: Args:
@ -1804,37 +1803,37 @@ class EventCallback(Generic[P, T]):
@overload @overload
def __call__( def __call__(
self: EventCallback[Q, T], self: EventCallback[Unpack[Q]],
) -> EventCallback[Q, T]: ... ) -> EventCallback[Unpack[Q]]: ...
@overload @overload
def __call__( def __call__(
self: EventCallback[Concatenate[V, Q], T], value: V | Var[V] self: EventCallback[V, Unpack[Q]], value: V | Var[V]
) -> EventCallback[Q, T]: ... ) -> EventCallback[Unpack[Q]]: ...
@overload @overload
def __call__( def __call__(
self: EventCallback[Concatenate[V, V2, Q], T], self: EventCallback[V, V2, Unpack[Q]],
value: V | Var[V], value: V | Var[V],
value2: V2 | Var[V2], value2: V2 | Var[V2],
) -> EventCallback[Q, T]: ... ) -> EventCallback[Unpack[Q]]: ...
@overload @overload
def __call__( def __call__(
self: EventCallback[Concatenate[V, V2, V3, Q], T], self: EventCallback[V, V2, V3, Unpack[Q]],
value: V | Var[V], value: V | Var[V],
value2: V2 | Var[V2], value2: V2 | Var[V2],
value3: V3 | Var[V3], value3: V3 | Var[V3],
) -> EventCallback[Q, T]: ... ) -> EventCallback[Unpack[Q]]: ...
@overload @overload
def __call__( def __call__(
self: EventCallback[Concatenate[V, V2, V3, V4, Q], T], self: EventCallback[V, V2, V3, V4, Unpack[Q]],
value: V | Var[V], value: V | Var[V],
value2: V2 | Var[V2], value2: V2 | Var[V2],
value3: V3 | Var[V3], value3: V3 | Var[V3],
value4: V4 | Var[V4], value4: V4 | Var[V4],
) -> EventCallback[Q, T]: ... ) -> EventCallback[Unpack[Q]]: ...
def __call__(self, *values) -> EventCallback: # pyright: ignore [reportInconsistentOverload] def __call__(self, *values) -> EventCallback: # pyright: ignore [reportInconsistentOverload]
"""Call the function with the values. """Call the function with the values.
@ -1849,11 +1848,11 @@ class EventCallback(Generic[P, T]):
@overload @overload
def __get__( def __get__(
self: EventCallback[P, T], instance: None, owner: Any self: EventCallback[Unpack[P]], instance: None, owner: Any
) -> EventCallback[P, T]: ... ) -> EventCallback[Unpack[P]]: ...
@overload @overload
def __get__(self, instance: Any, owner: Any) -> Callable[P, T]: ... def __get__(self, instance: Any, owner: Any) -> Callable[[Unpack[P]]]: ...
def __get__(self, instance: Any, owner: Any) -> Callable: def __get__(self, instance: Any, owner: Any) -> Callable:
"""Get the function with the instance bound to it. """Get the function with the instance bound to it.
@ -1871,7 +1870,51 @@ class EventCallback(Generic[P, T]):
return partial(self.func, instance) return partial(self.func, instance)
G = ParamSpec("G") class LambdaEventCallback(Protocol[Unpack[P]]):
"""A protocol for a lambda event callback."""
@overload
def __call__(self: LambdaEventCallback[()]) -> Any: ...
@overload
def __call__(self: LambdaEventCallback[V], value: Var[V], /) -> Any: ...
@overload
def __call__(
self: LambdaEventCallback[V, V2], value: Var[V], value2: Var[V2], /
) -> Any: ...
@overload
def __call__(
self: LambdaEventCallback[V, V2, V3],
value: Var[V],
value2: Var[V2],
value3: Var[V3],
/,
) -> Any: ...
def __call__(self, *args: Var) -> Any:
"""Call the lambda with the args.
Args:
*args: The args to call the lambda with.
Returns:
The result of calling the lambda with the args.
"""
BasicEventTypes = EventSpec | EventHandler | Var[Any]
ARGS = TypeVarTuple("ARGS")
LambdaOrState = LambdaEventCallback[Unpack[ARGS]] | EventCallback[Unpack[ARGS]]
ItemOrList = V | List[V]
IndividualEventType = BasicEventTypes | LambdaOrState[Unpack[ARGS]] | LambdaOrState[()]
EventType = ItemOrList[IndividualEventType[Unpack[ARGS]]]
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.state import BaseState from reflex.state import BaseState
@ -1880,25 +1923,6 @@ if TYPE_CHECKING:
else: else:
BASE_STATE = TypeVar("BASE_STATE") BASE_STATE = TypeVar("BASE_STATE")
StateCallable = TypeAliasType(
"StateCallable",
Callable[Concatenate[BASE_STATE, G], Any],
type_params=(G, BASE_STATE),
)
IndividualEventType = Union[
EventSpec,
EventHandler,
Callable[G, Any],
StateCallable[G, BASE_STATE],
EventCallback[G, Any],
Var[Any],
]
ItemOrList = Union[V, List[V]]
EventType = ItemOrList[IndividualEventType[G, BASE_STATE]]
class EventNamespace(types.SimpleNamespace): class EventNamespace(types.SimpleNamespace):
"""A namespace for event related classes.""" """A namespace for event related classes."""
@ -1919,24 +1943,26 @@ class EventNamespace(types.SimpleNamespace):
@staticmethod @staticmethod
def __call__( def __call__(
func: None = None, *, background: bool | None = None func: None = None, *, background: bool | None = None
) -> Callable[[Callable[Concatenate[BASE_STATE, P], T]], EventCallback[P, T]]: ... # pyright: ignore [reportInvalidTypeVarUse] ) -> Callable[
[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]] # pyright: ignore [reportInvalidTypeVarUse]
]: ...
@overload @overload
@staticmethod @staticmethod
def __call__( def __call__(
func: Callable[Concatenate[BASE_STATE, P], T], func: Callable[[BASE_STATE, Unpack[P]], Any],
*, *,
background: bool | None = None, background: bool | None = None,
) -> EventCallback[P, T]: ... ) -> EventCallback[Unpack[P]]: ...
@staticmethod @staticmethod
def __call__( def __call__(
func: Callable[Concatenate[BASE_STATE, P], T] | None = None, func: Callable[[BASE_STATE, Unpack[P]], Any] | None = None,
*, *,
background: bool | None = None, background: bool | None = None,
) -> Union[ ) -> Union[
EventCallback[P, T], EventCallback[Unpack[P]],
Callable[[Callable[Concatenate[BASE_STATE, P], T]], EventCallback[P, T]], Callable[[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]]],
]: ]:
"""Wrap a function to be used as an event. """Wrap a function to be used as an event.
@ -1952,8 +1978,8 @@ class EventNamespace(types.SimpleNamespace):
""" """
def wrapper( def wrapper(
func: Callable[Concatenate[BASE_STATE, P], T], func: Callable[[BASE_STATE, Unpack[P]], T],
) -> EventCallback[P, T]: ) -> EventCallback[Unpack[P]]:
if background is True: if background is True:
if not inspect.iscoroutinefunction( if not inspect.iscoroutinefunction(
func func

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List
from reflex.config import get_config from reflex.config import get_config
from reflex.event import BASE_STATE, EventType from reflex.event import EventType
DECORATED_PAGES: Dict[str, List] = defaultdict(list) DECORATED_PAGES: Dict[str, List] = defaultdict(list)
@ -18,7 +18,7 @@ def page(
description: str | None = None, description: str | None = None,
meta: list[Any] | None = None, meta: list[Any] | None = None,
script_tags: list[Any] | None = None, script_tags: list[Any] | None = None,
on_load: EventType[[], BASE_STATE] | None = None, on_load: EventType[()] | None = None,
): ):
"""Decorate a function as a page. """Decorate a function as a page.

View File

@ -502,7 +502,7 @@ if TYPE_CHECKING:
def format_queue_events( def format_queue_events(
events: EventType | None = None, events: EventType[Any] | None = None,
args_spec: Optional[ArgsSpec] = None, args_spec: Optional[ArgsSpec] = None,
) -> Var[EventChain]: ) -> Var[EventChain]:
"""Format a list of event handler / event spec as a javascript callback. """Format a list of event handler / event spec as a javascript callback.

View File

@ -75,7 +75,6 @@ DEFAULT_IMPORTS = {
"EventHandler", "EventHandler",
"EventSpec", "EventSpec",
"EventType", "EventType",
"BASE_STATE",
"KeyInputInfo", "KeyInputInfo",
], ],
"reflex.style": ["Style"], "reflex.style": ["Style"],
@ -502,7 +501,7 @@ def _generate_component_create_functiondef(
def figure_out_return_type(annotation: Any): def figure_out_return_type(annotation: Any):
if inspect.isclass(annotation) and issubclass(annotation, inspect._empty): if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
return ast.Name(id="EventType[..., BASE_STATE]") return ast.Name(id="EventType[Any]")
if not isinstance(annotation, str) and get_origin(annotation) is tuple: if not isinstance(annotation, str) and get_origin(annotation) is tuple:
arguments = get_args(annotation) arguments = get_args(annotation)
@ -518,7 +517,7 @@ def _generate_component_create_functiondef(
# Get all prefixes of the type arguments # Get all prefixes of the type arguments
all_count_args_type = [ all_count_args_type = [
ast.Name( ast.Name(
f"EventType[[{', '.join([ast.unparse(arg) for arg in type_args[:i]])}], BASE_STATE]" f"EventType[[{', '.join([ast.unparse(arg) for arg in type_args[:i]])}]]"
) )
for i in range(len(type_args) + 1) for i in range(len(type_args) + 1)
] ]
@ -532,7 +531,7 @@ def _generate_component_create_functiondef(
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]") inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
if inside_of_tuple == "()": if inside_of_tuple == "()":
return ast.Name(id="EventType[[], BASE_STATE]") return ast.Name(id="EventType[()]")
arguments = [""] arguments = [""]
@ -559,16 +558,14 @@ def _generate_component_create_functiondef(
] ]
all_count_args_type = [ all_count_args_type = [
ast.Name( ast.Name(f"EventType[[{', '.join(arguments_without_var[:i])}]]")
f"EventType[[{', '.join(arguments_without_var[:i])}], BASE_STATE]"
)
for i in range(len(arguments) + 1) for i in range(len(arguments) + 1)
] ]
return ast.Name( return ast.Name(
id=f"Union[{', '.join(map(ast.unparse, all_count_args_type))}]" id=f"Union[{', '.join(map(ast.unparse, all_count_args_type))}]"
) )
return ast.Name(id="EventType[..., BASE_STATE]") return ast.Name(id="EventType[Any]")
event_triggers = clz().get_event_triggers() event_triggers = clz().get_event_triggers()