add basic test to cover it
This commit is contained in:
parent
f41b6e4322
commit
233033cb7d
@ -1019,9 +1019,16 @@ def call_event_handler(
|
||||
)
|
||||
|
||||
def compare_types(provided_type, accepted_type):
|
||||
provided_type_origin = get_origin(provided_type)
|
||||
accepted_type_origin = get_origin(accepted_type)
|
||||
|
||||
if provided_type_origin is None and accepted_type_origin is None:
|
||||
# Check if both are concrete types (e.g., int)
|
||||
return issubclass(provided_type, accepted_type)
|
||||
|
||||
# Check if both are generic types (e.g., List)
|
||||
if (get_origin(provided_type) or provided_type) != (
|
||||
get_origin(accepted_type) or accepted_type
|
||||
if (provided_type_origin or provided_type) != (
|
||||
accepted_type_origin or accepted_type
|
||||
):
|
||||
return False
|
||||
|
||||
|
@ -20,13 +20,18 @@ from reflex.event import (
|
||||
EventChain,
|
||||
EventHandler,
|
||||
empty_event,
|
||||
identity_event,
|
||||
input_event,
|
||||
parse_args_spec,
|
||||
)
|
||||
from reflex.state import BaseState
|
||||
from reflex.style import Style
|
||||
from reflex.utils import imports
|
||||
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
|
||||
from reflex.utils.exceptions import (
|
||||
EventFnArgMismatch,
|
||||
EventHandlerArgMismatch,
|
||||
EventHandlerArgTypeMismatch,
|
||||
)
|
||||
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
@ -43,6 +48,12 @@ def test_state():
|
||||
def do_something_arg(self, arg):
|
||||
pass
|
||||
|
||||
def do_something_with_bool(self, arg: bool):
|
||||
pass
|
||||
|
||||
def do_something_with_int(self, arg: int):
|
||||
pass
|
||||
|
||||
return TestState
|
||||
|
||||
|
||||
@ -95,8 +106,9 @@ def component2() -> Type[Component]:
|
||||
"""
|
||||
return {
|
||||
**super().get_event_triggers(),
|
||||
"on_open": lambda e0: [e0],
|
||||
"on_close": lambda e0: [e0],
|
||||
"on_open": identity_event(bool),
|
||||
"on_close": identity_event(bool),
|
||||
"on_user_visited_count_changed": identity_event(int),
|
||||
}
|
||||
|
||||
def _get_imports(self) -> ParsedImportDict:
|
||||
@ -582,7 +594,8 @@ def test_get_event_triggers(component1, component2):
|
||||
assert component1().get_event_triggers().keys() == default_triggers
|
||||
assert (
|
||||
component2().get_event_triggers().keys()
|
||||
== {"on_open", "on_close", "on_prop_event"} | default_triggers
|
||||
== {"on_open", "on_close", "on_prop_event", "on_user_visited_count_changed"}
|
||||
| default_triggers
|
||||
)
|
||||
|
||||
|
||||
@ -918,6 +931,16 @@ def test_invalid_event_handler_args(component2, test_state):
|
||||
on_prop_event=[test_state.do_something_arg, test_state.do_something]
|
||||
)
|
||||
|
||||
# Event Handler types must match
|
||||
with pytest.raises(EventHandlerArgTypeMismatch):
|
||||
component2.create(
|
||||
on_user_visited_count_changed=test_state.do_something_with_bool
|
||||
)
|
||||
|
||||
component2.create(on_open=test_state.do_something_with_int)
|
||||
component2.create(on_open=test_state.do_something_with_bool)
|
||||
component2.create(on_user_visited_count_changed=test_state.do_something_with_int)
|
||||
|
||||
# lambda cannot return weird values.
|
||||
with pytest.raises(ValueError):
|
||||
component2.create(on_click=lambda: 1)
|
||||
|
Loading…
Reference in New Issue
Block a user