[ENG-3943]type check for event handler if spec arg are typed (#4046)

* type check for event handler if spec arg are typed

* fix the typecheck logic

* rearrange logic pieces

* add try except

* add try except around compare

* change form and improve type checking

* print key instead

* dang it darglint

* change wording

* add basic test to cover it

* add a slightly more complicated test

* challenge it a bit by doing small capital list

* add multiple argspec

* fix slider event order

* i hate 3.9

* add note for UnionType

* move function to types

* add a test for type hint is subclass

* make on submit dict str any

* add testing for dict cases

* add check against any

* accept dict str str

* bruh i used i twice

* escape strings and print actual error message

* disable the error and print deprecation warning instead

* disable tests

* fix doc message

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
Thomas Brandého 2024-10-31 12:45:28 -07:00 committed by GitHub
parent c8cecbf3cc
commit c07eb2a6a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 387 additions and 69 deletions

View File

@ -480,6 +480,7 @@ class Component(BaseComponent, ABC):
kwargs["event_triggers"][key] = self._create_event_chain(
value=value, # type: ignore
args_spec=component_specific_triggers[key],
key=key,
)
# Remove any keys that were added as events.
@ -540,12 +541,14 @@ class Component(BaseComponent, ABC):
List[Union[EventHandler, EventSpec, EventVar]],
Callable,
],
key: Optional[str] = None,
) -> Union[EventChain, Var]:
"""Create an event chain from a variety of input types.
Args:
args_spec: The args_spec of the event trigger being bound.
value: The value to create the event chain from.
key: The key of the event trigger being bound.
Returns:
The event chain.
@ -560,7 +563,7 @@ class Component(BaseComponent, ABC):
elif isinstance(value, EventVar):
value = [value]
elif issubclass(value._var_type, (EventChain, EventSpec)):
return self._create_event_chain(args_spec, value.guess_type())
return self._create_event_chain(args_spec, value.guess_type(), key=key)
else:
raise ValueError(
f"Invalid event chain: {str(value)} of type {value._var_type}"
@ -579,10 +582,10 @@ class Component(BaseComponent, ABC):
for v in value:
if isinstance(v, (EventHandler, EventSpec)):
# Call the event handler to get the event.
events.append(call_event_handler(v, args_spec))
events.append(call_event_handler(v, args_spec, key=key))
elif isinstance(v, Callable):
# Call the lambda to get the event chain.
result = call_event_fn(v, args_spec)
result = call_event_fn(v, args_spec, key=key)
if isinstance(result, Var):
raise ValueError(
f"Invalid event chain: {v}. Cannot use a Var-returning "
@ -599,7 +602,7 @@ class Component(BaseComponent, ABC):
result = call_event_fn(value, args_spec)
if isinstance(result, Var):
# Recursively call this function if the lambda returned an EventChain Var.
return self._create_event_chain(args_spec, result)
return self._create_event_chain(args_spec, result, key=key)
events = [*result]
# Otherwise, raise an error.
@ -1722,6 +1725,7 @@ class CustomComponent(Component):
args_spec=event_triggers_in_component_declaration.get(
key, empty_event
),
key=key,
)
self.props[format.to_camel_case(key)] = value
continue

View File

@ -111,6 +111,15 @@ def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]:
return (FORM_DATA,)
def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]:
"""Event handler spec for the on_submit event.
Returns:
The event handler spec.
"""
return (FORM_DATA,)
class Form(BaseHTML):
"""Display the form element."""
@ -150,7 +159,7 @@ class Form(BaseHTML):
handle_submit_unique_name: Var[str]
# Fired when the form is submitted
on_submit: EventHandler[on_submit_event_spec]
on_submit: EventHandler[on_submit_event_spec, on_submit_string_event_spec]
@classmethod
def create(cls, *children, **props):

View File

@ -271,6 +271,7 @@ class Fieldset(Element):
...
def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]: ...
def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]: ...
class Form(BaseHTML):
@overload
@ -337,7 +338,9 @@ class Form(BaseHTML):
on_mouse_over: Optional[EventType[[]]] = None,
on_mouse_up: Optional[EventType[[]]] = None,
on_scroll: Optional[EventType[[]]] = None,
on_submit: Optional[EventType[Dict[str, Any]]] = None,
on_submit: Optional[
Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
] = None,
on_unmount: Optional[EventType[[]]] = None,
**props,
) -> "Form":

View File

@ -129,7 +129,9 @@ class FormRoot(FormComponent, HTMLForm):
on_mouse_over: Optional[EventType[[]]] = None,
on_mouse_up: Optional[EventType[[]]] = None,
on_scroll: Optional[EventType[[]]] = None,
on_submit: Optional[EventType[Dict[str, Any]]] = None,
on_submit: Optional[
Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
] = None,
on_unmount: Optional[EventType[[]]] = None,
**props,
) -> "FormRoot":
@ -596,7 +598,9 @@ class Form(FormRoot):
on_mouse_over: Optional[EventType[[]]] = None,
on_mouse_up: Optional[EventType[[]]] = None,
on_scroll: Optional[EventType[[]]] = None,
on_submit: Optional[EventType[Dict[str, Any]]] = None,
on_submit: Optional[
Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
] = None,
on_unmount: Optional[EventType[[]]] = None,
**props,
) -> "Form":
@ -720,7 +724,9 @@ class FormNamespace(ComponentNamespace):
on_mouse_over: Optional[EventType[[]]] = None,
on_mouse_up: Optional[EventType[[]]] = None,
on_scroll: Optional[EventType[[]]] = None,
on_submit: Optional[EventType[Dict[str, Any]]] = None,
on_submit: Optional[
Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
] = None,
on_unmount: Optional[EventType[[]]] = None,
**props,
) -> "Form":

View File

@ -2,11 +2,11 @@
from __future__ import annotations
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Literal, Optional, Union
from reflex.components.component import Component
from reflex.components.core.breakpoints import Responsive
from reflex.event import EventHandler
from reflex.event import EventHandler, identity_event
from reflex.vars.base import Var
from ..base import (
@ -14,19 +14,11 @@ from ..base import (
RadixThemesComponent,
)
def on_value_event_spec(
value: Var[List[Union[int, float]]],
) -> Tuple[Var[List[Union[int, float]]]]:
"""Event handler spec for the value event.
Args:
value: The value of the event.
Returns:
The event handler spec.
"""
return (value,) # type: ignore
on_value_event_spec = (
identity_event(list[Union[int, float]]),
identity_event(list[int]),
identity_event(list[float]),
)
class Slider(RadixThemesComponent):

View File

@ -3,18 +3,20 @@
# ------------------- DO NOT EDIT ----------------------
# This file was generated by `reflex/utils/pyi_generator.py`!
# ------------------------------------------------------
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
from typing import Any, Dict, List, Literal, Optional, Union, overload
from reflex.components.core.breakpoints import Breakpoints
from reflex.event import EventType
from reflex.event import EventType, identity_event
from reflex.style import Style
from reflex.vars.base import Var
from ..base import RadixThemesComponent
def on_value_event_spec(
value: Var[List[Union[int, float]]],
) -> Tuple[Var[List[Union[int, float]]]]: ...
on_value_event_spec = (
identity_event(list[Union[int, float]]),
identity_event(list[int]),
identity_event(list[float]),
)
class Slider(RadixThemesComponent):
@overload
@ -138,7 +140,13 @@ class Slider(RadixThemesComponent):
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
on_blur: Optional[EventType[[]]] = None,
on_change: Optional[EventType[List[Union[int, float]]]] = None,
on_change: Optional[
Union[
EventType[list[Union[int, float]]],
EventType[list[int]],
EventType[list[float]],
]
] = None,
on_click: Optional[EventType[[]]] = None,
on_context_menu: Optional[EventType[[]]] = None,
on_double_click: Optional[EventType[[]]] = None,
@ -153,7 +161,13 @@ class Slider(RadixThemesComponent):
on_mouse_up: Optional[EventType[[]]] = None,
on_scroll: Optional[EventType[[]]] = None,
on_unmount: Optional[EventType[[]]] = None,
on_value_commit: Optional[EventType[List[Union[int, float]]]] = None,
on_value_commit: Optional[
Union[
EventType[list[Union[int, float]]],
EventType[list[int]],
EventType[list[float]],
]
] = None,
**props,
) -> "Slider":
"""Create a Slider component.

View File

@ -29,8 +29,12 @@ from typing_extensions import ParamSpec, Protocol, get_args, get_origin
from reflex import constants
from reflex.utils import console, format
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.types import ArgsSpec, GenericType
from reflex.utils.exceptions import (
EventFnArgMismatch,
EventHandlerArgMismatch,
EventHandlerArgTypeMismatch,
)
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
from reflex.vars import VarData
from reflex.vars.base import (
LiteralVar,
@ -401,7 +405,9 @@ class EventChain(EventActionsMixin):
default_factory=list
)
args_spec: Optional[Callable] = dataclasses.field(default=None)
args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field(
default=None
)
invocation: Optional[Var] = dataclasses.field(default=None)
@ -1053,7 +1059,8 @@ def get_hydrate_event(state) -> str:
def call_event_handler(
event_handler: EventHandler | EventSpec,
arg_spec: ArgsSpec,
arg_spec: ArgsSpec | Sequence[ArgsSpec],
key: Optional[str] = None,
) -> EventSpec:
"""Call an event handler to get the event spec.
@ -1064,12 +1071,16 @@ def call_event_handler(
Args:
event_handler: The event handler.
arg_spec: The lambda that define the argument(s) to pass to the event handler.
key: The key to pass to the event handler.
Raises:
EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
Returns:
The event spec from calling the event handler.
# noqa: DAR401 failure
"""
parsed_args = parse_args_spec(arg_spec) # type: ignore
@ -1077,19 +1088,113 @@ def call_event_handler(
# Handle partial application of EventSpec args
return event_handler.add_args(*parsed_args)
args = inspect.getfullargspec(event_handler.fn).args
n_args = len(args) - 1 # subtract 1 for bound self arg
if n_args == len(parsed_args):
return event_handler(*parsed_args) # type: ignore
else:
provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)
provided_callback_n_args = (
len(provided_callback_fullspec.args) - 1
) # subtract 1 for bound self arg
if provided_callback_n_args != len(parsed_args):
raise EventHandlerArgMismatch(
"The number of arguments accepted by "
f"{event_handler.fn.__qualname__} ({n_args}) "
f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
)
all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
event_spec_return_types = list(
filter(
lambda event_spec_return_type: event_spec_return_type is not None
and get_origin(event_spec_return_type) is tuple,
(get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
)
)
if event_spec_return_types:
failures = []
for event_spec_index, event_spec_return_type in enumerate(
event_spec_return_types
):
args = get_args(event_spec_return_type)
args_types_without_vars = [
arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
]
try:
type_hints_of_provided_callback = get_type_hints(event_handler.fn)
except NameError:
type_hints_of_provided_callback = {}
failed_type_check = False
# check that args of event handler are matching the spec if type hints are provided
for i, arg in enumerate(provided_callback_fullspec.args[1:]):
if arg not in type_hints_of_provided_callback:
continue
try:
compare_result = typehint_issubclass(
args_types_without_vars[i], type_hints_of_provided_callback[arg]
)
except TypeError:
# TODO: In 0.7.0, remove this block and raise the exception
# raise TypeError(
# f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
# ) from e
console.warn(
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
)
compare_result = False
if compare_result:
continue
else:
failure = EventHandlerArgTypeMismatch(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
)
failures.append(failure)
failed_type_check = True
break
if not failed_type_check:
if event_spec_index:
args = get_args(event_spec_return_types[0])
args_types_without_vars = [
arg if get_origin(arg) is not Var else get_args(arg)[0]
for arg in args
]
expect_string = ", ".join(
repr(arg) for arg in args_types_without_vars
).replace("[", "\\[")
given_string = ", ".join(
repr(type_hints_of_provided_callback.get(arg, Any))
for arg in provided_callback_fullspec.args[1:]
).replace("[", "\\[")
console.warn(
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
f"This may lead to unexpected behavior but is intentionally ignored for {key}."
)
return event_handler(*parsed_args)
if failures:
console.deprecate(
"Mismatched event handler argument types",
"\n".join([str(f) for f in failures]),
"0.6.5",
"0.7.0",
)
return event_handler(*parsed_args) # type: ignore
def unwrap_var_annotation(annotation: GenericType):
"""Unwrap a Var annotation or return it as is if it's not Var[X].
@ -1128,7 +1233,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
return annotation
def parse_args_spec(arg_spec: ArgsSpec):
def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
"""Parse the args provided in the ArgsSpec of an event trigger.
Args:
@ -1137,6 +1242,8 @@ def parse_args_spec(arg_spec: ArgsSpec):
Returns:
The parsed args.
"""
# if there's multiple, the first is the default
arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec
spec = inspect.getfullargspec(arg_spec)
annotations = get_type_hints(arg_spec)
@ -1152,13 +1259,18 @@ def parse_args_spec(arg_spec: ArgsSpec):
)
def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
def check_fn_match_arg_spec(
fn: Callable,
arg_spec: ArgsSpec,
key: Optional[str] = None,
) -> List[Var]:
"""Ensures that the function signature matches the passed argument specification
or raises an EventFnArgMismatch if they do not.
Args:
fn: The function to be validated.
arg_spec: The argument specification for the event trigger.
key: The key to pass to the event handler.
Returns:
The parsed arguments from the argument specification.
@ -1184,7 +1296,11 @@ def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
return parsed_args
def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
def call_event_fn(
fn: Callable,
arg_spec: ArgsSpec,
key: Optional[str] = None,
) -> list[EventSpec] | Var:
"""Call a function to a list of event specs.
The function should return a single EventSpec, a list of EventSpecs, or a
@ -1193,6 +1309,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
Args:
fn: The function to call.
arg_spec: The argument spec for the event trigger.
key: The key to pass to the event handler.
Returns:
The event specs from calling the function or a Var.
@ -1205,7 +1322,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
from reflex.utils.exceptions import EventHandlerValueError
# Check that fn signature matches arg_spec
parsed_args = check_fn_match_arg_spec(fn, arg_spec)
parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
# Call the function with the parsed args.
out = fn(*parsed_args)
@ -1223,7 +1340,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
for e in out:
if isinstance(e, EventHandler):
# An un-called EventHandler gets all of the args of the event trigger.
e = call_event_handler(e, arg_spec)
e = call_event_handler(e, arg_spec, key=key)
# Make sure the event spec is valid.
if not isinstance(e, EventSpec):
@ -1433,7 +1550,12 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
Returns:
The created LiteralEventChainVar instance.
"""
sig = inspect.signature(value.args_spec) # type: ignore
arg_spec = (
value.args_spec[0]
if isinstance(value.args_spec, Sequence)
else value.args_spec
)
sig = inspect.signature(arg_spec) # type: ignore
if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])

View File

@ -90,7 +90,11 @@ class MatchTypeError(ReflexError, TypeError):
class EventHandlerArgMismatch(ReflexError, TypeError):
"""Raised when the number of args accepted by an EventHandler is differs from that provided by the event trigger."""
"""Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""
class EventHandlerArgTypeMismatch(ReflexError, TypeError):
"""Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
class EventFnArgMismatch(ReflexError, TypeError):

View File

@ -490,7 +490,7 @@ def _generate_component_create_functiondef(
def figure_out_return_type(annotation: Any):
if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
return ast.Name(id="Optional[EventType]")
return ast.Name(id="EventType")
if not isinstance(annotation, str) and get_origin(annotation) is tuple:
arguments = get_args(annotation)
@ -509,20 +509,13 @@ def _generate_component_create_functiondef(
# Create EventType using the joined string
event_type = ast.Name(id=f"EventType[{args_str}]")
# Wrap in Optional
optional_type = ast.Subscript(
value=ast.Name(id="Optional"),
slice=ast.Index(value=event_type),
ctx=ast.Load(),
)
return ast.Name(id=ast.unparse(optional_type))
return event_type
if isinstance(annotation, str) and annotation.startswith("Tuple["):
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
if inside_of_tuple == "()":
return ast.Name(id="Optional[EventType[[]]]")
return ast.Name(id="EventType[[]]")
arguments = [""]
@ -548,10 +541,8 @@ def _generate_component_create_functiondef(
for argument in arguments
]
return ast.Name(
id=f"Optional[EventType[{', '.join(arguments_without_var)}]]"
)
return ast.Name(id="Optional[EventType]")
return ast.Name(id=f"EventType[{', '.join(arguments_without_var)}]")
return ast.Name(id="EventType")
event_triggers = clz().get_event_triggers()
@ -560,8 +551,33 @@ def _generate_component_create_functiondef(
(
ast.arg(
arg=trigger,
annotation=figure_out_return_type(
inspect.signature(event_triggers[trigger]).return_annotation
annotation=ast.Subscript(
ast.Name("Optional"),
ast.Index( # type: ignore
value=ast.Name(
id=ast.unparse(
figure_out_return_type(
inspect.signature(event_specs).return_annotation
)
if not isinstance(
event_specs := event_triggers[trigger], tuple
)
else ast.Subscript(
ast.Name("Union"),
ast.Tuple(
[
figure_out_return_type(
inspect.signature(
event_spec
).return_annotation
)
for event_spec in event_specs
]
),
)
)
)
),
),
),
ast.Constant(value=None),

View File

@ -774,3 +774,69 @@ def validate_parameter_literals(func):
# Store this here for performance.
StateBases = get_base_class(StateVar)
StateIterBases = get_base_class(StateIterVar)
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint.
Args:
possible_subclass: The type hint to check.
possible_superclass: The type hint to check against.
Returns:
Whether the type hint is a subclass of the other type hint.
"""
if possible_superclass is Any:
return True
if possible_subclass is Any:
return False
provided_type_origin = get_origin(possible_subclass)
accepted_type_origin = get_origin(possible_superclass)
if provided_type_origin is None and accepted_type_origin is None:
# In this case, we are dealing with a non-generic type, so we can use issubclass
return issubclass(possible_subclass, possible_superclass)
# Remove this check when Python 3.10 is the minimum supported version
if hasattr(types, "UnionType"):
provided_type_origin = (
Union if provided_type_origin is types.UnionType else provided_type_origin
)
accepted_type_origin = (
Union if accepted_type_origin is types.UnionType else accepted_type_origin
)
# Get type arguments (e.g., [float, int] for Dict[float, int])
provided_args = get_args(possible_subclass)
accepted_args = get_args(possible_superclass)
if accepted_type_origin is Union:
if provided_type_origin is not Union:
return any(
typehint_issubclass(possible_subclass, accepted_arg)
for accepted_arg in accepted_args
)
return all(
any(
typehint_issubclass(provided_arg, accepted_arg)
for accepted_arg in accepted_args
)
for provided_arg in provided_args
)
# Check if the origin of both types is the same (e.g., list for List[int])
# This probably should be issubclass instead of ==
if (provided_type_origin or possible_subclass) != (
accepted_type_origin or possible_superclass
):
return False
# Ensure all specific types are compatible with accepted types
# Note this is not necessarily correct, as it doesn't check against contravariance and covariance
# It also ignores when the length of the arguments is different
return all(
typehint_issubclass(provided_arg, accepted_arg)
for provided_arg, accepted_arg in zip(provided_args, accepted_args)
if accepted_arg is not Any
)

View File

@ -20,13 +20,17 @@ from reflex.event import (
EventChain,
EventHandler,
empty_event,
identity_event,
input_event,
parse_args_spec,
)
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.exceptions import (
EventFnArgMismatch,
EventHandlerArgMismatch,
)
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
@ -43,6 +47,18 @@ def test_state():
def do_something_arg(self, arg):
pass
def do_something_with_bool(self, arg: bool):
pass
def do_something_with_int(self, arg: int):
pass
def do_something_with_list_int(self, arg: list[int]):
pass
def do_something_with_list_str(self, arg: list[str]):
pass
return TestState
@ -95,8 +111,10 @@ def component2() -> Type[Component]:
"""
return {
**super().get_event_triggers(),
"on_open": lambda e0: [e0],
"on_close": lambda e0: [e0],
"on_open": identity_event(bool),
"on_close": identity_event(bool),
"on_user_visited_count_changed": identity_event(int),
"on_user_list_changed": identity_event(List[str]),
}
def _get_imports(self) -> ParsedImportDict:
@ -582,7 +600,14 @@ def test_get_event_triggers(component1, component2):
assert component1().get_event_triggers().keys() == default_triggers
assert (
component2().get_event_triggers().keys()
== {"on_open", "on_close", "on_prop_event"} | default_triggers
== {
"on_open",
"on_close",
"on_prop_event",
"on_user_visited_count_changed",
"on_user_list_changed",
}
| default_triggers
)
@ -903,6 +928,22 @@ def test_invalid_event_handler_args(component2, test_state):
on_prop_event=[test_state.do_something_arg, test_state.do_something]
)
# Enable when 0.7.0 happens
# # Event Handler types must match
# with pytest.raises(EventHandlerArgTypeMismatch):
# component2.create(
# on_user_visited_count_changed=test_state.do_something_with_bool
# )
# with pytest.raises(EventHandlerArgTypeMismatch):
# component2.create(on_user_list_changed=test_state.do_something_with_int)
# with pytest.raises(EventHandlerArgTypeMismatch):
# component2.create(on_user_list_changed=test_state.do_something_with_list_int)
# component2.create(on_open=test_state.do_something_with_int)
# component2.create(on_open=test_state.do_something_with_bool)
# component2.create(on_user_visited_count_changed=test_state.do_something_with_int)
# component2.create(on_user_list_changed=test_state.do_something_with_list_str)
# lambda cannot return weird values.
with pytest.raises(ValueError):
component2.create(on_click=lambda: 1)

View File

@ -2,7 +2,7 @@ import os
import typing
from functools import cached_property
from pathlib import Path
from typing import Any, ClassVar, List, Literal, Type, Union
from typing import Any, ClassVar, Dict, List, Literal, Type, Union
import pytest
import typer
@ -77,6 +77,47 @@ def test_is_generic_alias(cls: type, expected: bool):
assert types.is_generic_alias(cls) == expected
@pytest.mark.parametrize(
("subclass", "superclass", "expected"),
[
*[
(base_type, base_type, True)
for base_type in [int, float, str, bool, list, dict]
],
*[
(one_type, another_type, False)
for one_type in [int, float, str, list, dict]
for another_type in [int, float, str, list, dict]
if one_type != another_type
],
(bool, int, True),
(int, bool, False),
(list, List, True),
(list, List[str], True), # this is wrong, but it's a limitation of the function
(List, list, True),
(List[int], list, True),
(List[int], List, True),
(List[int], List[str], False),
(List[int], List[int], True),
(List[int], List[float], False),
(List[int], List[Union[int, float]], True),
(List[int], List[Union[float, str]], False),
(Union[int, float], List[Union[int, float]], False),
(Union[int, float], Union[int, float, str], True),
(Union[int, float], Union[str, float], False),
(Dict[str, int], Dict[str, int], True),
(Dict[str, bool], Dict[str, int], True),
(Dict[str, int], Dict[str, bool], False),
(Dict[str, Any], dict[str, str], False),
(Dict[str, str], dict[str, str], True),
(Dict[str, str], dict[str, Any], True),
(Dict[str, Any], dict[str, Any], True),
],
)
def test_typehint_issubclass(subclass, superclass, expected):
assert types.typehint_issubclass(subclass, superclass) == expected
def test_validate_invalid_bun_path(mocker):
"""Test that an error is thrown when a custom specified bun path is not valid
or does not exist.