diff --git a/reflex/components/component.py b/reflex/components/component.py index 1de884e8a..af9da1b4e 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -4,6 +4,7 @@ from __future__ import annotations import copy import dataclasses +import inspect import typing from abc import ABC, abstractmethod from functools import lru_cache, wraps @@ -21,6 +22,8 @@ from typing import ( Set, Type, Union, + get_args, + get_origin, ) from typing_extensions import Self @@ -43,6 +46,7 @@ from reflex.constants import ( from reflex.constants.compiler import SpecialAttributes from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.event import ( + EventActionsMixin, EventCallback, EventChain, EventHandler, @@ -508,7 +512,7 @@ class Component(BaseComponent, ABC): # Remove any keys that were added as events. for key in kwargs["event_triggers"]: - del kwargs[key] + kwargs.pop(key, None) # Place data_ and aria_ attributes into custom_attrs special_attributes = tuple( @@ -1672,11 +1676,51 @@ class CustomComponent(Component): props = {key: value for key, value in kwargs.items() if key in props_types} kwargs = {key: value for key, value in kwargs.items() if key not in props_types} + event_types = { + key + for key in props + if ( + (get_origin((annotation := props_types.get(key))) or annotation) + == EventHandler + ) + } + + def get_args_spec(key: str) -> types.ArgsSpec | Sequence[types.ArgsSpec]: + type_ = props_types[key] + + return ( + args[0] + if (args := get_args(type_)) + else ( + annotation_args[1] + if get_origin( + ( + annotation := inspect.getfullargspec( + component_fn + ).annotations[key] + ) + ) + is typing.Annotated + and (annotation_args := get_args(annotation)) + else no_args_event_spec + ) + ) + super().__init__( + event_triggers={ + key: EventChain.create( + value=props[key], + args_spec=get_args_spec(key), + key=key, + ) + for key in event_types + }, **kwargs, ) - to_camel_cased_props = {format.to_camel_case(key) for key in props} + to_camel_cased_props = { + format.to_camel_case(key) for key in props if key not in event_types + } self.get_props = lambda: to_camel_cased_props # pyright: ignore [reportIncompatibleVariableOverride] # Unset the style. @@ -1685,9 +1729,6 @@ class CustomComponent(Component): # Set the tag to the name of the function. self.tag = format.to_title_case(self.component_fn.__name__) - # Get the event triggers defined in the component declaration. - event_triggers_in_component_declaration = self.get_event_triggers() - for key, value in props.items(): # Skip kwargs that are not props. if key not in props_types: @@ -1699,20 +1740,14 @@ class CustomComponent(Component): type_ = props_types[key] # Handle event chains. - if types._issubclass(type_, EventChain): - value = EventChain.create( - value=value, - args_spec=event_triggers_in_component_declaration.get( - key, no_args_event_spec - ), - key=key, + if types._issubclass(type_, EventActionsMixin): + inspect.getfullargspec(component_fn).annotations[key] + self.props[camel_cased_key] = EventChain.create( + value=value, args_spec=get_args_spec(key), key=key ) - self.props[camel_cased_key] = value continue value = LiteralVar.create(value) - - # Set the prop. self.props[camel_cased_key] = value setattr(self, camel_cased_key, value) @@ -1793,7 +1828,15 @@ class CustomComponent(Component): return [ Var( _js_expr=name, - _var_type=(prop._var_type if isinstance(prop, Var) else type(prop)), + _var_type=( + prop._var_type + if isinstance(prop, Var) + else ( + type(prop) + if not isinstance(prop, EventActionsMixin) + else EventChain + ) + ), ).guess_type() for name, prop in self.props.items() ]