add multiple argspec

This commit is contained in:
Khaleel Al-Adhami 2024-10-22 14:26:50 -07:00
parent fcf6aa6cf3
commit 0ec59aaf53
4 changed files with 137 additions and 73 deletions

View File

@ -2,11 +2,11 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Union
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.core.breakpoints import Responsive from reflex.components.core.breakpoints import Responsive
from reflex.event import EventHandler from reflex.event import EventHandler, identity_event
from reflex.vars.base import Var from reflex.vars.base import Var
from ..base import ( from ..base import (
@ -14,19 +14,11 @@ from ..base import (
RadixThemesComponent, RadixThemesComponent,
) )
on_value_event_spec = (
def on_value_event_spec( identity_event(list[Union[int, float]]),
value: Var[List[int | float]], identity_event(list[int]),
) -> Tuple[Var[List[int | float]]]: identity_event(list[float]),
"""Event handler spec for the value event. )
Args:
value: The value of the event.
Returns:
The event handler spec.
"""
return (value,) # type: ignore
class Slider(RadixThemesComponent): class Slider(RadixThemesComponent):

View File

@ -3,18 +3,20 @@
# ------------------- DO NOT EDIT ---------------------- # ------------------- DO NOT EDIT ----------------------
# This file was generated by `reflex/utils/pyi_generator.py`! # This file was generated by `reflex/utils/pyi_generator.py`!
# ------------------------------------------------------ # ------------------------------------------------------
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload from typing import Any, Dict, List, Literal, Optional, Union, overload
from reflex.components.core.breakpoints import Breakpoints from reflex.components.core.breakpoints import Breakpoints
from reflex.event import EventType from reflex.event import EventType, identity_event
from reflex.style import Style from reflex.style import Style
from reflex.vars.base import Var from reflex.vars.base import Var
from ..base import RadixThemesComponent from ..base import RadixThemesComponent
def on_value_event_spec( on_value_event_spec = (
value: Var[List[int | float]], identity_event(list[int]),
) -> Tuple[Var[List[int | float]]]: ... identity_event(list[Union[int, float]]),
identity_event(list[float]),
)
class Slider(RadixThemesComponent): class Slider(RadixThemesComponent):
@overload @overload
@ -138,7 +140,13 @@ class Slider(RadixThemesComponent):
autofocus: Optional[bool] = None, autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
on_blur: Optional[EventType[[]]] = None, on_blur: Optional[EventType[[]]] = None,
on_change: Optional[EventType[List[int | float]]] = None, on_change: Optional[
Union[
EventType[list[int]],
EventType[list[Union[int, float]]],
EventType[list[float]],
]
] = None,
on_click: Optional[EventType[[]]] = None, on_click: Optional[EventType[[]]] = None,
on_context_menu: Optional[EventType[[]]] = None, on_context_menu: Optional[EventType[[]]] = None,
on_double_click: Optional[EventType[[]]] = None, on_double_click: Optional[EventType[[]]] = None,
@ -153,7 +161,13 @@ class Slider(RadixThemesComponent):
on_mouse_up: Optional[EventType[[]]] = None, on_mouse_up: Optional[EventType[[]]] = None,
on_scroll: Optional[EventType[[]]] = None, on_scroll: Optional[EventType[[]]] = None,
on_unmount: Optional[EventType[[]]] = None, on_unmount: Optional[EventType[[]]] = None,
on_value_commit: Optional[EventType[List[int | float]]] = None, on_value_commit: Optional[
Union[
EventType[list[int]],
EventType[list[Union[int, float]]],
EventType[list[float]],
]
] = None,
**props, **props,
) -> "Slider": ) -> "Slider":
"""Create a Slider component. """Create a Slider component.

View File

@ -16,6 +16,7 @@ from typing import (
Generic, Generic,
List, List,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -395,7 +396,9 @@ class EventChain(EventActionsMixin):
events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list) events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list)
args_spec: Optional[Callable] = dataclasses.field(default=None) args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field(
default=None
)
invocation: Optional[Var] = dataclasses.field(default=None) invocation: Optional[Var] = dataclasses.field(default=None)
@ -1040,7 +1043,7 @@ def get_hydrate_event(state) -> str:
def call_event_handler( def call_event_handler(
event_handler: EventHandler | EventSpec, event_handler: EventHandler | EventSpec,
arg_spec: ArgsSpec, arg_spec: ArgsSpec | Sequence[ArgsSpec],
key: Optional[str] = None, key: Optional[str] = None,
) -> EventSpec: ) -> EventSpec:
"""Call an event handler to get the event spec. """Call an event handler to get the event spec.
@ -1084,6 +1087,9 @@ def call_event_handler(
) )
def compare_types(provided_type, accepted_type): def compare_types(provided_type, accepted_type):
if accepted_type is Any:
return True
provided_type_origin = get_origin(provided_type) provided_type_origin = get_origin(provided_type)
accepted_type_origin = get_origin(accepted_type) accepted_type_origin = get_origin(accepted_type)
@ -1091,6 +1097,13 @@ def call_event_handler(
# Check if both are concrete types (e.g., int) # Check if both are concrete types (e.g., int)
return issubclass(provided_type, accepted_type) return issubclass(provided_type, accepted_type)
provided_type_origin = (
Union if provided_type_origin is types.UnionType else provided_type_origin
)
accepted_type_origin = (
Union if accepted_type_origin is types.UnionType else accepted_type_origin
)
# Check if both are generic types (e.g., List) # Check if both are generic types (e.g., List)
if (provided_type_origin or provided_type) != ( if (provided_type_origin or provided_type) != (
accepted_type_origin or accepted_type accepted_type_origin or accepted_type
@ -1103,17 +1116,25 @@ def call_event_handler(
# Ensure all specific types are compatible with accepted types # Ensure all specific types are compatible with accepted types
return all( return all(
issubclass(provided_arg, accepted_arg) compare_types(provided_arg, accepted_arg)
for provided_arg, accepted_arg in zip(provided_args, accepted_args) for provided_arg, accepted_arg in zip(provided_args, accepted_args)
if accepted_arg is not Any if accepted_arg is not Any
) )
event_spec_return_type = get_type_hints(arg_spec).get("return", None) all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
if ( event_spec_return_types = list(
event_spec_return_type is not None filter(
and get_origin(event_spec_return_type) is tuple lambda event_spec_return_type: event_spec_return_type is not None
): and get_origin(event_spec_return_type) is tuple,
(get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
)
)
if event_spec_return_types:
failures = []
for event_spec_return_type in event_spec_return_types:
args = get_args(event_spec_return_type) args = get_args(event_spec_return_type)
args_types_without_vars = [ args_types_without_vars = [
@ -1125,6 +1146,8 @@ def call_event_handler(
except NameError: except NameError:
type_hints_of_provided_callback = {} type_hints_of_provided_callback = {}
failed_type_check = False
# check that args of event handler are matching the spec if type hints are provided # check that args of event handler are matching the spec if type hints are provided
for i, arg in enumerate(provided_callback_fullspec.args[1:]): for i, arg in enumerate(provided_callback_fullspec.args[1:]):
if arg not in type_hints_of_provided_callback: if arg not in type_hints_of_provided_callback:
@ -1142,9 +1165,21 @@ def call_event_handler(
if compare_result: if compare_result:
continue continue
else: else:
raise EventHandlerArgTypeMismatch( failure = EventHandlerArgTypeMismatch(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead." f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
) )
if len(event_spec_return_types) == 1:
raise failure
else:
failures.append(failure)
failed_type_check = True
break
if not failed_type_check:
return event_handler(*parsed_args)
if failures:
raise EventHandlerArgTypeMismatch("\n".join([str(f) for f in failures]))
return event_handler(*parsed_args) # type: ignore return event_handler(*parsed_args) # type: ignore
@ -1186,7 +1221,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
return annotation return annotation
def parse_args_spec(arg_spec: ArgsSpec): def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
"""Parse the args provided in the ArgsSpec of an event trigger. """Parse the args provided in the ArgsSpec of an event trigger.
Args: Args:
@ -1195,6 +1230,8 @@ def parse_args_spec(arg_spec: ArgsSpec):
Returns: Returns:
The parsed args. The parsed args.
""" """
# if there's multiple, the first is the default
arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec
spec = inspect.getfullargspec(arg_spec) spec = inspect.getfullargspec(arg_spec)
annotations = get_type_hints(arg_spec) annotations = get_type_hints(arg_spec)
@ -1501,7 +1538,12 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
Returns: Returns:
The created LiteralEventChainVar instance. The created LiteralEventChainVar instance.
""" """
sig = inspect.signature(value.args_spec) # type: ignore arg_spec = (
value.args_spec[0]
if isinstance(value.args_spec, Sequence)
else value.args_spec
)
sig = inspect.signature(arg_spec) # type: ignore
if sig.parameters: if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters)) arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])

View File

@ -490,7 +490,7 @@ def _generate_component_create_functiondef(
def figure_out_return_type(annotation: Any): def figure_out_return_type(annotation: Any):
if inspect.isclass(annotation) and issubclass(annotation, inspect._empty): if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
return ast.Name(id="Optional[EventType]") return ast.Name(id="EventType")
if not isinstance(annotation, str) and get_origin(annotation) is tuple: if not isinstance(annotation, str) and get_origin(annotation) is tuple:
arguments = get_args(annotation) arguments = get_args(annotation)
@ -509,20 +509,13 @@ def _generate_component_create_functiondef(
# Create EventType using the joined string # Create EventType using the joined string
event_type = ast.Name(id=f"EventType[{args_str}]") event_type = ast.Name(id=f"EventType[{args_str}]")
# Wrap in Optional return event_type
optional_type = ast.Subscript(
value=ast.Name(id="Optional"),
slice=ast.Index(value=event_type),
ctx=ast.Load(),
)
return ast.Name(id=ast.unparse(optional_type))
if isinstance(annotation, str) and annotation.startswith("Tuple["): if isinstance(annotation, str) and annotation.startswith("Tuple["):
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]") inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
if inside_of_tuple == "()": if inside_of_tuple == "()":
return ast.Name(id="Optional[EventType[[]]]") return ast.Name(id="EventType[[]]")
arguments = [""] arguments = [""]
@ -548,10 +541,8 @@ def _generate_component_create_functiondef(
for argument in arguments for argument in arguments
] ]
return ast.Name( return ast.Name(id=f"EventType[{', '.join(arguments_without_var)}]")
id=f"Optional[EventType[{', '.join(arguments_without_var)}]]" return ast.Name(id="EventType")
)
return ast.Name(id="Optional[EventType]")
event_triggers = clz().get_event_triggers() event_triggers = clz().get_event_triggers()
@ -560,8 +551,33 @@ def _generate_component_create_functiondef(
( (
ast.arg( ast.arg(
arg=trigger, arg=trigger,
annotation=figure_out_return_type( annotation=ast.Subscript(
inspect.signature(event_triggers[trigger]).return_annotation ast.Name("Optional"),
ast.Index( # type: ignore
value=ast.Name(
id=ast.unparse(
figure_out_return_type(
inspect.signature(event_specs).return_annotation
)
if not isinstance(
event_specs := event_triggers[trigger], tuple
)
else ast.Subscript(
ast.Name("Union"),
ast.Tuple(
[
figure_out_return_type(
inspect.signature(
event_spec
).return_annotation
)
for event_spec in event_specs
]
),
)
)
)
),
), ),
), ),
ast.Constant(value=None), ast.Constant(value=None),