add basic test to cover it

This commit is contained in:
Khaleel Al-Adhami 2024-10-15 17:09:39 -07:00
parent f41b6e4322
commit 233033cb7d
2 changed files with 36 additions and 6 deletions

View File

@ -1019,9 +1019,16 @@ def call_event_handler(
) )
def compare_types(provided_type, accepted_type): def compare_types(provided_type, accepted_type):
provided_type_origin = get_origin(provided_type)
accepted_type_origin = get_origin(accepted_type)
if provided_type_origin is None and accepted_type_origin is None:
# Check if both are concrete types (e.g., int)
return issubclass(provided_type, accepted_type)
# Check if both are generic types (e.g., List) # Check if both are generic types (e.g., List)
if (get_origin(provided_type) or provided_type) != ( if (provided_type_origin or provided_type) != (
get_origin(accepted_type) or accepted_type accepted_type_origin or accepted_type
): ):
return False return False

View File

@ -20,13 +20,18 @@ from reflex.event import (
EventChain, EventChain,
EventHandler, EventHandler,
empty_event, empty_event,
identity_event,
input_event, input_event,
parse_args_spec, parse_args_spec,
) )
from reflex.state import BaseState from reflex.state import BaseState
from reflex.style import Style from reflex.style import Style
from reflex.utils import imports from reflex.utils import imports
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch from reflex.utils.exceptions import (
EventFnArgMismatch,
EventHandlerArgMismatch,
EventHandlerArgTypeMismatch,
)
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
@ -43,6 +48,12 @@ def test_state():
def do_something_arg(self, arg): def do_something_arg(self, arg):
pass pass
def do_something_with_bool(self, arg: bool):
pass
def do_something_with_int(self, arg: int):
pass
return TestState return TestState
@ -95,8 +106,9 @@ def component2() -> Type[Component]:
""" """
return { return {
**super().get_event_triggers(), **super().get_event_triggers(),
"on_open": lambda e0: [e0], "on_open": identity_event(bool),
"on_close": lambda e0: [e0], "on_close": identity_event(bool),
"on_user_visited_count_changed": identity_event(int),
} }
def _get_imports(self) -> ParsedImportDict: def _get_imports(self) -> ParsedImportDict:
@ -582,7 +594,8 @@ def test_get_event_triggers(component1, component2):
assert component1().get_event_triggers().keys() == default_triggers assert component1().get_event_triggers().keys() == default_triggers
assert ( assert (
component2().get_event_triggers().keys() component2().get_event_triggers().keys()
== {"on_open", "on_close", "on_prop_event"} | default_triggers == {"on_open", "on_close", "on_prop_event", "on_user_visited_count_changed"}
| default_triggers
) )
@ -918,6 +931,16 @@ def test_invalid_event_handler_args(component2, test_state):
on_prop_event=[test_state.do_something_arg, test_state.do_something] on_prop_event=[test_state.do_something_arg, test_state.do_something]
) )
# Event Handler types must match
with pytest.raises(EventHandlerArgTypeMismatch):
component2.create(
on_user_visited_count_changed=test_state.do_something_with_bool
)
component2.create(on_open=test_state.do_something_with_int)
component2.create(on_open=test_state.do_something_with_bool)
component2.create(on_user_visited_count_changed=test_state.do_something_with_int)
# lambda cannot return weird values. # lambda cannot return weird values.
with pytest.raises(ValueError): with pytest.raises(ValueError):
component2.create(on_click=lambda: 1) component2.create(on_click=lambda: 1)