add multiple argspec

This commit is contained in:
Khaleel Al-Adhami 2024-10-22 14:26:50 -07:00
parent fcf6aa6cf3
commit 0ec59aaf53
4 changed files with 137 additions and 73 deletions

View File

@ -2,11 +2,11 @@
from __future__ import annotations
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Literal, Optional, Union
from reflex.components.component import Component
from reflex.components.core.breakpoints import Responsive
from reflex.event import EventHandler
from reflex.event import EventHandler, identity_event
from reflex.vars.base import Var
from ..base import (
@ -14,19 +14,11 @@ from ..base import (
RadixThemesComponent,
)
def on_value_event_spec(
value: Var[List[int | float]],
) -> Tuple[Var[List[int | float]]]:
"""Event handler spec for the value event.
Args:
value: The value of the event.
Returns:
The event handler spec.
"""
return (value,) # type: ignore
on_value_event_spec = (
identity_event(list[Union[int, float]]),
identity_event(list[int]),
identity_event(list[float]),
)
class Slider(RadixThemesComponent):

View File

@ -3,18 +3,20 @@
# ------------------- DO NOT EDIT ----------------------
# This file was generated by `reflex/utils/pyi_generator.py`!
# ------------------------------------------------------
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
from typing import Any, Dict, List, Literal, Optional, Union, overload
from reflex.components.core.breakpoints import Breakpoints
from reflex.event import EventType
from reflex.event import EventType, identity_event
from reflex.style import Style
from reflex.vars.base import Var
from ..base import RadixThemesComponent
def on_value_event_spec(
value: Var[List[int | float]],
) -> Tuple[Var[List[int | float]]]: ...
on_value_event_spec = (
identity_event(list[int]),
identity_event(list[Union[int, float]]),
identity_event(list[float]),
)
class Slider(RadixThemesComponent):
@overload
@ -138,7 +140,13 @@ class Slider(RadixThemesComponent):
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
on_blur: Optional[EventType[[]]] = None,
on_change: Optional[EventType[List[int | float]]] = None,
on_change: Optional[
Union[
EventType[list[int]],
EventType[list[Union[int, float]]],
EventType[list[float]],
]
] = None,
on_click: Optional[EventType[[]]] = None,
on_context_menu: Optional[EventType[[]]] = None,
on_double_click: Optional[EventType[[]]] = None,
@ -153,7 +161,13 @@ class Slider(RadixThemesComponent):
on_mouse_up: Optional[EventType[[]]] = None,
on_scroll: Optional[EventType[[]]] = None,
on_unmount: Optional[EventType[[]]] = None,
on_value_commit: Optional[EventType[List[int | float]]] = None,
on_value_commit: Optional[
Union[
EventType[list[int]],
EventType[list[Union[int, float]]],
EventType[list[float]],
]
] = None,
**props,
) -> "Slider":
"""Create a Slider component.

View File

@ -16,6 +16,7 @@ from typing import (
Generic,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
@ -395,7 +396,9 @@ class EventChain(EventActionsMixin):
events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list)
args_spec: Optional[Callable] = dataclasses.field(default=None)
args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field(
default=None
)
invocation: Optional[Var] = dataclasses.field(default=None)
@ -1040,7 +1043,7 @@ def get_hydrate_event(state) -> str:
def call_event_handler(
event_handler: EventHandler | EventSpec,
arg_spec: ArgsSpec,
arg_spec: ArgsSpec | Sequence[ArgsSpec],
key: Optional[str] = None,
) -> EventSpec:
"""Call an event handler to get the event spec.
@ -1084,6 +1087,9 @@ def call_event_handler(
)
def compare_types(provided_type, accepted_type):
if accepted_type is Any:
return True
provided_type_origin = get_origin(provided_type)
accepted_type_origin = get_origin(accepted_type)
@ -1091,6 +1097,13 @@ def call_event_handler(
# Check if both are concrete types (e.g., int)
return issubclass(provided_type, accepted_type)
provided_type_origin = (
Union if provided_type_origin is types.UnionType else provided_type_origin
)
accepted_type_origin = (
Union if accepted_type_origin is types.UnionType else accepted_type_origin
)
# Check if both are generic types (e.g., List)
if (provided_type_origin or provided_type) != (
accepted_type_origin or accepted_type
@ -1103,48 +1116,70 @@ def call_event_handler(
# Ensure all specific types are compatible with accepted types
return all(
issubclass(provided_arg, accepted_arg)
compare_types(provided_arg, accepted_arg)
for provided_arg, accepted_arg in zip(provided_args, accepted_args)
if accepted_arg is not Any
)
event_spec_return_type = get_type_hints(arg_spec).get("return", None)
all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
if (
event_spec_return_type is not None
and get_origin(event_spec_return_type) is tuple
):
args = get_args(event_spec_return_type)
event_spec_return_types = list(
filter(
lambda event_spec_return_type: event_spec_return_type is not None
and get_origin(event_spec_return_type) is tuple,
(get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
)
)
args_types_without_vars = [
arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
]
if event_spec_return_types:
failures = []
try:
type_hints_of_provided_callback = get_type_hints(event_handler.fn)
except NameError:
type_hints_of_provided_callback = {}
for event_spec_return_type in event_spec_return_types:
args = get_args(event_spec_return_type)
# check that args of event handler are matching the spec if type hints are provided
for i, arg in enumerate(provided_callback_fullspec.args[1:]):
if arg not in type_hints_of_provided_callback:
continue
args_types_without_vars = [
arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
]
try:
compare_result = compare_types(
args_types_without_vars[i], type_hints_of_provided_callback[arg]
)
except TypeError as e:
raise TypeError(
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
) from e
type_hints_of_provided_callback = get_type_hints(event_handler.fn)
except NameError:
type_hints_of_provided_callback = {}
if compare_result:
continue
else:
raise EventHandlerArgTypeMismatch(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
)
failed_type_check = False
# check that args of event handler are matching the spec if type hints are provided
for i, arg in enumerate(provided_callback_fullspec.args[1:]):
if arg not in type_hints_of_provided_callback:
continue
try:
compare_result = compare_types(
args_types_without_vars[i], type_hints_of_provided_callback[arg]
)
except TypeError as e:
raise TypeError(
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
) from e
if compare_result:
continue
else:
failure = EventHandlerArgTypeMismatch(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
)
if len(event_spec_return_types) == 1:
raise failure
else:
failures.append(failure)
failed_type_check = True
break
if not failed_type_check:
return event_handler(*parsed_args)
if failures:
raise EventHandlerArgTypeMismatch("\n".join([str(f) for f in failures]))
return event_handler(*parsed_args) # type: ignore
@ -1186,7 +1221,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
return annotation
def parse_args_spec(arg_spec: ArgsSpec):
def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
"""Parse the args provided in the ArgsSpec of an event trigger.
Args:
@ -1195,6 +1230,8 @@ def parse_args_spec(arg_spec: ArgsSpec):
Returns:
The parsed args.
"""
# if there's multiple, the first is the default
arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec
spec = inspect.getfullargspec(arg_spec)
annotations = get_type_hints(arg_spec)
@ -1501,7 +1538,12 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
Returns:
The created LiteralEventChainVar instance.
"""
sig = inspect.signature(value.args_spec) # type: ignore
arg_spec = (
value.args_spec[0]
if isinstance(value.args_spec, Sequence)
else value.args_spec
)
sig = inspect.signature(arg_spec) # type: ignore
if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])

View File

@ -490,7 +490,7 @@ def _generate_component_create_functiondef(
def figure_out_return_type(annotation: Any):
if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
return ast.Name(id="Optional[EventType]")
return ast.Name(id="EventType")
if not isinstance(annotation, str) and get_origin(annotation) is tuple:
arguments = get_args(annotation)
@ -509,20 +509,13 @@ def _generate_component_create_functiondef(
# Create EventType using the joined string
event_type = ast.Name(id=f"EventType[{args_str}]")
# Wrap in Optional
optional_type = ast.Subscript(
value=ast.Name(id="Optional"),
slice=ast.Index(value=event_type),
ctx=ast.Load(),
)
return ast.Name(id=ast.unparse(optional_type))
return event_type
if isinstance(annotation, str) and annotation.startswith("Tuple["):
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
if inside_of_tuple == "()":
return ast.Name(id="Optional[EventType[[]]]")
return ast.Name(id="EventType[[]]")
arguments = [""]
@ -548,10 +541,8 @@ def _generate_component_create_functiondef(
for argument in arguments
]
return ast.Name(
id=f"Optional[EventType[{', '.join(arguments_without_var)}]]"
)
return ast.Name(id="Optional[EventType]")
return ast.Name(id=f"EventType[{', '.join(arguments_without_var)}]")
return ast.Name(id="EventType")
event_triggers = clz().get_event_triggers()
@ -560,8 +551,33 @@ def _generate_component_create_functiondef(
(
ast.arg(
arg=trigger,
annotation=figure_out_return_type(
inspect.signature(event_triggers[trigger]).return_annotation
annotation=ast.Subscript(
ast.Name("Optional"),
ast.Index( # type: ignore
value=ast.Name(
id=ast.unparse(
figure_out_return_type(
inspect.signature(event_specs).return_annotation
)
if not isinstance(
event_specs := event_triggers[trigger], tuple
)
else ast.Subscript(
ast.Name("Union"),
ast.Tuple(
[
figure_out_return_type(
inspect.signature(
event_spec
).return_annotation
)
for event_spec in event_specs
]
),
)
)
)
),
),
),
ast.Constant(value=None),