diff --git a/reflex/event.py b/reflex/event.py index ece684f50..1fd611776 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1019,9 +1019,16 @@ def call_event_handler( ) def compare_types(provided_type, accepted_type): + provided_type_origin = get_origin(provided_type) + accepted_type_origin = get_origin(accepted_type) + + if provided_type_origin is None and accepted_type_origin is None: + # Check if both are concrete types (e.g., int) + return issubclass(provided_type, accepted_type) + # Check if both are generic types (e.g., List) - if (get_origin(provided_type) or provided_type) != ( - get_origin(accepted_type) or accepted_type + if (provided_type_origin or provided_type) != ( + accepted_type_origin or accepted_type ): return False diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index b7b721a92..bbc4a3f0a 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -20,13 +20,18 @@ from reflex.event import ( EventChain, EventHandler, empty_event, + identity_event, input_event, parse_args_spec, ) from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports -from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch +from reflex.utils.exceptions import ( + EventFnArgMismatch, + EventHandlerArgMismatch, + EventHandlerArgTypeMismatch, +) from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var @@ -43,6 +48,12 @@ def test_state(): def do_something_arg(self, arg): pass + def do_something_with_bool(self, arg: bool): + pass + + def do_something_with_int(self, arg: int): + pass + return TestState @@ -95,8 +106,9 @@ def component2() -> Type[Component]: """ return { **super().get_event_triggers(), - "on_open": lambda e0: [e0], - "on_close": lambda e0: [e0], + "on_open": identity_event(bool), + "on_close": identity_event(bool), + "on_user_visited_count_changed": identity_event(int), } def _get_imports(self) -> ParsedImportDict: @@ -582,7 +594,8 @@ def test_get_event_triggers(component1, component2): assert component1().get_event_triggers().keys() == default_triggers assert ( component2().get_event_triggers().keys() - == {"on_open", "on_close", "on_prop_event"} | default_triggers + == {"on_open", "on_close", "on_prop_event", "on_user_visited_count_changed"} + | default_triggers ) @@ -918,6 +931,16 @@ def test_invalid_event_handler_args(component2, test_state): on_prop_event=[test_state.do_something_arg, test_state.do_something] ) + # Event Handler types must match + with pytest.raises(EventHandlerArgTypeMismatch): + component2.create( + on_user_visited_count_changed=test_state.do_something_with_bool + ) + + component2.create(on_open=test_state.do_something_with_int) + component2.create(on_open=test_state.do_something_with_bool) + component2.create(on_user_visited_count_changed=test_state.do_something_with_int) + # lambda cannot return weird values. with pytest.raises(ValueError): component2.create(on_click=lambda: 1)