diff --git a/reflex/components/component.py b/reflex/components/component.py index 29263bf7b..047ca3863 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -474,9 +474,11 @@ class Component(BaseComponent, ABC): # e.g. variable declared as EventHandler types. for field in self.get_fields().values(): if types._issubclass(field.type_, EventHandler): - default_triggers[field.name] = getattr( - field.type_, "args_spec", lambda: [] - ) + args_spec = None + annotation = field.annotation + if hasattr(annotation, "__metadata__"): + args_spec = annotation.__metadata__[0] + default_triggers[field.name] = args_spec or (lambda: []) return default_triggers def __repr__(self) -> str: diff --git a/reflex/event.py b/reflex/event.py index 19fdad5e9..bef86cdc5 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -13,7 +13,6 @@ from typing import ( Optional, Tuple, Union, - _GenericAlias, # type: ignore get_type_hints, ) @@ -23,6 +22,11 @@ from reflex.utils import console, format from reflex.utils.types import ArgsSpec from reflex.vars import BaseVar, Var +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + class Event(Base): """An event that describes any state change in the app.""" @@ -118,7 +122,7 @@ class EventHandler(EventActionsMixin): frozen = True @classmethod - def __class_getitem__(cls, args_spec: str) -> _GenericAlias: + def __class_getitem__(cls, args_spec: str) -> Annotated: """Get a typed EventHandler. Args: @@ -127,10 +131,7 @@ class EventHandler(EventActionsMixin): Returns: The EventHandler class item. """ - gen = _GenericAlias(cls, Any) - # Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute. - gen.args_spec = args_spec - return gen + return Annotated[cls, args_spec] @property def is_background(self) -> bool: diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 3c3b640a9..25756b552 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -15,7 +15,7 @@ from reflex.components.component import ( custom_component, ) from reflex.constants import EventTriggers -from reflex.event import EventChain, EventHandler +from reflex.event import EventChain, EventHandler, parse_args_spec from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports @@ -1542,11 +1542,12 @@ def test_custom_component_declare_event_handlers_in_fields(): """ return { **super().get_event_triggers(), - "on_a": lambda e: [e], - "on_b": lambda e: [e.target.value], - "on_c": lambda e: [], + "on_a": lambda e0: [e0], + "on_b": lambda e0: [e0.target.value], + "on_c": lambda e0: [], "on_d": lambda: [], "on_e": lambda: [], + "on_f": lambda a, b, c: [c, b, a], } class TestComponent(Component): @@ -1555,10 +1556,16 @@ def test_custom_component_declare_event_handlers_in_fields(): on_c: EventHandler[lambda e0: []] on_d: EventHandler[lambda: []] on_e: EventHandler + on_f: EventHandler[lambda a, b, c: [c, b, a]] custom_component = ReferenceComponent.create() test_component = TestComponent.create() - assert ( - custom_component.get_event_triggers().keys() - == test_component.get_event_triggers().keys() - ) + custom_triggers = custom_component.get_event_triggers() + test_triggers = test_component.get_event_triggers() + assert custom_triggers.keys() == test_triggers.keys() + for trigger_name in custom_component.get_event_triggers(): + for v1, v2 in zip( + parse_args_spec(test_triggers[trigger_name]), + parse_args_spec(custom_triggers[trigger_name]), + ): + assert v1.equals(v2)