From f372402ee40607d1d6114a96de77aab378f3c905 Mon Sep 17 00:00:00 2001 From: Martin Xu <15661672+martinxu9@users.noreply.github.com> Date: Fri, 29 Mar 2024 06:26:07 -0700 Subject: [PATCH] [REF-2272] Support declaring EventHandlers directly in component (#2952) --- reflex/__init__.py | 1 + reflex/__init__.pyi | 1 + reflex/components/component.py | 48 ++++++++++++------- reflex/event.py | 19 +++++++- tests/components/test_component.py | 32 +++++++++++++ .../test_component_future_annotations.py | 37 ++++++++++++++ 6 files changed, 121 insertions(+), 17 deletions(-) create mode 100644 tests/components/test_component_future_annotations.py diff --git a/reflex/__init__.py b/reflex/__init__.py index a461ed9a4..9c5faea0f 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -131,6 +131,7 @@ _MAPPING = { "reflex.event": [ "event", "EventChain", + "EventHandler", "background", "call_script", "clear_local_storage", diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index c6651a41d..51c291343 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -113,6 +113,7 @@ from reflex import constants as constants from reflex.constants import Env as Env from reflex import event as event from reflex.event import EventChain as EventChain +from reflex.event import EventHandler as EventHandler from reflex.event import background as background from reflex.event import call_script as call_script from reflex.event import clear_local_storage as clear_local_storage diff --git a/reflex/components/component.py b/reflex/components/component.py index 58842062f..4e7654a76 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -236,6 +236,8 @@ class Component(BaseComponent, ABC): if types._issubclass(field.type_, Var): field.required = False field.default = Var.create(field.default) + elif types._issubclass(field.type_, EventHandler): + field.required = False # Ensure renamed props from parent classes are applied to the subclass. if cls._rename_props: @@ -272,7 +274,8 @@ class Component(BaseComponent, ABC): # Get the component fields, triggers, and props. fields = self.get_fields() - triggers = self.get_event_triggers().keys() + component_specific_triggers = self.get_event_triggers() + triggers = component_specific_triggers.keys() props = self.get_props() # Add any events triggers. @@ -327,7 +330,9 @@ class Component(BaseComponent, ABC): # Check if the key is an event trigger. if key in triggers: # Temporarily disable full control for event triggers. - kwargs["event_triggers"][key] = self._create_event_chain(key, value) + kwargs["event_triggers"][key] = self._create_event_chain( + value=value, args_spec=component_specific_triggers[key] + ) # Remove any keys that were added as events. for key in kwargs["event_triggers"]: @@ -359,7 +364,7 @@ class Component(BaseComponent, ABC): def _create_event_chain( self, - event_trigger: str, + args_spec: Any, value: Union[ Var, EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], Callable ], @@ -367,7 +372,7 @@ class Component(BaseComponent, ABC): """Create an event chain from a variety of input types. Args: - event_trigger: The event trigger to bind the chain to. + args_spec: The args_spec of the the event trigger being bound. value: The value to create the event chain from. Returns: @@ -376,9 +381,6 @@ class Component(BaseComponent, ABC): Raises: ValueError: If the value is not a valid event chain. """ - # Check if the trigger is a controlled event. - triggers = self.get_event_triggers() - # If it's an event chain var, return it. if isinstance(value, Var): if value._var_type is not EventChain: @@ -388,8 +390,6 @@ class Component(BaseComponent, ABC): # Trust that the caller knows what they're doing passing an EventChain directly return value - arg_spec = triggers.get(event_trigger, lambda: []) - # If the input is a single event handler, wrap it in a list. if isinstance(value, (EventHandler, EventSpec)): value = [value] @@ -401,7 +401,7 @@ class Component(BaseComponent, ABC): if isinstance(v, EventHandler): # Call the event handler to get the event. try: - event = call_event_handler(v, arg_spec) # type: ignore + event = call_event_handler(v, args_spec) except ValueError as err: raise ValueError( f" {err} defined in the `{type(self).__name__}` component" @@ -414,13 +414,13 @@ class Component(BaseComponent, ABC): events.append(v) elif isinstance(v, Callable): # Call the lambda to get the event chain. - events.extend(call_event_fn(v, arg_spec)) # type: ignore + events.extend(call_event_fn(v, args_spec)) else: raise ValueError(f"Invalid event: {v}") # If the input is a callable, create an event chain. elif isinstance(value, Callable): - events = call_event_fn(value, arg_spec) # type: ignore + events = call_event_fn(value, args_spec) # Otherwise, raise an error. else: @@ -435,7 +435,7 @@ class Component(BaseComponent, ABC): event_actions.update(e.event_actions) # Return the event chain. - if isinstance(arg_spec, Var): + if isinstance(args_spec, Var): return EventChain( events=events, args_spec=None, @@ -444,7 +444,7 @@ class Component(BaseComponent, ABC): else: return EventChain( events=events, - args_spec=arg_spec, # type: ignore + args_spec=args_spec, event_actions=event_actions, ) @@ -454,7 +454,7 @@ class Component(BaseComponent, ABC): Returns: The event triggers. """ - return { + default_triggers = { EventTriggers.ON_FOCUS: lambda: [], EventTriggers.ON_BLUR: lambda: [], EventTriggers.ON_CLICK: lambda: [], @@ -471,6 +471,14 @@ class Component(BaseComponent, ABC): EventTriggers.ON_MOUNT: lambda: [], EventTriggers.ON_UNMOUNT: lambda: [], } + # Look for component specific triggers, + # e.g. variable declared as EventHandler types. + for field in self.get_fields().values(): + if types._issubclass(field.type_, EventHandler): + default_triggers[field.name] = getattr( + field.type_, "args_spec", lambda: [] + ) + return default_triggers def __repr__(self) -> str: """Represent the component in React. @@ -1352,6 +1360,9 @@ class CustomComponent(Component): # Set the tag to the name of the function. self.tag = format.to_title_case(self.component_fn.__name__) + # Get the event triggers defined in the component declaration. + event_triggers_in_component_declaration = self.get_event_triggers() + # Set the props. props = typing.get_type_hints(self.component_fn) for key, value in kwargs.items(): @@ -1364,7 +1375,12 @@ class CustomComponent(Component): # Handle event chains. if types._issubclass(type_, EventChain): - value = self._create_event_chain(key, value) + value = self._create_event_chain( + value=value, + args_spec=event_triggers_in_component_declaration.get( + key, lambda: [] + ), + ) self.props[format.to_camel_case(key)] = value continue diff --git a/reflex/event.py b/reflex/event.py index b8611b90e..19fdad5e9 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1,4 +1,5 @@ """Define event classes to connect the frontend and backend.""" + from __future__ import annotations import inspect @@ -12,6 +13,7 @@ from typing import ( Optional, Tuple, Union, + _GenericAlias, # type: ignore get_type_hints, ) @@ -106,7 +108,7 @@ class EventHandler(EventActionsMixin): fn: Any # The full name of the state class this event handler is attached to. - # Emtpy string means this event handler is a server side event. + # Empty string means this event handler is a server side event. state_full_name: str = "" class Config: @@ -115,6 +117,21 @@ class EventHandler(EventActionsMixin): # Needed to allow serialization of Callable. frozen = True + @classmethod + def __class_getitem__(cls, args_spec: str) -> _GenericAlias: + """Get a typed EventHandler. + + Args: + args_spec: The args_spec of the EventHandler. + + Returns: + The EventHandler class item. + """ + gen = _GenericAlias(cls, Any) + # Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute. + gen.args_spec = args_spec + return gen + @property def is_background(self) -> bool: """Whether the event handler is a background task. diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 9c090baa9..c5f5f9d51 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -1350,3 +1350,35 @@ def test_custom_component_add_imports(tags): assert baseline.get_imports() == {"react": _list_to_import_vars(tags)} assert test.get_imports() == baseline.get_imports() + + +def test_custom_component_declare_event_handlers_in_fields(): + class ReferenceComponent(Component): + def get_event_triggers(self) -> Dict[str, Any]: + """Test controlled triggers. + + Returns: + Test controlled triggers. + """ + return { + **super().get_event_triggers(), + "on_a": lambda e: [e], + "on_b": lambda e: [e.target.value], + "on_c": lambda e: [], + "on_d": lambda: [], + "on_e": lambda: [], + } + + class TestComponent(Component): + on_a: EventHandler[lambda e0: [e0]] + on_b: EventHandler[lambda e0: [e0.target.value]] + on_c: EventHandler[lambda e0: []] + on_d: EventHandler[lambda: []] + on_e: EventHandler + + custom_component = ReferenceComponent.create() + test_component = TestComponent.create() + assert ( + custom_component.get_event_triggers().keys() + == test_component.get_event_triggers().keys() + ) diff --git a/tests/components/test_component_future_annotations.py b/tests/components/test_component_future_annotations.py new file mode 100644 index 000000000..37aeb813a --- /dev/null +++ b/tests/components/test_component_future_annotations.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any + +from reflex.components.component import Component +from reflex.event import EventHandler + + +# This is a repeat of its namesake in test_component.py. +def test_custom_component_declare_event_handlers_in_fields(): + class ReferenceComponent(Component): + def get_event_triggers(self) -> dict[str, Any]: + """Test controlled triggers. + + Returns: + Test controlled triggers. + """ + return { + **super().get_event_triggers(), + "on_a": lambda e: [e], + "on_b": lambda e: [e.target.value], + "on_c": lambda e: [], + "on_d": lambda: [], + } + + class TestComponent(Component): + on_a: EventHandler[lambda e0: [e0]] + on_b: EventHandler[lambda e0: [e0.target.value]] + on_c: EventHandler[lambda e0: []] + on_d: EventHandler[lambda: []] + + custom_component = ReferenceComponent.create() + test_component = TestComponent.create() + assert ( + custom_component.get_event_triggers().keys() + == test_component.get_event_triggers().keys() + )