Fix annotated EventHandler (#3076)

This commit is contained in:
Masen Furer 2024-04-11 15:34:00 -07:00 committed by GitHub
parent 713ee06ab7
commit fc0be257a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 17 deletions

View File

@ -474,9 +474,11 @@ class Component(BaseComponent, ABC):
# e.g. variable declared as EventHandler types.
for field in self.get_fields().values():
if types._issubclass(field.type_, EventHandler):
default_triggers[field.name] = getattr(
field.type_, "args_spec", lambda: []
)
args_spec = None
annotation = field.annotation
if hasattr(annotation, "__metadata__"):
args_spec = annotation.__metadata__[0]
default_triggers[field.name] = args_spec or (lambda: [])
return default_triggers
def __repr__(self) -> str:

View File

@ -13,7 +13,6 @@ from typing import (
Optional,
Tuple,
Union,
_GenericAlias, # type: ignore
get_type_hints,
)
@ -23,6 +22,11 @@ from reflex.utils import console, format
from reflex.utils.types import ArgsSpec
from reflex.vars import BaseVar, Var
try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated
class Event(Base):
"""An event that describes any state change in the app."""
@ -118,7 +122,7 @@ class EventHandler(EventActionsMixin):
frozen = True
@classmethod
def __class_getitem__(cls, args_spec: str) -> _GenericAlias:
def __class_getitem__(cls, args_spec: str) -> Annotated:
"""Get a typed EventHandler.
Args:
@ -127,10 +131,7 @@ class EventHandler(EventActionsMixin):
Returns:
The EventHandler class item.
"""
gen = _GenericAlias(cls, Any)
# Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute.
gen.args_spec = args_spec
return gen
return Annotated[cls, args_spec]
@property
def is_background(self) -> bool:

View File

@ -15,7 +15,7 @@ from reflex.components.component import (
custom_component,
)
from reflex.constants import EventTriggers
from reflex.event import EventChain, EventHandler
from reflex.event import EventChain, EventHandler, parse_args_spec
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
@ -1542,11 +1542,12 @@ def test_custom_component_declare_event_handlers_in_fields():
"""
return {
**super().get_event_triggers(),
"on_a": lambda e: [e],
"on_b": lambda e: [e.target.value],
"on_c": lambda e: [],
"on_a": lambda e0: [e0],
"on_b": lambda e0: [e0.target.value],
"on_c": lambda e0: [],
"on_d": lambda: [],
"on_e": lambda: [],
"on_f": lambda a, b, c: [c, b, a],
}
class TestComponent(Component):
@ -1555,10 +1556,16 @@ def test_custom_component_declare_event_handlers_in_fields():
on_c: EventHandler[lambda e0: []]
on_d: EventHandler[lambda: []]
on_e: EventHandler
on_f: EventHandler[lambda a, b, c: [c, b, a]]
custom_component = ReferenceComponent.create()
test_component = TestComponent.create()
assert (
custom_component.get_event_triggers().keys()
== test_component.get_event_triggers().keys()
)
custom_triggers = custom_component.get_event_triggers()
test_triggers = test_component.get_event_triggers()
assert custom_triggers.keys() == test_triggers.keys()
for trigger_name in custom_component.get_event_triggers():
for v1, v2 in zip(
parse_args_spec(test_triggers[trigger_name]),
parse_args_spec(custom_triggers[trigger_name]),
):
assert v1.equals(v2)