add basic test to cover it

This commit is contained in:
Khaleel Al-Adhami 2024-10-15 17:09:39 -07:00
parent f41b6e4322
commit 233033cb7d
2 changed files with 36 additions and 6 deletions

View File

@ -1019,9 +1019,16 @@ def call_event_handler(
)
def compare_types(provided_type, accepted_type):
provided_type_origin = get_origin(provided_type)
accepted_type_origin = get_origin(accepted_type)
if provided_type_origin is None and accepted_type_origin is None:
# Check if both are concrete types (e.g., int)
return issubclass(provided_type, accepted_type)
# Check if both are generic types (e.g., List)
if (get_origin(provided_type) or provided_type) != (
get_origin(accepted_type) or accepted_type
if (provided_type_origin or provided_type) != (
accepted_type_origin or accepted_type
):
return False

View File

@ -20,13 +20,18 @@ from reflex.event import (
EventChain,
EventHandler,
empty_event,
identity_event,
input_event,
parse_args_spec,
)
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.exceptions import (
EventFnArgMismatch,
EventHandlerArgMismatch,
EventHandlerArgTypeMismatch,
)
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
@ -43,6 +48,12 @@ def test_state():
def do_something_arg(self, arg):
pass
def do_something_with_bool(self, arg: bool):
pass
def do_something_with_int(self, arg: int):
pass
return TestState
@ -95,8 +106,9 @@ def component2() -> Type[Component]:
"""
return {
**super().get_event_triggers(),
"on_open": lambda e0: [e0],
"on_close": lambda e0: [e0],
"on_open": identity_event(bool),
"on_close": identity_event(bool),
"on_user_visited_count_changed": identity_event(int),
}
def _get_imports(self) -> ParsedImportDict:
@ -582,7 +594,8 @@ def test_get_event_triggers(component1, component2):
assert component1().get_event_triggers().keys() == default_triggers
assert (
component2().get_event_triggers().keys()
== {"on_open", "on_close", "on_prop_event"} | default_triggers
== {"on_open", "on_close", "on_prop_event", "on_user_visited_count_changed"}
| default_triggers
)
@ -918,6 +931,16 @@ def test_invalid_event_handler_args(component2, test_state):
on_prop_event=[test_state.do_something_arg, test_state.do_something]
)
# Event Handler types must match
with pytest.raises(EventHandlerArgTypeMismatch):
component2.create(
on_user_visited_count_changed=test_state.do_something_with_bool
)
component2.create(on_open=test_state.do_something_with_int)
component2.create(on_open=test_state.do_something_with_bool)
component2.create(on_user_visited_count_changed=test_state.do_something_with_int)
# lambda cannot return weird values.
with pytest.raises(ValueError):
component2.create(on_click=lambda: 1)