remove base_state from event types

This commit is contained in:
Khaleel Al-Adhami 2025-02-03 14:52:52 -08:00
parent ef93161840
commit 035b116a21
5 changed files with 88 additions and 67 deletions

View File

@ -25,7 +25,6 @@ from typing import (
Callable,
Coroutine,
Dict,
Generic,
List,
MutableMapping,
Optional,
@ -76,7 +75,6 @@ from reflex.components.radix import themes
from reflex.config import environment, get_config
from reflex.event import (
_EVENT_FIELDS,
BASE_STATE,
Event,
EventHandler,
EventSpec,
@ -196,7 +194,7 @@ class OverlayFragment(Fragment):
@dataclasses.dataclass(
frozen=True,
)
class UnevaluatedPage(Generic[BASE_STATE]):
class UnevaluatedPage:
"""An uncompiled page."""
component: Union[Component, ComponentCallable]
@ -204,7 +202,7 @@ class UnevaluatedPage(Generic[BASE_STATE]):
title: Union[Var, str, None]
description: Union[Var, str, None]
image: str
on_load: Union[EventType[[], BASE_STATE], None]
on_load: Union[EventType[()], None]
meta: List[Dict[str, str]]
@ -279,7 +277,7 @@ class App(MiddlewareMixin, LifespanMixin):
_state_manager: Optional[StateManager] = None
# Mapping from a route to event handlers to trigger when the page loads.
_load_events: Dict[str, List[IndividualEventType[[], Any]]] = dataclasses.field(
_load_events: Dict[str, List[IndividualEventType[()]]] = dataclasses.field(
default_factory=dict
)
@ -544,7 +542,7 @@ class App(MiddlewareMixin, LifespanMixin):
title: str | Var | None = None,
description: str | Var | None = None,
image: str = constants.DefaultPage.IMAGE,
on_load: EventType[[], BASE_STATE] | None = None,
on_load: EventType[()] | None = None,
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
):
"""Add a page to the app.
@ -648,7 +646,7 @@ class App(MiddlewareMixin, LifespanMixin):
if save_page:
self._pages[route] = component
def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]:
def get_load_events(self, route: str) -> list[IndividualEventType[()]]:
"""Get the load events for a route.
Args:
@ -710,7 +708,7 @@ class App(MiddlewareMixin, LifespanMixin):
title: str = constants.Page404.TITLE,
image: str = constants.Page404.IMAGE,
description: str = constants.Page404.DESCRIPTION,
on_load: EventType[[], BASE_STATE] | None = None,
on_load: EventType[()] | None = None,
meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
):
"""Define a custom 404 page for any url having no match.

View File

@ -20,17 +20,16 @@ from typing import (
Tuple,
Type,
Union,
Unpack,
get_type_hints,
overload,
)
from typing_extensions import (
Concatenate,
ParamSpec,
Protocol,
TypeAliasType,
TypedDict,
TypeVar,
TypeVarTuple,
get_args,
get_origin,
)
@ -1763,8 +1762,8 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
)
P = ParamSpec("P")
Q = ParamSpec("Q")
P = TypeVarTuple("P")
Q = TypeVarTuple("Q")
T = TypeVar("T")
V = TypeVar("V")
V2 = TypeVar("V2")
@ -1773,10 +1772,10 @@ V4 = TypeVar("V4")
V5 = TypeVar("V5")
class EventCallback(Generic[P, T]):
class EventCallback(Generic[Unpack[P]]):
"""A descriptor that wraps a function to be used as an event."""
def __init__(self, func: Callable[Concatenate[Any, P], T]):
def __init__(self, func: Callable[[Any, Unpack[P]], Any]):
"""Initialize the descriptor with the function to be wrapped.
Args:
@ -1804,37 +1803,37 @@ class EventCallback(Generic[P, T]):
@overload
def __call__(
self: EventCallback[Q, T],
) -> EventCallback[Q, T]: ...
self: EventCallback[Unpack[Q]],
) -> EventCallback[Unpack[Q]]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, Q], T], value: V | Var[V]
) -> EventCallback[Q, T]: ...
self: EventCallback[V, Unpack[Q]], value: V | Var[V]
) -> EventCallback[Unpack[Q]]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, V2, Q], T],
self: EventCallback[V, V2, Unpack[Q]],
value: V | Var[V],
value2: V2 | Var[V2],
) -> EventCallback[Q, T]: ...
) -> EventCallback[Unpack[Q]]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, V2, V3, Q], T],
self: EventCallback[V, V2, V3, Unpack[Q]],
value: V | Var[V],
value2: V2 | Var[V2],
value3: V3 | Var[V3],
) -> EventCallback[Q, T]: ...
) -> EventCallback[Unpack[Q]]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, V2, V3, V4, Q], T],
self: EventCallback[V, V2, V3, V4, Unpack[Q]],
value: V | Var[V],
value2: V2 | Var[V2],
value3: V3 | Var[V3],
value4: V4 | Var[V4],
) -> EventCallback[Q, T]: ...
) -> EventCallback[Unpack[Q]]: ...
def __call__(self, *values) -> EventCallback: # pyright: ignore [reportInconsistentOverload]
"""Call the function with the values.
@ -1849,11 +1848,11 @@ class EventCallback(Generic[P, T]):
@overload
def __get__(
self: EventCallback[P, T], instance: None, owner: Any
) -> EventCallback[P, T]: ...
self: EventCallback[Unpack[P]], instance: None, owner: Any
) -> EventCallback[Unpack[P]]: ...
@overload
def __get__(self, instance: Any, owner: Any) -> Callable[P, T]: ...
def __get__(self, instance: Any, owner: Any) -> Callable[[Unpack[P]]]: ...
def __get__(self, instance: Any, owner: Any) -> Callable:
"""Get the function with the instance bound to it.
@ -1871,7 +1870,51 @@ class EventCallback(Generic[P, T]):
return partial(self.func, instance)
G = ParamSpec("G")
class LambdaEventCallback(Protocol[Unpack[P]]):
"""A protocol for a lambda event callback."""
@overload
def __call__(self: LambdaEventCallback[()]) -> Any: ...
@overload
def __call__(self: LambdaEventCallback[V], value: Var[V], /) -> Any: ...
@overload
def __call__(
self: LambdaEventCallback[V, V2], value: Var[V], value2: Var[V2], /
) -> Any: ...
@overload
def __call__(
self: LambdaEventCallback[V, V2, V3],
value: Var[V],
value2: Var[V2],
value3: Var[V3],
/,
) -> Any: ...
def __call__(self, *args: Var) -> Any:
"""Call the lambda with the args.
Args:
*args: The args to call the lambda with.
Returns:
The result of calling the lambda with the args.
"""
BasicEventTypes = EventSpec | EventHandler | Var[Any]
ARGS = TypeVarTuple("ARGS")
LambdaOrState = LambdaEventCallback[Unpack[ARGS]] | EventCallback[Unpack[ARGS]]
ItemOrList = V | List[V]
IndividualEventType = BasicEventTypes | LambdaOrState[Unpack[ARGS]] | LambdaOrState[()]
EventType = ItemOrList[IndividualEventType[Unpack[ARGS]]]
if TYPE_CHECKING:
from reflex.state import BaseState
@ -1880,25 +1923,6 @@ if TYPE_CHECKING:
else:
BASE_STATE = TypeVar("BASE_STATE")
StateCallable = TypeAliasType(
"StateCallable",
Callable[Concatenate[BASE_STATE, G], Any],
type_params=(G, BASE_STATE),
)
IndividualEventType = Union[
EventSpec,
EventHandler,
Callable[G, Any],
StateCallable[G, BASE_STATE],
EventCallback[G, Any],
Var[Any],
]
ItemOrList = Union[V, List[V]]
EventType = ItemOrList[IndividualEventType[G, BASE_STATE]]
class EventNamespace(types.SimpleNamespace):
"""A namespace for event related classes."""
@ -1919,24 +1943,26 @@ class EventNamespace(types.SimpleNamespace):
@staticmethod
def __call__(
func: None = None, *, background: bool | None = None
) -> Callable[[Callable[Concatenate[BASE_STATE, P], T]], EventCallback[P, T]]: ... # pyright: ignore [reportInvalidTypeVarUse]
) -> Callable[
[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]] # pyright: ignore [reportInvalidTypeVarUse]
]: ...
@overload
@staticmethod
def __call__(
func: Callable[Concatenate[BASE_STATE, P], T],
func: Callable[[BASE_STATE, Unpack[P]], Any],
*,
background: bool | None = None,
) -> EventCallback[P, T]: ...
) -> EventCallback[Unpack[P]]: ...
@staticmethod
def __call__(
func: Callable[Concatenate[BASE_STATE, P], T] | None = None,
func: Callable[[BASE_STATE, Unpack[P]], Any] | None = None,
*,
background: bool | None = None,
) -> Union[
EventCallback[P, T],
Callable[[Callable[Concatenate[BASE_STATE, P], T]], EventCallback[P, T]],
EventCallback[Unpack[P]],
Callable[[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]]],
]:
"""Wrap a function to be used as an event.
@ -1952,8 +1978,8 @@ class EventNamespace(types.SimpleNamespace):
"""
def wrapper(
func: Callable[Concatenate[BASE_STATE, P], T],
) -> EventCallback[P, T]:
func: Callable[[BASE_STATE, Unpack[P]], T],
) -> EventCallback[Unpack[P]]:
if background is True:
if not inspect.iscoroutinefunction(
func

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from typing import Any, Callable, Dict, List
from reflex.config import get_config
from reflex.event import BASE_STATE, EventType
from reflex.event import EventType
DECORATED_PAGES: Dict[str, List] = defaultdict(list)
@ -18,7 +18,7 @@ def page(
description: str | None = None,
meta: list[Any] | None = None,
script_tags: list[Any] | None = None,
on_load: EventType[[], BASE_STATE] | None = None,
on_load: EventType[()] | None = None,
):
"""Decorate a function as a page.

View File

@ -502,7 +502,7 @@ if TYPE_CHECKING:
def format_queue_events(
events: EventType | None = None,
events: EventType[Any] | None = None,
args_spec: Optional[ArgsSpec] = None,
) -> Var[EventChain]:
"""Format a list of event handler / event spec as a javascript callback.

View File

@ -75,7 +75,6 @@ DEFAULT_IMPORTS = {
"EventHandler",
"EventSpec",
"EventType",
"BASE_STATE",
"KeyInputInfo",
],
"reflex.style": ["Style"],
@ -502,7 +501,7 @@ def _generate_component_create_functiondef(
def figure_out_return_type(annotation: Any):
if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
return ast.Name(id="EventType[..., BASE_STATE]")
return ast.Name(id="EventType[Any]")
if not isinstance(annotation, str) and get_origin(annotation) is tuple:
arguments = get_args(annotation)
@ -518,7 +517,7 @@ def _generate_component_create_functiondef(
# Get all prefixes of the type arguments
all_count_args_type = [
ast.Name(
f"EventType[[{', '.join([ast.unparse(arg) for arg in type_args[:i]])}], BASE_STATE]"
f"EventType[[{', '.join([ast.unparse(arg) for arg in type_args[:i]])}]]"
)
for i in range(len(type_args) + 1)
]
@ -532,7 +531,7 @@ def _generate_component_create_functiondef(
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
if inside_of_tuple == "()":
return ast.Name(id="EventType[[], BASE_STATE]")
return ast.Name(id="EventType[()]")
arguments = [""]
@ -559,16 +558,14 @@ def _generate_component_create_functiondef(
]
all_count_args_type = [
ast.Name(
f"EventType[[{', '.join(arguments_without_var[:i])}], BASE_STATE]"
)
ast.Name(f"EventType[[{', '.join(arguments_without_var[:i])}]]")
for i in range(len(arguments) + 1)
]
return ast.Name(
id=f"Union[{', '.join(map(ast.unparse, all_count_args_type))}]"
)
return ast.Name(id="EventType[..., BASE_STATE]")
return ast.Name(id="EventType[Any]")
event_triggers = clz().get_event_triggers()