diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index 01255dd14..65449a7a7 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -5,6 +5,8 @@ from __future__ import annotations from enum import Enum from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing_extensions import TypedDict + from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.components.literals import LiteralRowMarker @@ -120,6 +122,78 @@ def on_edit_spec(pos, data: dict[str, Any]): return [pos, data] +class Bounds(TypedDict): + """The bounds of the group header.""" + + x: int + y: int + width: int + height: int + + +class CompatSelection(TypedDict): + """The selection.""" + + items: list + + +class Rectangle(TypedDict): + """The bounds of the group header.""" + + x: int + y: int + width: int + height: int + + +class GridSelectionCurrent(TypedDict): + """The current selection.""" + + cell: list[int] + range: Rectangle + rangeStack: list[Rectangle] + + +class GridSelection(TypedDict): + """The grid selection.""" + + current: Optional[GridSelectionCurrent] + columns: CompatSelection + rows: CompatSelection + + +class GroupHeaderClickedEventArgs(TypedDict): + """The arguments for the group header clicked event.""" + + kind: str + group: str + location: list[int] + bounds: Bounds + isEdge: bool + shiftKey: bool + ctrlKey: bool + metaKey: bool + isTouch: bool + localEventX: int + localEventY: int + button: int + buttons: int + scrollEdge: list[int] + + +class GridCell(TypedDict): + """The grid cell.""" + + span: Optional[List[int]] + + +class GridColumn(TypedDict): + """The grid column.""" + + title: str + group: Optional[str] + + class DataEditor(NoSSRComponent): """The DataEditor Component.""" @@ -238,10 +312,12 @@ class DataEditor(NoSSRComponent): on_group_header_clicked: EventHandler[on_edit_spec] # Fired when a group header is right-clicked. - on_group_header_context_menu: EventHandler[lambda grp_idx, data: [grp_idx, data]] + on_group_header_context_menu: EventHandler[ + identity_event(int, GroupHeaderClickedEventArgs) + ] # Fired when a group header is renamed. - on_group_header_renamed: EventHandler[lambda idx, val: [idx, val]] + on_group_header_renamed: EventHandler[identity_event(str, str)] # Fired when a header is clicked. on_header_clicked: EventHandler[identity_event(Tuple[int, int])] @@ -250,16 +326,16 @@ class DataEditor(NoSSRComponent): on_header_context_menu: EventHandler[identity_event(Tuple[int, int])] # Fired when a header menu item is clicked. - on_header_menu_click: EventHandler[lambda col, pos: [col, pos]] + on_header_menu_click: EventHandler[identity_event(int, Rectangle)] # Fired when an item is hovered. on_item_hovered: EventHandler[identity_event(Tuple[int, int])] # Fired when a selection is deleted. - on_delete: EventHandler[lambda selection: [selection]] + on_delete: EventHandler[identity_event(GridSelection)] # Fired when editing is finished. - on_finished_editing: EventHandler[lambda new_value, movement: [new_value, movement]] + on_finished_editing: EventHandler[identity_event(Optional[GridCell], list[int])] # Fired when a row is appended. on_row_appended: EventHandler[empty_event] @@ -268,7 +344,7 @@ class DataEditor(NoSSRComponent): on_selection_cleared: EventHandler[empty_event] # Fired when a column is resized. - on_column_resize: EventHandler[lambda col, width: [col, width]] + on_column_resize: EventHandler[identity_event(GridColumn, int, int)] def add_imports(self) -> ImportDict: """Add imports for the component. diff --git a/reflex/event.py b/reflex/event.py index 8291e3465..ba100335d 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -25,7 +25,7 @@ from typing import ( overload, ) -from typing_extensions import ParamSpec, get_args, get_origin +from typing_extensions import ParamSpec, Protocol, get_args, get_origin from reflex import constants from reflex.utils import console, format @@ -468,33 +468,97 @@ prevent_default = EventChain(events=[], args_spec=empty_event).prevent_default T = TypeVar("T") +U = TypeVar("U") -def identity_event(event_type: Type[T]) -> Callable[[Var[T]], Tuple[Var[T]]]: +# def identity_event(event_type: Type[T]) -> Callable[[Var[T]], Tuple[Var[T]]]: +# """A helper function that returns the input event as output. + +# Args: +# event_type: The type of the event. + +# Returns: +# A function that returns the input event as output. +# """ + +# def inner(ev: Var[T]) -> Tuple[Var[T]]: +# return (ev,) + +# inner.__signature__ = inspect.signature(inner).replace( # type: ignore +# parameters=[ +# inspect.Parameter( +# "ev", +# kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, +# annotation=Var[event_type], +# ) +# ], +# return_annotation=Tuple[Var[event_type]], +# ) +# inner.__annotations__["ev"] = Var[event_type] +# inner.__annotations__["return"] = Tuple[Var[event_type]] + +# return inner + + +class IdentityEventReturn(Generic[T], Protocol): + """Protocol for an identity event return.""" + + def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]: + """Return the input values. + + Args: + *values: The values to return. + + Returns: + The input values. + """ + return values + + +@overload +def identity_event(event_type: Type[T], /) -> Callable[[Var[T]], Tuple[Var[T]]]: ... # type: ignore + + +@overload +def identity_event( + event_type_1: Type[T], event_type2: Type[U], / +) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ... + + +@overload +def identity_event(*event_types: Type[T]) -> IdentityEventReturn[T]: ... + + +def identity_event(*event_types: Type[T]) -> IdentityEventReturn[T]: # type: ignore """A helper function that returns the input event as output. Args: - event_type: The type of the event. + *event_types: The types of the events. Returns: A function that returns the input event as output. """ - def inner(ev: Var[T]) -> Tuple[Var[T]]: - return (ev,) + def inner(*values: Var[T]) -> Tuple[Var[T], ...]: + return values + + inner_type = tuple(Var[event_type] for event_type in event_types) + return_annotation = Tuple[inner_type] # type: ignore inner.__signature__ = inspect.signature(inner).replace( # type: ignore parameters=[ inspect.Parameter( - "ev", + f"ev_{i}", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Var[event_type], ) + for i, event_type in enumerate(event_types) ], - return_annotation=Tuple[Var[event_type]], + return_annotation=return_annotation, ) - inner.__annotations__["ev"] = Var[event_type] - inner.__annotations__["return"] = Tuple[Var[event_type]] + for i, event_type in enumerate(event_types): + inner.__annotations__[f"ev_{i}"] = Var[event_type] + inner.__annotations__["return"] = return_annotation return inner