[REF-2272] Support declaring EventHandlers directly in component (#2952)
This commit is contained in:
parent
e2eced79b1
commit
f372402ee4
@ -131,6 +131,7 @@ _MAPPING = {
|
|||||||
"reflex.event": [
|
"reflex.event": [
|
||||||
"event",
|
"event",
|
||||||
"EventChain",
|
"EventChain",
|
||||||
|
"EventHandler",
|
||||||
"background",
|
"background",
|
||||||
"call_script",
|
"call_script",
|
||||||
"clear_local_storage",
|
"clear_local_storage",
|
||||||
|
@ -113,6 +113,7 @@ from reflex import constants as constants
|
|||||||
from reflex.constants import Env as Env
|
from reflex.constants import Env as Env
|
||||||
from reflex import event as event
|
from reflex import event as event
|
||||||
from reflex.event import EventChain as EventChain
|
from reflex.event import EventChain as EventChain
|
||||||
|
from reflex.event import EventHandler as EventHandler
|
||||||
from reflex.event import background as background
|
from reflex.event import background as background
|
||||||
from reflex.event import call_script as call_script
|
from reflex.event import call_script as call_script
|
||||||
from reflex.event import clear_local_storage as clear_local_storage
|
from reflex.event import clear_local_storage as clear_local_storage
|
||||||
|
@ -236,6 +236,8 @@ class Component(BaseComponent, ABC):
|
|||||||
if types._issubclass(field.type_, Var):
|
if types._issubclass(field.type_, Var):
|
||||||
field.required = False
|
field.required = False
|
||||||
field.default = Var.create(field.default)
|
field.default = Var.create(field.default)
|
||||||
|
elif types._issubclass(field.type_, EventHandler):
|
||||||
|
field.required = False
|
||||||
|
|
||||||
# Ensure renamed props from parent classes are applied to the subclass.
|
# Ensure renamed props from parent classes are applied to the subclass.
|
||||||
if cls._rename_props:
|
if cls._rename_props:
|
||||||
@ -272,7 +274,8 @@ class Component(BaseComponent, ABC):
|
|||||||
|
|
||||||
# Get the component fields, triggers, and props.
|
# Get the component fields, triggers, and props.
|
||||||
fields = self.get_fields()
|
fields = self.get_fields()
|
||||||
triggers = self.get_event_triggers().keys()
|
component_specific_triggers = self.get_event_triggers()
|
||||||
|
triggers = component_specific_triggers.keys()
|
||||||
props = self.get_props()
|
props = self.get_props()
|
||||||
|
|
||||||
# Add any events triggers.
|
# Add any events triggers.
|
||||||
@ -327,7 +330,9 @@ class Component(BaseComponent, ABC):
|
|||||||
# Check if the key is an event trigger.
|
# Check if the key is an event trigger.
|
||||||
if key in triggers:
|
if key in triggers:
|
||||||
# Temporarily disable full control for event triggers.
|
# Temporarily disable full control for event triggers.
|
||||||
kwargs["event_triggers"][key] = self._create_event_chain(key, value)
|
kwargs["event_triggers"][key] = self._create_event_chain(
|
||||||
|
value=value, args_spec=component_specific_triggers[key]
|
||||||
|
)
|
||||||
|
|
||||||
# Remove any keys that were added as events.
|
# Remove any keys that were added as events.
|
||||||
for key in kwargs["event_triggers"]:
|
for key in kwargs["event_triggers"]:
|
||||||
@ -359,7 +364,7 @@ class Component(BaseComponent, ABC):
|
|||||||
|
|
||||||
def _create_event_chain(
|
def _create_event_chain(
|
||||||
self,
|
self,
|
||||||
event_trigger: str,
|
args_spec: Any,
|
||||||
value: Union[
|
value: Union[
|
||||||
Var, EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], Callable
|
Var, EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], Callable
|
||||||
],
|
],
|
||||||
@ -367,7 +372,7 @@ class Component(BaseComponent, ABC):
|
|||||||
"""Create an event chain from a variety of input types.
|
"""Create an event chain from a variety of input types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_trigger: The event trigger to bind the chain to.
|
args_spec: The args_spec of the the event trigger being bound.
|
||||||
value: The value to create the event chain from.
|
value: The value to create the event chain from.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -376,9 +381,6 @@ class Component(BaseComponent, ABC):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the value is not a valid event chain.
|
ValueError: If the value is not a valid event chain.
|
||||||
"""
|
"""
|
||||||
# Check if the trigger is a controlled event.
|
|
||||||
triggers = self.get_event_triggers()
|
|
||||||
|
|
||||||
# If it's an event chain var, return it.
|
# If it's an event chain var, return it.
|
||||||
if isinstance(value, Var):
|
if isinstance(value, Var):
|
||||||
if value._var_type is not EventChain:
|
if value._var_type is not EventChain:
|
||||||
@ -388,8 +390,6 @@ class Component(BaseComponent, ABC):
|
|||||||
# Trust that the caller knows what they're doing passing an EventChain directly
|
# Trust that the caller knows what they're doing passing an EventChain directly
|
||||||
return value
|
return value
|
||||||
|
|
||||||
arg_spec = triggers.get(event_trigger, lambda: [])
|
|
||||||
|
|
||||||
# If the input is a single event handler, wrap it in a list.
|
# If the input is a single event handler, wrap it in a list.
|
||||||
if isinstance(value, (EventHandler, EventSpec)):
|
if isinstance(value, (EventHandler, EventSpec)):
|
||||||
value = [value]
|
value = [value]
|
||||||
@ -401,7 +401,7 @@ class Component(BaseComponent, ABC):
|
|||||||
if isinstance(v, EventHandler):
|
if isinstance(v, EventHandler):
|
||||||
# Call the event handler to get the event.
|
# Call the event handler to get the event.
|
||||||
try:
|
try:
|
||||||
event = call_event_handler(v, arg_spec) # type: ignore
|
event = call_event_handler(v, args_spec)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f" {err} defined in the `{type(self).__name__}` component"
|
f" {err} defined in the `{type(self).__name__}` component"
|
||||||
@ -414,13 +414,13 @@ class Component(BaseComponent, ABC):
|
|||||||
events.append(v)
|
events.append(v)
|
||||||
elif isinstance(v, Callable):
|
elif isinstance(v, Callable):
|
||||||
# Call the lambda to get the event chain.
|
# Call the lambda to get the event chain.
|
||||||
events.extend(call_event_fn(v, arg_spec)) # type: ignore
|
events.extend(call_event_fn(v, args_spec))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid event: {v}")
|
raise ValueError(f"Invalid event: {v}")
|
||||||
|
|
||||||
# If the input is a callable, create an event chain.
|
# If the input is a callable, create an event chain.
|
||||||
elif isinstance(value, Callable):
|
elif isinstance(value, Callable):
|
||||||
events = call_event_fn(value, arg_spec) # type: ignore
|
events = call_event_fn(value, args_spec)
|
||||||
|
|
||||||
# Otherwise, raise an error.
|
# Otherwise, raise an error.
|
||||||
else:
|
else:
|
||||||
@ -435,7 +435,7 @@ class Component(BaseComponent, ABC):
|
|||||||
event_actions.update(e.event_actions)
|
event_actions.update(e.event_actions)
|
||||||
|
|
||||||
# Return the event chain.
|
# Return the event chain.
|
||||||
if isinstance(arg_spec, Var):
|
if isinstance(args_spec, Var):
|
||||||
return EventChain(
|
return EventChain(
|
||||||
events=events,
|
events=events,
|
||||||
args_spec=None,
|
args_spec=None,
|
||||||
@ -444,7 +444,7 @@ class Component(BaseComponent, ABC):
|
|||||||
else:
|
else:
|
||||||
return EventChain(
|
return EventChain(
|
||||||
events=events,
|
events=events,
|
||||||
args_spec=arg_spec, # type: ignore
|
args_spec=args_spec,
|
||||||
event_actions=event_actions,
|
event_actions=event_actions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -454,7 +454,7 @@ class Component(BaseComponent, ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
The event triggers.
|
The event triggers.
|
||||||
"""
|
"""
|
||||||
return {
|
default_triggers = {
|
||||||
EventTriggers.ON_FOCUS: lambda: [],
|
EventTriggers.ON_FOCUS: lambda: [],
|
||||||
EventTriggers.ON_BLUR: lambda: [],
|
EventTriggers.ON_BLUR: lambda: [],
|
||||||
EventTriggers.ON_CLICK: lambda: [],
|
EventTriggers.ON_CLICK: lambda: [],
|
||||||
@ -471,6 +471,14 @@ class Component(BaseComponent, ABC):
|
|||||||
EventTriggers.ON_MOUNT: lambda: [],
|
EventTriggers.ON_MOUNT: lambda: [],
|
||||||
EventTriggers.ON_UNMOUNT: lambda: [],
|
EventTriggers.ON_UNMOUNT: lambda: [],
|
||||||
}
|
}
|
||||||
|
# Look for component specific triggers,
|
||||||
|
# e.g. variable declared as EventHandler types.
|
||||||
|
for field in self.get_fields().values():
|
||||||
|
if types._issubclass(field.type_, EventHandler):
|
||||||
|
default_triggers[field.name] = getattr(
|
||||||
|
field.type_, "args_spec", lambda: []
|
||||||
|
)
|
||||||
|
return default_triggers
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Represent the component in React.
|
"""Represent the component in React.
|
||||||
@ -1352,6 +1360,9 @@ class CustomComponent(Component):
|
|||||||
# Set the tag to the name of the function.
|
# Set the tag to the name of the function.
|
||||||
self.tag = format.to_title_case(self.component_fn.__name__)
|
self.tag = format.to_title_case(self.component_fn.__name__)
|
||||||
|
|
||||||
|
# Get the event triggers defined in the component declaration.
|
||||||
|
event_triggers_in_component_declaration = self.get_event_triggers()
|
||||||
|
|
||||||
# Set the props.
|
# Set the props.
|
||||||
props = typing.get_type_hints(self.component_fn)
|
props = typing.get_type_hints(self.component_fn)
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
@ -1364,7 +1375,12 @@ class CustomComponent(Component):
|
|||||||
|
|
||||||
# Handle event chains.
|
# Handle event chains.
|
||||||
if types._issubclass(type_, EventChain):
|
if types._issubclass(type_, EventChain):
|
||||||
value = self._create_event_chain(key, value)
|
value = self._create_event_chain(
|
||||||
|
value=value,
|
||||||
|
args_spec=event_triggers_in_component_declaration.get(
|
||||||
|
key, lambda: []
|
||||||
|
),
|
||||||
|
)
|
||||||
self.props[format.to_camel_case(key)] = value
|
self.props[format.to_camel_case(key)] = value
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Define event classes to connect the frontend and backend."""
|
"""Define event classes to connect the frontend and backend."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
@ -12,6 +13,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
_GenericAlias, # type: ignore
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -106,7 +108,7 @@ class EventHandler(EventActionsMixin):
|
|||||||
fn: Any
|
fn: Any
|
||||||
|
|
||||||
# The full name of the state class this event handler is attached to.
|
# The full name of the state class this event handler is attached to.
|
||||||
# Emtpy string means this event handler is a server side event.
|
# Empty string means this event handler is a server side event.
|
||||||
state_full_name: str = ""
|
state_full_name: str = ""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -115,6 +117,21 @@ class EventHandler(EventActionsMixin):
|
|||||||
# Needed to allow serialization of Callable.
|
# Needed to allow serialization of Callable.
|
||||||
frozen = True
|
frozen = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __class_getitem__(cls, args_spec: str) -> _GenericAlias:
|
||||||
|
"""Get a typed EventHandler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_spec: The args_spec of the EventHandler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The EventHandler class item.
|
||||||
|
"""
|
||||||
|
gen = _GenericAlias(cls, Any)
|
||||||
|
# Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute.
|
||||||
|
gen.args_spec = args_spec
|
||||||
|
return gen
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_background(self) -> bool:
|
def is_background(self) -> bool:
|
||||||
"""Whether the event handler is a background task.
|
"""Whether the event handler is a background task.
|
||||||
|
@ -1350,3 +1350,35 @@ def test_custom_component_add_imports(tags):
|
|||||||
|
|
||||||
assert baseline.get_imports() == {"react": _list_to_import_vars(tags)}
|
assert baseline.get_imports() == {"react": _list_to_import_vars(tags)}
|
||||||
assert test.get_imports() == baseline.get_imports()
|
assert test.get_imports() == baseline.get_imports()
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_component_declare_event_handlers_in_fields():
|
||||||
|
class ReferenceComponent(Component):
|
||||||
|
def get_event_triggers(self) -> Dict[str, Any]:
|
||||||
|
"""Test controlled triggers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Test controlled triggers.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
**super().get_event_triggers(),
|
||||||
|
"on_a": lambda e: [e],
|
||||||
|
"on_b": lambda e: [e.target.value],
|
||||||
|
"on_c": lambda e: [],
|
||||||
|
"on_d": lambda: [],
|
||||||
|
"on_e": lambda: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
class TestComponent(Component):
|
||||||
|
on_a: EventHandler[lambda e0: [e0]]
|
||||||
|
on_b: EventHandler[lambda e0: [e0.target.value]]
|
||||||
|
on_c: EventHandler[lambda e0: []]
|
||||||
|
on_d: EventHandler[lambda: []]
|
||||||
|
on_e: EventHandler
|
||||||
|
|
||||||
|
custom_component = ReferenceComponent.create()
|
||||||
|
test_component = TestComponent.create()
|
||||||
|
assert (
|
||||||
|
custom_component.get_event_triggers().keys()
|
||||||
|
== test_component.get_event_triggers().keys()
|
||||||
|
)
|
||||||
|
37
tests/components/test_component_future_annotations.py
Normal file
37
tests/components/test_component_future_annotations.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from reflex.components.component import Component
|
||||||
|
from reflex.event import EventHandler
|
||||||
|
|
||||||
|
|
||||||
|
# This is a repeat of its namesake in test_component.py.
|
||||||
|
def test_custom_component_declare_event_handlers_in_fields():
|
||||||
|
class ReferenceComponent(Component):
|
||||||
|
def get_event_triggers(self) -> dict[str, Any]:
|
||||||
|
"""Test controlled triggers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Test controlled triggers.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
**super().get_event_triggers(),
|
||||||
|
"on_a": lambda e: [e],
|
||||||
|
"on_b": lambda e: [e.target.value],
|
||||||
|
"on_c": lambda e: [],
|
||||||
|
"on_d": lambda: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
class TestComponent(Component):
|
||||||
|
on_a: EventHandler[lambda e0: [e0]]
|
||||||
|
on_b: EventHandler[lambda e0: [e0.target.value]]
|
||||||
|
on_c: EventHandler[lambda e0: []]
|
||||||
|
on_d: EventHandler[lambda: []]
|
||||||
|
|
||||||
|
custom_component = ReferenceComponent.create()
|
||||||
|
test_component = TestComponent.create()
|
||||||
|
assert (
|
||||||
|
custom_component.get_event_triggers().keys()
|
||||||
|
== test_component.get_event_triggers().keys()
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user