[REF-2272] Support declaring EventHandlers directly in component (#2952)

This commit is contained in:
Martin Xu 2024-03-29 06:26:07 -07:00 committed by GitHub
parent e2eced79b1
commit f372402ee4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 121 additions and 17 deletions

View File

@ -131,6 +131,7 @@ _MAPPING = {
"reflex.event": [
"event",
"EventChain",
"EventHandler",
"background",
"call_script",
"clear_local_storage",

View File

@ -113,6 +113,7 @@ from reflex import constants as constants
from reflex.constants import Env as Env
from reflex import event as event
from reflex.event import EventChain as EventChain
from reflex.event import EventHandler as EventHandler
from reflex.event import background as background
from reflex.event import call_script as call_script
from reflex.event import clear_local_storage as clear_local_storage

View File

@ -236,6 +236,8 @@ class Component(BaseComponent, ABC):
if types._issubclass(field.type_, Var):
field.required = False
field.default = Var.create(field.default)
elif types._issubclass(field.type_, EventHandler):
field.required = False
# Ensure renamed props from parent classes are applied to the subclass.
if cls._rename_props:
@ -272,7 +274,8 @@ class Component(BaseComponent, ABC):
# Get the component fields, triggers, and props.
fields = self.get_fields()
triggers = self.get_event_triggers().keys()
component_specific_triggers = self.get_event_triggers()
triggers = component_specific_triggers.keys()
props = self.get_props()
# Add any events triggers.
@ -327,7 +330,9 @@ class Component(BaseComponent, ABC):
# Check if the key is an event trigger.
if key in triggers:
# Temporarily disable full control for event triggers.
kwargs["event_triggers"][key] = self._create_event_chain(key, value)
kwargs["event_triggers"][key] = self._create_event_chain(
value=value, args_spec=component_specific_triggers[key]
)
# Remove any keys that were added as events.
for key in kwargs["event_triggers"]:
@ -359,7 +364,7 @@ class Component(BaseComponent, ABC):
def _create_event_chain(
self,
event_trigger: str,
args_spec: Any,
value: Union[
Var, EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], Callable
],
@ -367,7 +372,7 @@ class Component(BaseComponent, ABC):
"""Create an event chain from a variety of input types.
Args:
event_trigger: The event trigger to bind the chain to.
args_spec: The args_spec of the the event trigger being bound.
value: The value to create the event chain from.
Returns:
@ -376,9 +381,6 @@ class Component(BaseComponent, ABC):
Raises:
ValueError: If the value is not a valid event chain.
"""
# Check if the trigger is a controlled event.
triggers = self.get_event_triggers()
# If it's an event chain var, return it.
if isinstance(value, Var):
if value._var_type is not EventChain:
@ -388,8 +390,6 @@ class Component(BaseComponent, ABC):
# Trust that the caller knows what they're doing passing an EventChain directly
return value
arg_spec = triggers.get(event_trigger, lambda: [])
# If the input is a single event handler, wrap it in a list.
if isinstance(value, (EventHandler, EventSpec)):
value = [value]
@ -401,7 +401,7 @@ class Component(BaseComponent, ABC):
if isinstance(v, EventHandler):
# Call the event handler to get the event.
try:
event = call_event_handler(v, arg_spec) # type: ignore
event = call_event_handler(v, args_spec)
except ValueError as err:
raise ValueError(
f" {err} defined in the `{type(self).__name__}` component"
@ -414,13 +414,13 @@ class Component(BaseComponent, ABC):
events.append(v)
elif isinstance(v, Callable):
# Call the lambda to get the event chain.
events.extend(call_event_fn(v, arg_spec)) # type: ignore
events.extend(call_event_fn(v, args_spec))
else:
raise ValueError(f"Invalid event: {v}")
# If the input is a callable, create an event chain.
elif isinstance(value, Callable):
events = call_event_fn(value, arg_spec) # type: ignore
events = call_event_fn(value, args_spec)
# Otherwise, raise an error.
else:
@ -435,7 +435,7 @@ class Component(BaseComponent, ABC):
event_actions.update(e.event_actions)
# Return the event chain.
if isinstance(arg_spec, Var):
if isinstance(args_spec, Var):
return EventChain(
events=events,
args_spec=None,
@ -444,7 +444,7 @@ class Component(BaseComponent, ABC):
else:
return EventChain(
events=events,
args_spec=arg_spec, # type: ignore
args_spec=args_spec,
event_actions=event_actions,
)
@ -454,7 +454,7 @@ class Component(BaseComponent, ABC):
Returns:
The event triggers.
"""
return {
default_triggers = {
EventTriggers.ON_FOCUS: lambda: [],
EventTriggers.ON_BLUR: lambda: [],
EventTriggers.ON_CLICK: lambda: [],
@ -471,6 +471,14 @@ class Component(BaseComponent, ABC):
EventTriggers.ON_MOUNT: lambda: [],
EventTriggers.ON_UNMOUNT: lambda: [],
}
# Look for component specific triggers,
# e.g. variable declared as EventHandler types.
for field in self.get_fields().values():
if types._issubclass(field.type_, EventHandler):
default_triggers[field.name] = getattr(
field.type_, "args_spec", lambda: []
)
return default_triggers
def __repr__(self) -> str:
"""Represent the component in React.
@ -1352,6 +1360,9 @@ class CustomComponent(Component):
# Set the tag to the name of the function.
self.tag = format.to_title_case(self.component_fn.__name__)
# Get the event triggers defined in the component declaration.
event_triggers_in_component_declaration = self.get_event_triggers()
# Set the props.
props = typing.get_type_hints(self.component_fn)
for key, value in kwargs.items():
@ -1364,7 +1375,12 @@ class CustomComponent(Component):
# Handle event chains.
if types._issubclass(type_, EventChain):
value = self._create_event_chain(key, value)
value = self._create_event_chain(
value=value,
args_spec=event_triggers_in_component_declaration.get(
key, lambda: []
),
)
self.props[format.to_camel_case(key)] = value
continue

View File

@ -1,4 +1,5 @@
"""Define event classes to connect the frontend and backend."""
from __future__ import annotations
import inspect
@ -12,6 +13,7 @@ from typing import (
Optional,
Tuple,
Union,
_GenericAlias, # type: ignore
get_type_hints,
)
@ -106,7 +108,7 @@ class EventHandler(EventActionsMixin):
fn: Any
# The full name of the state class this event handler is attached to.
# Emtpy string means this event handler is a server side event.
# Empty string means this event handler is a server side event.
state_full_name: str = ""
class Config:
@ -115,6 +117,21 @@ class EventHandler(EventActionsMixin):
# Needed to allow serialization of Callable.
frozen = True
@classmethod
def __class_getitem__(cls, args_spec: str) -> _GenericAlias:
"""Get a typed EventHandler.
Args:
args_spec: The args_spec of the EventHandler.
Returns:
The EventHandler class item.
"""
gen = _GenericAlias(cls, Any)
# Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute.
gen.args_spec = args_spec
return gen
@property
def is_background(self) -> bool:
"""Whether the event handler is a background task.

View File

@ -1350,3 +1350,35 @@ def test_custom_component_add_imports(tags):
assert baseline.get_imports() == {"react": _list_to_import_vars(tags)}
assert test.get_imports() == baseline.get_imports()
def test_custom_component_declare_event_handlers_in_fields():
class ReferenceComponent(Component):
def get_event_triggers(self) -> Dict[str, Any]:
"""Test controlled triggers.
Returns:
Test controlled triggers.
"""
return {
**super().get_event_triggers(),
"on_a": lambda e: [e],
"on_b": lambda e: [e.target.value],
"on_c": lambda e: [],
"on_d": lambda: [],
"on_e": lambda: [],
}
class TestComponent(Component):
on_a: EventHandler[lambda e0: [e0]]
on_b: EventHandler[lambda e0: [e0.target.value]]
on_c: EventHandler[lambda e0: []]
on_d: EventHandler[lambda: []]
on_e: EventHandler
custom_component = ReferenceComponent.create()
test_component = TestComponent.create()
assert (
custom_component.get_event_triggers().keys()
== test_component.get_event_triggers().keys()
)

View File

@ -0,0 +1,37 @@
from __future__ import annotations
from typing import Any
from reflex.components.component import Component
from reflex.event import EventHandler
# This is a repeat of its namesake in test_component.py.
def test_custom_component_declare_event_handlers_in_fields():
class ReferenceComponent(Component):
def get_event_triggers(self) -> dict[str, Any]:
"""Test controlled triggers.
Returns:
Test controlled triggers.
"""
return {
**super().get_event_triggers(),
"on_a": lambda e: [e],
"on_b": lambda e: [e.target.value],
"on_c": lambda e: [],
"on_d": lambda: [],
}
class TestComponent(Component):
on_a: EventHandler[lambda e0: [e0]]
on_b: EventHandler[lambda e0: [e0.target.value]]
on_c: EventHandler[lambda e0: []]
on_d: EventHandler[lambda: []]
custom_component = ReferenceComponent.create()
test_component = TestComponent.create()
assert (
custom_component.get_event_triggers().keys()
== test_component.get_event_triggers().keys()
)