add basic test to cover it
This commit is contained in:
parent
f41b6e4322
commit
233033cb7d
@ -1019,9 +1019,16 @@ def call_event_handler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def compare_types(provided_type, accepted_type):
|
def compare_types(provided_type, accepted_type):
|
||||||
|
provided_type_origin = get_origin(provided_type)
|
||||||
|
accepted_type_origin = get_origin(accepted_type)
|
||||||
|
|
||||||
|
if provided_type_origin is None and accepted_type_origin is None:
|
||||||
|
# Check if both are concrete types (e.g., int)
|
||||||
|
return issubclass(provided_type, accepted_type)
|
||||||
|
|
||||||
# Check if both are generic types (e.g., List)
|
# Check if both are generic types (e.g., List)
|
||||||
if (get_origin(provided_type) or provided_type) != (
|
if (provided_type_origin or provided_type) != (
|
||||||
get_origin(accepted_type) or accepted_type
|
accepted_type_origin or accepted_type
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -20,13 +20,18 @@ from reflex.event import (
|
|||||||
EventChain,
|
EventChain,
|
||||||
EventHandler,
|
EventHandler,
|
||||||
empty_event,
|
empty_event,
|
||||||
|
identity_event,
|
||||||
input_event,
|
input_event,
|
||||||
parse_args_spec,
|
parse_args_spec,
|
||||||
)
|
)
|
||||||
from reflex.state import BaseState
|
from reflex.state import BaseState
|
||||||
from reflex.style import Style
|
from reflex.style import Style
|
||||||
from reflex.utils import imports
|
from reflex.utils import imports
|
||||||
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
|
from reflex.utils.exceptions import (
|
||||||
|
EventFnArgMismatch,
|
||||||
|
EventHandlerArgMismatch,
|
||||||
|
EventHandlerArgTypeMismatch,
|
||||||
|
)
|
||||||
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
|
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
|
||||||
from reflex.vars import VarData
|
from reflex.vars import VarData
|
||||||
from reflex.vars.base import LiteralVar, Var
|
from reflex.vars.base import LiteralVar, Var
|
||||||
@ -43,6 +48,12 @@ def test_state():
|
|||||||
def do_something_arg(self, arg):
|
def do_something_arg(self, arg):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def do_something_with_bool(self, arg: bool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def do_something_with_int(self, arg: int):
|
||||||
|
pass
|
||||||
|
|
||||||
return TestState
|
return TestState
|
||||||
|
|
||||||
|
|
||||||
@ -95,8 +106,9 @@ def component2() -> Type[Component]:
|
|||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
**super().get_event_triggers(),
|
**super().get_event_triggers(),
|
||||||
"on_open": lambda e0: [e0],
|
"on_open": identity_event(bool),
|
||||||
"on_close": lambda e0: [e0],
|
"on_close": identity_event(bool),
|
||||||
|
"on_user_visited_count_changed": identity_event(int),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_imports(self) -> ParsedImportDict:
|
def _get_imports(self) -> ParsedImportDict:
|
||||||
@ -582,7 +594,8 @@ def test_get_event_triggers(component1, component2):
|
|||||||
assert component1().get_event_triggers().keys() == default_triggers
|
assert component1().get_event_triggers().keys() == default_triggers
|
||||||
assert (
|
assert (
|
||||||
component2().get_event_triggers().keys()
|
component2().get_event_triggers().keys()
|
||||||
== {"on_open", "on_close", "on_prop_event"} | default_triggers
|
== {"on_open", "on_close", "on_prop_event", "on_user_visited_count_changed"}
|
||||||
|
| default_triggers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -918,6 +931,16 @@ def test_invalid_event_handler_args(component2, test_state):
|
|||||||
on_prop_event=[test_state.do_something_arg, test_state.do_something]
|
on_prop_event=[test_state.do_something_arg, test_state.do_something]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Event Handler types must match
|
||||||
|
with pytest.raises(EventHandlerArgTypeMismatch):
|
||||||
|
component2.create(
|
||||||
|
on_user_visited_count_changed=test_state.do_something_with_bool
|
||||||
|
)
|
||||||
|
|
||||||
|
component2.create(on_open=test_state.do_something_with_int)
|
||||||
|
component2.create(on_open=test_state.do_something_with_bool)
|
||||||
|
component2.create(on_user_visited_count_changed=test_state.do_something_with_int)
|
||||||
|
|
||||||
# lambda cannot return weird values.
|
# lambda cannot return weird values.
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
component2.create(on_click=lambda: 1)
|
component2.create(on_click=lambda: 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user