[ENG-3943]type check for event handler if spec arg are typed (#4046)
* type check for event handler if spec arg are typed * fix the typecheck logic * rearrange logic pieces * add try except * add try except around compare * change form and improve type checking * print key instead * dang it darglint * change wording * add basic test to cover it * add a slightly more complicated test * challenge it a bit by doing small capital list * add multiple argspec * fix slider event order * i hate 3.9 * add note for UnionType * move function to types * add a test for type hint is subclass * make on submit dict str any * add testing for dict cases * add check against any * accept dict str str * bruh i used i twice * escape strings and print actual error message * disable the error and print deprecation warning instead * disable tests * fix doc message --------- Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
parent
c8cecbf3cc
commit
c07eb2a6a0
@ -480,6 +480,7 @@ class Component(BaseComponent, ABC):
|
||||
kwargs["event_triggers"][key] = self._create_event_chain(
|
||||
value=value, # type: ignore
|
||||
args_spec=component_specific_triggers[key],
|
||||
key=key,
|
||||
)
|
||||
|
||||
# Remove any keys that were added as events.
|
||||
@ -540,12 +541,14 @@ class Component(BaseComponent, ABC):
|
||||
List[Union[EventHandler, EventSpec, EventVar]],
|
||||
Callable,
|
||||
],
|
||||
key: Optional[str] = None,
|
||||
) -> Union[EventChain, Var]:
|
||||
"""Create an event chain from a variety of input types.
|
||||
|
||||
Args:
|
||||
args_spec: The args_spec of the event trigger being bound.
|
||||
value: The value to create the event chain from.
|
||||
key: The key of the event trigger being bound.
|
||||
|
||||
Returns:
|
||||
The event chain.
|
||||
@ -560,7 +563,7 @@ class Component(BaseComponent, ABC):
|
||||
elif isinstance(value, EventVar):
|
||||
value = [value]
|
||||
elif issubclass(value._var_type, (EventChain, EventSpec)):
|
||||
return self._create_event_chain(args_spec, value.guess_type())
|
||||
return self._create_event_chain(args_spec, value.guess_type(), key=key)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid event chain: {str(value)} of type {value._var_type}"
|
||||
@ -579,10 +582,10 @@ class Component(BaseComponent, ABC):
|
||||
for v in value:
|
||||
if isinstance(v, (EventHandler, EventSpec)):
|
||||
# Call the event handler to get the event.
|
||||
events.append(call_event_handler(v, args_spec))
|
||||
events.append(call_event_handler(v, args_spec, key=key))
|
||||
elif isinstance(v, Callable):
|
||||
# Call the lambda to get the event chain.
|
||||
result = call_event_fn(v, args_spec)
|
||||
result = call_event_fn(v, args_spec, key=key)
|
||||
if isinstance(result, Var):
|
||||
raise ValueError(
|
||||
f"Invalid event chain: {v}. Cannot use a Var-returning "
|
||||
@ -599,7 +602,7 @@ class Component(BaseComponent, ABC):
|
||||
result = call_event_fn(value, args_spec)
|
||||
if isinstance(result, Var):
|
||||
# Recursively call this function if the lambda returned an EventChain Var.
|
||||
return self._create_event_chain(args_spec, result)
|
||||
return self._create_event_chain(args_spec, result, key=key)
|
||||
events = [*result]
|
||||
|
||||
# Otherwise, raise an error.
|
||||
@ -1722,6 +1725,7 @@ class CustomComponent(Component):
|
||||
args_spec=event_triggers_in_component_declaration.get(
|
||||
key, empty_event
|
||||
),
|
||||
key=key,
|
||||
)
|
||||
self.props[format.to_camel_case(key)] = value
|
||||
continue
|
||||
|
@ -111,6 +111,15 @@ def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]:
|
||||
return (FORM_DATA,)
|
||||
|
||||
|
||||
def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]:
|
||||
"""Event handler spec for the on_submit event.
|
||||
|
||||
Returns:
|
||||
The event handler spec.
|
||||
"""
|
||||
return (FORM_DATA,)
|
||||
|
||||
|
||||
class Form(BaseHTML):
|
||||
"""Display the form element."""
|
||||
|
||||
@ -150,7 +159,7 @@ class Form(BaseHTML):
|
||||
handle_submit_unique_name: Var[str]
|
||||
|
||||
# Fired when the form is submitted
|
||||
on_submit: EventHandler[on_submit_event_spec]
|
||||
on_submit: EventHandler[on_submit_event_spec, on_submit_string_event_spec]
|
||||
|
||||
@classmethod
|
||||
def create(cls, *children, **props):
|
||||
|
@ -271,6 +271,7 @@ class Fieldset(Element):
|
||||
...
|
||||
|
||||
def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]: ...
|
||||
def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]: ...
|
||||
|
||||
class Form(BaseHTML):
|
||||
@overload
|
||||
@ -337,7 +338,9 @@ class Form(BaseHTML):
|
||||
on_mouse_over: Optional[EventType[[]]] = None,
|
||||
on_mouse_up: Optional[EventType[[]]] = None,
|
||||
on_scroll: Optional[EventType[[]]] = None,
|
||||
on_submit: Optional[EventType[Dict[str, Any]]] = None,
|
||||
on_submit: Optional[
|
||||
Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
|
||||
] = None,
|
||||
on_unmount: Optional[EventType[[]]] = None,
|
||||
**props,
|
||||
) -> "Form":
|
||||
|
@ -129,7 +129,9 @@ class FormRoot(FormComponent, HTMLForm):
|
||||
on_mouse_over: Optional[EventType[[]]] = None,
|
||||
on_mouse_up: Optional[EventType[[]]] = None,
|
||||
on_scroll: Optional[EventType[[]]] = None,
|
||||
on_submit: Optional[EventType[Dict[str, Any]]] = None,
|
||||
on_submit: Optional[
|
||||
Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
|
||||
] = None,
|
||||
on_unmount: Optional[EventType[[]]] = None,
|
||||
**props,
|
||||
) -> "FormRoot":
|
||||
@ -596,7 +598,9 @@ class Form(FormRoot):
|
||||
on_mouse_over: Optional[EventType[[]]] = None,
|
||||
on_mouse_up: Optional[EventType[[]]] = None,
|
||||
on_scroll: Optional[EventType[[]]] = None,
|
||||
on_submit: Optional[EventType[Dict[str, Any]]] = None,
|
||||
on_submit: Optional[
|
||||
Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
|
||||
] = None,
|
||||
on_unmount: Optional[EventType[[]]] = None,
|
||||
**props,
|
||||
) -> "Form":
|
||||
@ -720,7 +724,9 @@ class FormNamespace(ComponentNamespace):
|
||||
on_mouse_over: Optional[EventType[[]]] = None,
|
||||
on_mouse_up: Optional[EventType[[]]] = None,
|
||||
on_scroll: Optional[EventType[[]]] = None,
|
||||
on_submit: Optional[EventType[Dict[str, Any]]] = None,
|
||||
on_submit: Optional[
|
||||
Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
|
||||
] = None,
|
||||
on_unmount: Optional[EventType[[]]] = None,
|
||||
**props,
|
||||
) -> "Form":
|
||||
|
@ -2,11 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.core.breakpoints import Responsive
|
||||
from reflex.event import EventHandler
|
||||
from reflex.event import EventHandler, identity_event
|
||||
from reflex.vars.base import Var
|
||||
|
||||
from ..base import (
|
||||
@ -14,19 +14,11 @@ from ..base import (
|
||||
RadixThemesComponent,
|
||||
)
|
||||
|
||||
|
||||
def on_value_event_spec(
|
||||
value: Var[List[Union[int, float]]],
|
||||
) -> Tuple[Var[List[Union[int, float]]]]:
|
||||
"""Event handler spec for the value event.
|
||||
|
||||
Args:
|
||||
value: The value of the event.
|
||||
|
||||
Returns:
|
||||
The event handler spec.
|
||||
"""
|
||||
return (value,) # type: ignore
|
||||
on_value_event_spec = (
|
||||
identity_event(list[Union[int, float]]),
|
||||
identity_event(list[int]),
|
||||
identity_event(list[float]),
|
||||
)
|
||||
|
||||
|
||||
class Slider(RadixThemesComponent):
|
||||
|
@ -3,18 +3,20 @@
|
||||
# ------------------- DO NOT EDIT ----------------------
|
||||
# This file was generated by `reflex/utils/pyi_generator.py`!
|
||||
# ------------------------------------------------------
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
|
||||
from typing import Any, Dict, List, Literal, Optional, Union, overload
|
||||
|
||||
from reflex.components.core.breakpoints import Breakpoints
|
||||
from reflex.event import EventType
|
||||
from reflex.event import EventType, identity_event
|
||||
from reflex.style import Style
|
||||
from reflex.vars.base import Var
|
||||
|
||||
from ..base import RadixThemesComponent
|
||||
|
||||
def on_value_event_spec(
|
||||
value: Var[List[Union[int, float]]],
|
||||
) -> Tuple[Var[List[Union[int, float]]]]: ...
|
||||
on_value_event_spec = (
|
||||
identity_event(list[Union[int, float]]),
|
||||
identity_event(list[int]),
|
||||
identity_event(list[float]),
|
||||
)
|
||||
|
||||
class Slider(RadixThemesComponent):
|
||||
@overload
|
||||
@ -138,7 +140,13 @@ class Slider(RadixThemesComponent):
|
||||
autofocus: Optional[bool] = None,
|
||||
custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
|
||||
on_blur: Optional[EventType[[]]] = None,
|
||||
on_change: Optional[EventType[List[Union[int, float]]]] = None,
|
||||
on_change: Optional[
|
||||
Union[
|
||||
EventType[list[Union[int, float]]],
|
||||
EventType[list[int]],
|
||||
EventType[list[float]],
|
||||
]
|
||||
] = None,
|
||||
on_click: Optional[EventType[[]]] = None,
|
||||
on_context_menu: Optional[EventType[[]]] = None,
|
||||
on_double_click: Optional[EventType[[]]] = None,
|
||||
@ -153,7 +161,13 @@ class Slider(RadixThemesComponent):
|
||||
on_mouse_up: Optional[EventType[[]]] = None,
|
||||
on_scroll: Optional[EventType[[]]] = None,
|
||||
on_unmount: Optional[EventType[[]]] = None,
|
||||
on_value_commit: Optional[EventType[List[Union[int, float]]]] = None,
|
||||
on_value_commit: Optional[
|
||||
Union[
|
||||
EventType[list[Union[int, float]]],
|
||||
EventType[list[int]],
|
||||
EventType[list[float]],
|
||||
]
|
||||
] = None,
|
||||
**props,
|
||||
) -> "Slider":
|
||||
"""Create a Slider component.
|
||||
|
154
reflex/event.py
154
reflex/event.py
@ -29,8 +29,12 @@ from typing_extensions import ParamSpec, Protocol, get_args, get_origin
|
||||
|
||||
from reflex import constants
|
||||
from reflex.utils import console, format
|
||||
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
|
||||
from reflex.utils.types import ArgsSpec, GenericType
|
||||
from reflex.utils.exceptions import (
|
||||
EventFnArgMismatch,
|
||||
EventHandlerArgMismatch,
|
||||
EventHandlerArgTypeMismatch,
|
||||
)
|
||||
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import (
|
||||
LiteralVar,
|
||||
@ -401,7 +405,9 @@ class EventChain(EventActionsMixin):
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
args_spec: Optional[Callable] = dataclasses.field(default=None)
|
||||
args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field(
|
||||
default=None
|
||||
)
|
||||
|
||||
invocation: Optional[Var] = dataclasses.field(default=None)
|
||||
|
||||
@ -1053,7 +1059,8 @@ def get_hydrate_event(state) -> str:
|
||||
|
||||
def call_event_handler(
|
||||
event_handler: EventHandler | EventSpec,
|
||||
arg_spec: ArgsSpec,
|
||||
arg_spec: ArgsSpec | Sequence[ArgsSpec],
|
||||
key: Optional[str] = None,
|
||||
) -> EventSpec:
|
||||
"""Call an event handler to get the event spec.
|
||||
|
||||
@ -1064,12 +1071,16 @@ def call_event_handler(
|
||||
Args:
|
||||
event_handler: The event handler.
|
||||
arg_spec: The lambda that define the argument(s) to pass to the event handler.
|
||||
key: The key to pass to the event handler.
|
||||
|
||||
Raises:
|
||||
EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
|
||||
|
||||
Returns:
|
||||
The event spec from calling the event handler.
|
||||
|
||||
# noqa: DAR401 failure
|
||||
|
||||
"""
|
||||
parsed_args = parse_args_spec(arg_spec) # type: ignore
|
||||
|
||||
@ -1077,19 +1088,113 @@ def call_event_handler(
|
||||
# Handle partial application of EventSpec args
|
||||
return event_handler.add_args(*parsed_args)
|
||||
|
||||
args = inspect.getfullargspec(event_handler.fn).args
|
||||
n_args = len(args) - 1 # subtract 1 for bound self arg
|
||||
if n_args == len(parsed_args):
|
||||
return event_handler(*parsed_args) # type: ignore
|
||||
else:
|
||||
provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)
|
||||
|
||||
provided_callback_n_args = (
|
||||
len(provided_callback_fullspec.args) - 1
|
||||
) # subtract 1 for bound self arg
|
||||
|
||||
if provided_callback_n_args != len(parsed_args):
|
||||
raise EventHandlerArgMismatch(
|
||||
"The number of arguments accepted by "
|
||||
f"{event_handler.fn.__qualname__} ({n_args}) "
|
||||
f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
|
||||
"does not match the arguments passed by the event trigger: "
|
||||
f"{[str(v) for v in parsed_args]}\n"
|
||||
"See https://reflex.dev/docs/events/event-arguments/"
|
||||
)
|
||||
|
||||
all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
|
||||
|
||||
event_spec_return_types = list(
|
||||
filter(
|
||||
lambda event_spec_return_type: event_spec_return_type is not None
|
||||
and get_origin(event_spec_return_type) is tuple,
|
||||
(get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
|
||||
)
|
||||
)
|
||||
|
||||
if event_spec_return_types:
|
||||
failures = []
|
||||
|
||||
for event_spec_index, event_spec_return_type in enumerate(
|
||||
event_spec_return_types
|
||||
):
|
||||
args = get_args(event_spec_return_type)
|
||||
|
||||
args_types_without_vars = [
|
||||
arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
|
||||
]
|
||||
|
||||
try:
|
||||
type_hints_of_provided_callback = get_type_hints(event_handler.fn)
|
||||
except NameError:
|
||||
type_hints_of_provided_callback = {}
|
||||
|
||||
failed_type_check = False
|
||||
|
||||
# check that args of event handler are matching the spec if type hints are provided
|
||||
for i, arg in enumerate(provided_callback_fullspec.args[1:]):
|
||||
if arg not in type_hints_of_provided_callback:
|
||||
continue
|
||||
|
||||
try:
|
||||
compare_result = typehint_issubclass(
|
||||
args_types_without_vars[i], type_hints_of_provided_callback[arg]
|
||||
)
|
||||
except TypeError:
|
||||
# TODO: In 0.7.0, remove this block and raise the exception
|
||||
# raise TypeError(
|
||||
# f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
|
||||
# ) from e
|
||||
console.warn(
|
||||
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
|
||||
)
|
||||
compare_result = False
|
||||
|
||||
if compare_result:
|
||||
continue
|
||||
else:
|
||||
failure = EventHandlerArgTypeMismatch(
|
||||
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
|
||||
)
|
||||
failures.append(failure)
|
||||
failed_type_check = True
|
||||
break
|
||||
|
||||
if not failed_type_check:
|
||||
if event_spec_index:
|
||||
args = get_args(event_spec_return_types[0])
|
||||
|
||||
args_types_without_vars = [
|
||||
arg if get_origin(arg) is not Var else get_args(arg)[0]
|
||||
for arg in args
|
||||
]
|
||||
|
||||
expect_string = ", ".join(
|
||||
repr(arg) for arg in args_types_without_vars
|
||||
).replace("[", "\\[")
|
||||
|
||||
given_string = ", ".join(
|
||||
repr(type_hints_of_provided_callback.get(arg, Any))
|
||||
for arg in provided_callback_fullspec.args[1:]
|
||||
).replace("[", "\\[")
|
||||
|
||||
console.warn(
|
||||
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
|
||||
f"This may lead to unexpected behavior but is intentionally ignored for {key}."
|
||||
)
|
||||
return event_handler(*parsed_args)
|
||||
|
||||
if failures:
|
||||
console.deprecate(
|
||||
"Mismatched event handler argument types",
|
||||
"\n".join([str(f) for f in failures]),
|
||||
"0.6.5",
|
||||
"0.7.0",
|
||||
)
|
||||
|
||||
return event_handler(*parsed_args) # type: ignore
|
||||
|
||||
|
||||
def unwrap_var_annotation(annotation: GenericType):
|
||||
"""Unwrap a Var annotation or return it as is if it's not Var[X].
|
||||
@ -1128,7 +1233,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
|
||||
return annotation
|
||||
|
||||
|
||||
def parse_args_spec(arg_spec: ArgsSpec):
|
||||
def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
|
||||
"""Parse the args provided in the ArgsSpec of an event trigger.
|
||||
|
||||
Args:
|
||||
@ -1137,6 +1242,8 @@ def parse_args_spec(arg_spec: ArgsSpec):
|
||||
Returns:
|
||||
The parsed args.
|
||||
"""
|
||||
# if there's multiple, the first is the default
|
||||
arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec
|
||||
spec = inspect.getfullargspec(arg_spec)
|
||||
annotations = get_type_hints(arg_spec)
|
||||
|
||||
@ -1152,13 +1259,18 @@ def parse_args_spec(arg_spec: ArgsSpec):
|
||||
)
|
||||
|
||||
|
||||
def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
|
||||
def check_fn_match_arg_spec(
|
||||
fn: Callable,
|
||||
arg_spec: ArgsSpec,
|
||||
key: Optional[str] = None,
|
||||
) -> List[Var]:
|
||||
"""Ensures that the function signature matches the passed argument specification
|
||||
or raises an EventFnArgMismatch if they do not.
|
||||
|
||||
Args:
|
||||
fn: The function to be validated.
|
||||
arg_spec: The argument specification for the event trigger.
|
||||
key: The key to pass to the event handler.
|
||||
|
||||
Returns:
|
||||
The parsed arguments from the argument specification.
|
||||
@ -1184,7 +1296,11 @@ def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
|
||||
return parsed_args
|
||||
|
||||
|
||||
def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
||||
def call_event_fn(
|
||||
fn: Callable,
|
||||
arg_spec: ArgsSpec,
|
||||
key: Optional[str] = None,
|
||||
) -> list[EventSpec] | Var:
|
||||
"""Call a function to a list of event specs.
|
||||
|
||||
The function should return a single EventSpec, a list of EventSpecs, or a
|
||||
@ -1193,6 +1309,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
||||
Args:
|
||||
fn: The function to call.
|
||||
arg_spec: The argument spec for the event trigger.
|
||||
key: The key to pass to the event handler.
|
||||
|
||||
Returns:
|
||||
The event specs from calling the function or a Var.
|
||||
@ -1205,7 +1322,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
||||
from reflex.utils.exceptions import EventHandlerValueError
|
||||
|
||||
# Check that fn signature matches arg_spec
|
||||
parsed_args = check_fn_match_arg_spec(fn, arg_spec)
|
||||
parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
|
||||
|
||||
# Call the function with the parsed args.
|
||||
out = fn(*parsed_args)
|
||||
@ -1223,7 +1340,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
||||
for e in out:
|
||||
if isinstance(e, EventHandler):
|
||||
# An un-called EventHandler gets all of the args of the event trigger.
|
||||
e = call_event_handler(e, arg_spec)
|
||||
e = call_event_handler(e, arg_spec, key=key)
|
||||
|
||||
# Make sure the event spec is valid.
|
||||
if not isinstance(e, EventSpec):
|
||||
@ -1433,7 +1550,12 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
|
||||
Returns:
|
||||
The created LiteralEventChainVar instance.
|
||||
"""
|
||||
sig = inspect.signature(value.args_spec) # type: ignore
|
||||
arg_spec = (
|
||||
value.args_spec[0]
|
||||
if isinstance(value.args_spec, Sequence)
|
||||
else value.args_spec
|
||||
)
|
||||
sig = inspect.signature(arg_spec) # type: ignore
|
||||
if sig.parameters:
|
||||
arg_def = tuple((f"_{p}" for p in sig.parameters))
|
||||
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
|
||||
|
@ -90,7 +90,11 @@ class MatchTypeError(ReflexError, TypeError):
|
||||
|
||||
|
||||
class EventHandlerArgMismatch(ReflexError, TypeError):
|
||||
"""Raised when the number of args accepted by an EventHandler is differs from that provided by the event trigger."""
|
||||
"""Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""
|
||||
|
||||
|
||||
class EventHandlerArgTypeMismatch(ReflexError, TypeError):
|
||||
"""Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
|
||||
|
||||
|
||||
class EventFnArgMismatch(ReflexError, TypeError):
|
||||
|
@ -490,7 +490,7 @@ def _generate_component_create_functiondef(
|
||||
|
||||
def figure_out_return_type(annotation: Any):
|
||||
if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
|
||||
return ast.Name(id="Optional[EventType]")
|
||||
return ast.Name(id="EventType")
|
||||
|
||||
if not isinstance(annotation, str) and get_origin(annotation) is tuple:
|
||||
arguments = get_args(annotation)
|
||||
@ -509,20 +509,13 @@ def _generate_component_create_functiondef(
|
||||
# Create EventType using the joined string
|
||||
event_type = ast.Name(id=f"EventType[{args_str}]")
|
||||
|
||||
# Wrap in Optional
|
||||
optional_type = ast.Subscript(
|
||||
value=ast.Name(id="Optional"),
|
||||
slice=ast.Index(value=event_type),
|
||||
ctx=ast.Load(),
|
||||
)
|
||||
|
||||
return ast.Name(id=ast.unparse(optional_type))
|
||||
return event_type
|
||||
|
||||
if isinstance(annotation, str) and annotation.startswith("Tuple["):
|
||||
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
|
||||
|
||||
if inside_of_tuple == "()":
|
||||
return ast.Name(id="Optional[EventType[[]]]")
|
||||
return ast.Name(id="EventType[[]]")
|
||||
|
||||
arguments = [""]
|
||||
|
||||
@ -548,10 +541,8 @@ def _generate_component_create_functiondef(
|
||||
for argument in arguments
|
||||
]
|
||||
|
||||
return ast.Name(
|
||||
id=f"Optional[EventType[{', '.join(arguments_without_var)}]]"
|
||||
)
|
||||
return ast.Name(id="Optional[EventType]")
|
||||
return ast.Name(id=f"EventType[{', '.join(arguments_without_var)}]")
|
||||
return ast.Name(id="EventType")
|
||||
|
||||
event_triggers = clz().get_event_triggers()
|
||||
|
||||
@ -560,8 +551,33 @@ def _generate_component_create_functiondef(
|
||||
(
|
||||
ast.arg(
|
||||
arg=trigger,
|
||||
annotation=figure_out_return_type(
|
||||
inspect.signature(event_triggers[trigger]).return_annotation
|
||||
annotation=ast.Subscript(
|
||||
ast.Name("Optional"),
|
||||
ast.Index( # type: ignore
|
||||
value=ast.Name(
|
||||
id=ast.unparse(
|
||||
figure_out_return_type(
|
||||
inspect.signature(event_specs).return_annotation
|
||||
)
|
||||
if not isinstance(
|
||||
event_specs := event_triggers[trigger], tuple
|
||||
)
|
||||
else ast.Subscript(
|
||||
ast.Name("Union"),
|
||||
ast.Tuple(
|
||||
[
|
||||
figure_out_return_type(
|
||||
inspect.signature(
|
||||
event_spec
|
||||
).return_annotation
|
||||
)
|
||||
for event_spec in event_specs
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
),
|
||||
ast.Constant(value=None),
|
||||
|
@ -774,3 +774,69 @@ def validate_parameter_literals(func):
|
||||
# Store this here for performance.
|
||||
StateBases = get_base_class(StateVar)
|
||||
StateIterBases = get_base_class(StateIterVar)
|
||||
|
||||
|
||||
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
|
||||
"""Check if a type hint is a subclass of another type hint.
|
||||
|
||||
Args:
|
||||
possible_subclass: The type hint to check.
|
||||
possible_superclass: The type hint to check against.
|
||||
|
||||
Returns:
|
||||
Whether the type hint is a subclass of the other type hint.
|
||||
"""
|
||||
if possible_superclass is Any:
|
||||
return True
|
||||
if possible_subclass is Any:
|
||||
return False
|
||||
|
||||
provided_type_origin = get_origin(possible_subclass)
|
||||
accepted_type_origin = get_origin(possible_superclass)
|
||||
|
||||
if provided_type_origin is None and accepted_type_origin is None:
|
||||
# In this case, we are dealing with a non-generic type, so we can use issubclass
|
||||
return issubclass(possible_subclass, possible_superclass)
|
||||
|
||||
# Remove this check when Python 3.10 is the minimum supported version
|
||||
if hasattr(types, "UnionType"):
|
||||
provided_type_origin = (
|
||||
Union if provided_type_origin is types.UnionType else provided_type_origin
|
||||
)
|
||||
accepted_type_origin = (
|
||||
Union if accepted_type_origin is types.UnionType else accepted_type_origin
|
||||
)
|
||||
|
||||
# Get type arguments (e.g., [float, int] for Dict[float, int])
|
||||
provided_args = get_args(possible_subclass)
|
||||
accepted_args = get_args(possible_superclass)
|
||||
|
||||
if accepted_type_origin is Union:
|
||||
if provided_type_origin is not Union:
|
||||
return any(
|
||||
typehint_issubclass(possible_subclass, accepted_arg)
|
||||
for accepted_arg in accepted_args
|
||||
)
|
||||
return all(
|
||||
any(
|
||||
typehint_issubclass(provided_arg, accepted_arg)
|
||||
for accepted_arg in accepted_args
|
||||
)
|
||||
for provided_arg in provided_args
|
||||
)
|
||||
|
||||
# Check if the origin of both types is the same (e.g., list for List[int])
|
||||
# This probably should be issubclass instead of ==
|
||||
if (provided_type_origin or possible_subclass) != (
|
||||
accepted_type_origin or possible_superclass
|
||||
):
|
||||
return False
|
||||
|
||||
# Ensure all specific types are compatible with accepted types
|
||||
# Note this is not necessarily correct, as it doesn't check against contravariance and covariance
|
||||
# It also ignores when the length of the arguments is different
|
||||
return all(
|
||||
typehint_issubclass(provided_arg, accepted_arg)
|
||||
for provided_arg, accepted_arg in zip(provided_args, accepted_args)
|
||||
if accepted_arg is not Any
|
||||
)
|
||||
|
@ -20,13 +20,17 @@ from reflex.event import (
|
||||
EventChain,
|
||||
EventHandler,
|
||||
empty_event,
|
||||
identity_event,
|
||||
input_event,
|
||||
parse_args_spec,
|
||||
)
|
||||
from reflex.state import BaseState
|
||||
from reflex.style import Style
|
||||
from reflex.utils import imports
|
||||
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
|
||||
from reflex.utils.exceptions import (
|
||||
EventFnArgMismatch,
|
||||
EventHandlerArgMismatch,
|
||||
)
|
||||
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
@ -43,6 +47,18 @@ def test_state():
|
||||
def do_something_arg(self, arg):
|
||||
pass
|
||||
|
||||
def do_something_with_bool(self, arg: bool):
|
||||
pass
|
||||
|
||||
def do_something_with_int(self, arg: int):
|
||||
pass
|
||||
|
||||
def do_something_with_list_int(self, arg: list[int]):
|
||||
pass
|
||||
|
||||
def do_something_with_list_str(self, arg: list[str]):
|
||||
pass
|
||||
|
||||
return TestState
|
||||
|
||||
|
||||
@ -95,8 +111,10 @@ def component2() -> Type[Component]:
|
||||
"""
|
||||
return {
|
||||
**super().get_event_triggers(),
|
||||
"on_open": lambda e0: [e0],
|
||||
"on_close": lambda e0: [e0],
|
||||
"on_open": identity_event(bool),
|
||||
"on_close": identity_event(bool),
|
||||
"on_user_visited_count_changed": identity_event(int),
|
||||
"on_user_list_changed": identity_event(List[str]),
|
||||
}
|
||||
|
||||
def _get_imports(self) -> ParsedImportDict:
|
||||
@ -582,7 +600,14 @@ def test_get_event_triggers(component1, component2):
|
||||
assert component1().get_event_triggers().keys() == default_triggers
|
||||
assert (
|
||||
component2().get_event_triggers().keys()
|
||||
== {"on_open", "on_close", "on_prop_event"} | default_triggers
|
||||
== {
|
||||
"on_open",
|
||||
"on_close",
|
||||
"on_prop_event",
|
||||
"on_user_visited_count_changed",
|
||||
"on_user_list_changed",
|
||||
}
|
||||
| default_triggers
|
||||
)
|
||||
|
||||
|
||||
@ -903,6 +928,22 @@ def test_invalid_event_handler_args(component2, test_state):
|
||||
on_prop_event=[test_state.do_something_arg, test_state.do_something]
|
||||
)
|
||||
|
||||
# Enable when 0.7.0 happens
|
||||
# # Event Handler types must match
|
||||
# with pytest.raises(EventHandlerArgTypeMismatch):
|
||||
# component2.create(
|
||||
# on_user_visited_count_changed=test_state.do_something_with_bool
|
||||
# )
|
||||
# with pytest.raises(EventHandlerArgTypeMismatch):
|
||||
# component2.create(on_user_list_changed=test_state.do_something_with_int)
|
||||
# with pytest.raises(EventHandlerArgTypeMismatch):
|
||||
# component2.create(on_user_list_changed=test_state.do_something_with_list_int)
|
||||
|
||||
# component2.create(on_open=test_state.do_something_with_int)
|
||||
# component2.create(on_open=test_state.do_something_with_bool)
|
||||
# component2.create(on_user_visited_count_changed=test_state.do_something_with_int)
|
||||
# component2.create(on_user_list_changed=test_state.do_something_with_list_str)
|
||||
|
||||
# lambda cannot return weird values.
|
||||
with pytest.raises(ValueError):
|
||||
component2.create(on_click=lambda: 1)
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, List, Literal, Type, Union
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Type, Union
|
||||
|
||||
import pytest
|
||||
import typer
|
||||
@ -77,6 +77,47 @@ def test_is_generic_alias(cls: type, expected: bool):
|
||||
assert types.is_generic_alias(cls) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("subclass", "superclass", "expected"),
|
||||
[
|
||||
*[
|
||||
(base_type, base_type, True)
|
||||
for base_type in [int, float, str, bool, list, dict]
|
||||
],
|
||||
*[
|
||||
(one_type, another_type, False)
|
||||
for one_type in [int, float, str, list, dict]
|
||||
for another_type in [int, float, str, list, dict]
|
||||
if one_type != another_type
|
||||
],
|
||||
(bool, int, True),
|
||||
(int, bool, False),
|
||||
(list, List, True),
|
||||
(list, List[str], True), # this is wrong, but it's a limitation of the function
|
||||
(List, list, True),
|
||||
(List[int], list, True),
|
||||
(List[int], List, True),
|
||||
(List[int], List[str], False),
|
||||
(List[int], List[int], True),
|
||||
(List[int], List[float], False),
|
||||
(List[int], List[Union[int, float]], True),
|
||||
(List[int], List[Union[float, str]], False),
|
||||
(Union[int, float], List[Union[int, float]], False),
|
||||
(Union[int, float], Union[int, float, str], True),
|
||||
(Union[int, float], Union[str, float], False),
|
||||
(Dict[str, int], Dict[str, int], True),
|
||||
(Dict[str, bool], Dict[str, int], True),
|
||||
(Dict[str, int], Dict[str, bool], False),
|
||||
(Dict[str, Any], dict[str, str], False),
|
||||
(Dict[str, str], dict[str, str], True),
|
||||
(Dict[str, str], dict[str, Any], True),
|
||||
(Dict[str, Any], dict[str, Any], True),
|
||||
],
|
||||
)
|
||||
def test_typehint_issubclass(subclass, superclass, expected):
|
||||
assert types.typehint_issubclass(subclass, superclass) == expected
|
||||
|
||||
|
||||
def test_validate_invalid_bun_path(mocker):
|
||||
"""Test that an error is thrown when a custom specified bun path is not valid
|
||||
or does not exist.
|
||||
|
Loading…
Reference in New Issue
Block a user