Fix annotated EventHandler (#3076)
This commit is contained in:
parent
713ee06ab7
commit
fc0be257a3
@ -474,9 +474,11 @@ class Component(BaseComponent, ABC):
|
|||||||
# e.g. variable declared as EventHandler types.
|
# e.g. variable declared as EventHandler types.
|
||||||
for field in self.get_fields().values():
|
for field in self.get_fields().values():
|
||||||
if types._issubclass(field.type_, EventHandler):
|
if types._issubclass(field.type_, EventHandler):
|
||||||
default_triggers[field.name] = getattr(
|
args_spec = None
|
||||||
field.type_, "args_spec", lambda: []
|
annotation = field.annotation
|
||||||
)
|
if hasattr(annotation, "__metadata__"):
|
||||||
|
args_spec = annotation.__metadata__[0]
|
||||||
|
default_triggers[field.name] = args_spec or (lambda: [])
|
||||||
return default_triggers
|
return default_triggers
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
@ -13,7 +13,6 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
_GenericAlias, # type: ignore
|
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,6 +22,11 @@ from reflex.utils import console, format
|
|||||||
from reflex.utils.types import ArgsSpec
|
from reflex.utils.types import ArgsSpec
|
||||||
from reflex.vars import BaseVar, Var
|
from reflex.vars import BaseVar, Var
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Annotated
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
class Event(Base):
|
class Event(Base):
|
||||||
"""An event that describes any state change in the app."""
|
"""An event that describes any state change in the app."""
|
||||||
@ -118,7 +122,7 @@ class EventHandler(EventActionsMixin):
|
|||||||
frozen = True
|
frozen = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __class_getitem__(cls, args_spec: str) -> _GenericAlias:
|
def __class_getitem__(cls, args_spec: str) -> Annotated:
|
||||||
"""Get a typed EventHandler.
|
"""Get a typed EventHandler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -127,10 +131,7 @@ class EventHandler(EventActionsMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
The EventHandler class item.
|
The EventHandler class item.
|
||||||
"""
|
"""
|
||||||
gen = _GenericAlias(cls, Any)
|
return Annotated[cls, args_spec]
|
||||||
# Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute.
|
|
||||||
gen.args_spec = args_spec
|
|
||||||
return gen
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_background(self) -> bool:
|
def is_background(self) -> bool:
|
||||||
|
@ -15,7 +15,7 @@ from reflex.components.component import (
|
|||||||
custom_component,
|
custom_component,
|
||||||
)
|
)
|
||||||
from reflex.constants import EventTriggers
|
from reflex.constants import EventTriggers
|
||||||
from reflex.event import EventChain, EventHandler
|
from reflex.event import EventChain, EventHandler, parse_args_spec
|
||||||
from reflex.state import BaseState
|
from reflex.state import BaseState
|
||||||
from reflex.style import Style
|
from reflex.style import Style
|
||||||
from reflex.utils import imports
|
from reflex.utils import imports
|
||||||
@ -1542,11 +1542,12 @@ def test_custom_component_declare_event_handlers_in_fields():
|
|||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
**super().get_event_triggers(),
|
**super().get_event_triggers(),
|
||||||
"on_a": lambda e: [e],
|
"on_a": lambda e0: [e0],
|
||||||
"on_b": lambda e: [e.target.value],
|
"on_b": lambda e0: [e0.target.value],
|
||||||
"on_c": lambda e: [],
|
"on_c": lambda e0: [],
|
||||||
"on_d": lambda: [],
|
"on_d": lambda: [],
|
||||||
"on_e": lambda: [],
|
"on_e": lambda: [],
|
||||||
|
"on_f": lambda a, b, c: [c, b, a],
|
||||||
}
|
}
|
||||||
|
|
||||||
class TestComponent(Component):
|
class TestComponent(Component):
|
||||||
@ -1555,10 +1556,16 @@ def test_custom_component_declare_event_handlers_in_fields():
|
|||||||
on_c: EventHandler[lambda e0: []]
|
on_c: EventHandler[lambda e0: []]
|
||||||
on_d: EventHandler[lambda: []]
|
on_d: EventHandler[lambda: []]
|
||||||
on_e: EventHandler
|
on_e: EventHandler
|
||||||
|
on_f: EventHandler[lambda a, b, c: [c, b, a]]
|
||||||
|
|
||||||
custom_component = ReferenceComponent.create()
|
custom_component = ReferenceComponent.create()
|
||||||
test_component = TestComponent.create()
|
test_component = TestComponent.create()
|
||||||
assert (
|
custom_triggers = custom_component.get_event_triggers()
|
||||||
custom_component.get_event_triggers().keys()
|
test_triggers = test_component.get_event_triggers()
|
||||||
== test_component.get_event_triggers().keys()
|
assert custom_triggers.keys() == test_triggers.keys()
|
||||||
)
|
for trigger_name in custom_component.get_event_triggers():
|
||||||
|
for v1, v2 in zip(
|
||||||
|
parse_args_spec(test_triggers[trigger_name]),
|
||||||
|
parse_args_spec(custom_triggers[trigger_name]),
|
||||||
|
):
|
||||||
|
assert v1.equals(v2)
|
||||||
|
Loading…
Reference in New Issue
Block a user