add multiple argspec
This commit is contained in:
parent
fcf6aa6cf3
commit
0ec59aaf53
@ -2,11 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.core.breakpoints import Responsive
|
||||
from reflex.event import EventHandler
|
||||
from reflex.event import EventHandler, identity_event
|
||||
from reflex.vars.base import Var
|
||||
|
||||
from ..base import (
|
||||
@ -14,19 +14,11 @@ from ..base import (
|
||||
RadixThemesComponent,
|
||||
)
|
||||
|
||||
|
||||
def on_value_event_spec(
|
||||
value: Var[List[int | float]],
|
||||
) -> Tuple[Var[List[int | float]]]:
|
||||
"""Event handler spec for the value event.
|
||||
|
||||
Args:
|
||||
value: The value of the event.
|
||||
|
||||
Returns:
|
||||
The event handler spec.
|
||||
"""
|
||||
return (value,) # type: ignore
|
||||
on_value_event_spec = (
|
||||
identity_event(list[Union[int, float]]),
|
||||
identity_event(list[int]),
|
||||
identity_event(list[float]),
|
||||
)
|
||||
|
||||
|
||||
class Slider(RadixThemesComponent):
|
||||
|
@ -3,18 +3,20 @@
|
||||
# ------------------- DO NOT EDIT ----------------------
|
||||
# This file was generated by `reflex/utils/pyi_generator.py`!
|
||||
# ------------------------------------------------------
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
|
||||
from typing import Any, Dict, List, Literal, Optional, Union, overload
|
||||
|
||||
from reflex.components.core.breakpoints import Breakpoints
|
||||
from reflex.event import EventType
|
||||
from reflex.event import EventType, identity_event
|
||||
from reflex.style import Style
|
||||
from reflex.vars.base import Var
|
||||
|
||||
from ..base import RadixThemesComponent
|
||||
|
||||
def on_value_event_spec(
|
||||
value: Var[List[int | float]],
|
||||
) -> Tuple[Var[List[int | float]]]: ...
|
||||
on_value_event_spec = (
|
||||
identity_event(list[int]),
|
||||
identity_event(list[Union[int, float]]),
|
||||
identity_event(list[float]),
|
||||
)
|
||||
|
||||
class Slider(RadixThemesComponent):
|
||||
@overload
|
||||
@ -138,7 +140,13 @@ class Slider(RadixThemesComponent):
|
||||
autofocus: Optional[bool] = None,
|
||||
custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
|
||||
on_blur: Optional[EventType[[]]] = None,
|
||||
on_change: Optional[EventType[List[int | float]]] = None,
|
||||
on_change: Optional[
|
||||
Union[
|
||||
EventType[list[int]],
|
||||
EventType[list[Union[int, float]]],
|
||||
EventType[list[float]],
|
||||
]
|
||||
] = None,
|
||||
on_click: Optional[EventType[[]]] = None,
|
||||
on_context_menu: Optional[EventType[[]]] = None,
|
||||
on_double_click: Optional[EventType[[]]] = None,
|
||||
@ -153,7 +161,13 @@ class Slider(RadixThemesComponent):
|
||||
on_mouse_up: Optional[EventType[[]]] = None,
|
||||
on_scroll: Optional[EventType[[]]] = None,
|
||||
on_unmount: Optional[EventType[[]]] = None,
|
||||
on_value_commit: Optional[EventType[List[int | float]]] = None,
|
||||
on_value_commit: Optional[
|
||||
Union[
|
||||
EventType[list[int]],
|
||||
EventType[list[Union[int, float]]],
|
||||
EventType[list[float]],
|
||||
]
|
||||
] = None,
|
||||
**props,
|
||||
) -> "Slider":
|
||||
"""Create a Slider component.
|
||||
|
112
reflex/event.py
112
reflex/event.py
@ -16,6 +16,7 @@ from typing import (
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
@ -395,7 +396,9 @@ class EventChain(EventActionsMixin):
|
||||
|
||||
events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list)
|
||||
|
||||
args_spec: Optional[Callable] = dataclasses.field(default=None)
|
||||
args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field(
|
||||
default=None
|
||||
)
|
||||
|
||||
invocation: Optional[Var] = dataclasses.field(default=None)
|
||||
|
||||
@ -1040,7 +1043,7 @@ def get_hydrate_event(state) -> str:
|
||||
|
||||
def call_event_handler(
|
||||
event_handler: EventHandler | EventSpec,
|
||||
arg_spec: ArgsSpec,
|
||||
arg_spec: ArgsSpec | Sequence[ArgsSpec],
|
||||
key: Optional[str] = None,
|
||||
) -> EventSpec:
|
||||
"""Call an event handler to get the event spec.
|
||||
@ -1084,6 +1087,9 @@ def call_event_handler(
|
||||
)
|
||||
|
||||
def compare_types(provided_type, accepted_type):
|
||||
if accepted_type is Any:
|
||||
return True
|
||||
|
||||
provided_type_origin = get_origin(provided_type)
|
||||
accepted_type_origin = get_origin(accepted_type)
|
||||
|
||||
@ -1091,6 +1097,13 @@ def call_event_handler(
|
||||
# Check if both are concrete types (e.g., int)
|
||||
return issubclass(provided_type, accepted_type)
|
||||
|
||||
provided_type_origin = (
|
||||
Union if provided_type_origin is types.UnionType else provided_type_origin
|
||||
)
|
||||
accepted_type_origin = (
|
||||
Union if accepted_type_origin is types.UnionType else accepted_type_origin
|
||||
)
|
||||
|
||||
# Check if both are generic types (e.g., List)
|
||||
if (provided_type_origin or provided_type) != (
|
||||
accepted_type_origin or accepted_type
|
||||
@ -1103,48 +1116,70 @@ def call_event_handler(
|
||||
|
||||
# Ensure all specific types are compatible with accepted types
|
||||
return all(
|
||||
issubclass(provided_arg, accepted_arg)
|
||||
compare_types(provided_arg, accepted_arg)
|
||||
for provided_arg, accepted_arg in zip(provided_args, accepted_args)
|
||||
if accepted_arg is not Any
|
||||
)
|
||||
|
||||
event_spec_return_type = get_type_hints(arg_spec).get("return", None)
|
||||
all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
|
||||
|
||||
if (
|
||||
event_spec_return_type is not None
|
||||
and get_origin(event_spec_return_type) is tuple
|
||||
):
|
||||
args = get_args(event_spec_return_type)
|
||||
event_spec_return_types = list(
|
||||
filter(
|
||||
lambda event_spec_return_type: event_spec_return_type is not None
|
||||
and get_origin(event_spec_return_type) is tuple,
|
||||
(get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
|
||||
)
|
||||
)
|
||||
|
||||
args_types_without_vars = [
|
||||
arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
|
||||
]
|
||||
if event_spec_return_types:
|
||||
failures = []
|
||||
|
||||
try:
|
||||
type_hints_of_provided_callback = get_type_hints(event_handler.fn)
|
||||
except NameError:
|
||||
type_hints_of_provided_callback = {}
|
||||
for event_spec_return_type in event_spec_return_types:
|
||||
args = get_args(event_spec_return_type)
|
||||
|
||||
# check that args of event handler are matching the spec if type hints are provided
|
||||
for i, arg in enumerate(provided_callback_fullspec.args[1:]):
|
||||
if arg not in type_hints_of_provided_callback:
|
||||
continue
|
||||
args_types_without_vars = [
|
||||
arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
|
||||
]
|
||||
|
||||
try:
|
||||
compare_result = compare_types(
|
||||
args_types_without_vars[i], type_hints_of_provided_callback[arg]
|
||||
)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
|
||||
) from e
|
||||
type_hints_of_provided_callback = get_type_hints(event_handler.fn)
|
||||
except NameError:
|
||||
type_hints_of_provided_callback = {}
|
||||
|
||||
if compare_result:
|
||||
continue
|
||||
else:
|
||||
raise EventHandlerArgTypeMismatch(
|
||||
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
|
||||
)
|
||||
failed_type_check = False
|
||||
|
||||
# check that args of event handler are matching the spec if type hints are provided
|
||||
for i, arg in enumerate(provided_callback_fullspec.args[1:]):
|
||||
if arg not in type_hints_of_provided_callback:
|
||||
continue
|
||||
|
||||
try:
|
||||
compare_result = compare_types(
|
||||
args_types_without_vars[i], type_hints_of_provided_callback[arg]
|
||||
)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
|
||||
) from e
|
||||
|
||||
if compare_result:
|
||||
continue
|
||||
else:
|
||||
failure = EventHandlerArgTypeMismatch(
|
||||
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
|
||||
)
|
||||
if len(event_spec_return_types) == 1:
|
||||
raise failure
|
||||
else:
|
||||
failures.append(failure)
|
||||
failed_type_check = True
|
||||
break
|
||||
|
||||
if not failed_type_check:
|
||||
return event_handler(*parsed_args)
|
||||
|
||||
if failures:
|
||||
raise EventHandlerArgTypeMismatch("\n".join([str(f) for f in failures]))
|
||||
|
||||
return event_handler(*parsed_args) # type: ignore
|
||||
|
||||
@ -1186,7 +1221,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
|
||||
return annotation
|
||||
|
||||
|
||||
def parse_args_spec(arg_spec: ArgsSpec):
|
||||
def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
|
||||
"""Parse the args provided in the ArgsSpec of an event trigger.
|
||||
|
||||
Args:
|
||||
@ -1195,6 +1230,8 @@ def parse_args_spec(arg_spec: ArgsSpec):
|
||||
Returns:
|
||||
The parsed args.
|
||||
"""
|
||||
# if there's multiple, the first is the default
|
||||
arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec
|
||||
spec = inspect.getfullargspec(arg_spec)
|
||||
annotations = get_type_hints(arg_spec)
|
||||
|
||||
@ -1501,7 +1538,12 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
|
||||
Returns:
|
||||
The created LiteralEventChainVar instance.
|
||||
"""
|
||||
sig = inspect.signature(value.args_spec) # type: ignore
|
||||
arg_spec = (
|
||||
value.args_spec[0]
|
||||
if isinstance(value.args_spec, Sequence)
|
||||
else value.args_spec
|
||||
)
|
||||
sig = inspect.signature(arg_spec) # type: ignore
|
||||
if sig.parameters:
|
||||
arg_def = tuple((f"_{p}" for p in sig.parameters))
|
||||
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
|
||||
|
@ -490,7 +490,7 @@ def _generate_component_create_functiondef(
|
||||
|
||||
def figure_out_return_type(annotation: Any):
|
||||
if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
|
||||
return ast.Name(id="Optional[EventType]")
|
||||
return ast.Name(id="EventType")
|
||||
|
||||
if not isinstance(annotation, str) and get_origin(annotation) is tuple:
|
||||
arguments = get_args(annotation)
|
||||
@ -509,20 +509,13 @@ def _generate_component_create_functiondef(
|
||||
# Create EventType using the joined string
|
||||
event_type = ast.Name(id=f"EventType[{args_str}]")
|
||||
|
||||
# Wrap in Optional
|
||||
optional_type = ast.Subscript(
|
||||
value=ast.Name(id="Optional"),
|
||||
slice=ast.Index(value=event_type),
|
||||
ctx=ast.Load(),
|
||||
)
|
||||
|
||||
return ast.Name(id=ast.unparse(optional_type))
|
||||
return event_type
|
||||
|
||||
if isinstance(annotation, str) and annotation.startswith("Tuple["):
|
||||
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
|
||||
|
||||
if inside_of_tuple == "()":
|
||||
return ast.Name(id="Optional[EventType[[]]]")
|
||||
return ast.Name(id="EventType[[]]")
|
||||
|
||||
arguments = [""]
|
||||
|
||||
@ -548,10 +541,8 @@ def _generate_component_create_functiondef(
|
||||
for argument in arguments
|
||||
]
|
||||
|
||||
return ast.Name(
|
||||
id=f"Optional[EventType[{', '.join(arguments_without_var)}]]"
|
||||
)
|
||||
return ast.Name(id="Optional[EventType]")
|
||||
return ast.Name(id=f"EventType[{', '.join(arguments_without_var)}]")
|
||||
return ast.Name(id="EventType")
|
||||
|
||||
event_triggers = clz().get_event_triggers()
|
||||
|
||||
@ -560,8 +551,33 @@ def _generate_component_create_functiondef(
|
||||
(
|
||||
ast.arg(
|
||||
arg=trigger,
|
||||
annotation=figure_out_return_type(
|
||||
inspect.signature(event_triggers[trigger]).return_annotation
|
||||
annotation=ast.Subscript(
|
||||
ast.Name("Optional"),
|
||||
ast.Index( # type: ignore
|
||||
value=ast.Name(
|
||||
id=ast.unparse(
|
||||
figure_out_return_type(
|
||||
inspect.signature(event_specs).return_annotation
|
||||
)
|
||||
if not isinstance(
|
||||
event_specs := event_triggers[trigger], tuple
|
||||
)
|
||||
else ast.Subscript(
|
||||
ast.Name("Union"),
|
||||
ast.Tuple(
|
||||
[
|
||||
figure_out_return_type(
|
||||
inspect.signature(
|
||||
event_spec
|
||||
).return_annotation
|
||||
)
|
||||
for event_spec in event_specs
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
),
|
||||
ast.Constant(value=None),
|
||||
|
Loading…
Reference in New Issue
Block a user