allow for event handlers to ignore args (#4282)

* allow for event handlers to ignore args

* use a constant

* dang it darglint

* forgor

* keep the tests but move them to valid place
This commit is contained in:
Khaleel Al-Adhami 2024-11-06 09:20:37 -08:00 committed by GitHub
parent d9ab3a0f1c
commit 6334cfab0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 126 additions and 115 deletions

View File

@ -17,6 +17,7 @@ from typing import (
Iterator,
List,
Optional,
Sequence,
Set,
Type,
Union,
@ -38,6 +39,7 @@ from reflex.constants import (
PageNames,
)
from reflex.constants.compiler import SpecialAttributes
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.event import (
EventCallback,
EventChain,
@ -533,7 +535,7 @@ class Component(BaseComponent, ABC):
def _create_event_chain(
self,
args_spec: Any,
args_spec: types.ArgsSpec | Sequence[types.ArgsSpec],
value: Union[
Var,
EventHandler,
@ -599,7 +601,7 @@ class Component(BaseComponent, ABC):
# If the input is a callable, create an event chain.
elif isinstance(value, Callable):
result = call_event_fn(value, args_spec)
result = call_event_fn(value, args_spec, key=key)
if isinstance(result, Var):
# Recursively call this function if the lambda returned an EventChain Var.
return self._create_event_chain(args_spec, result, key=key)
@ -629,14 +631,16 @@ class Component(BaseComponent, ABC):
event_actions={},
)
def get_event_triggers(self) -> Dict[str, Any]:
def get_event_triggers(
self,
) -> Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]]:
"""Get the event triggers for the component.
Returns:
The event triggers.
"""
default_triggers = {
default_triggers: Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
EventTriggers.ON_FOCUS: no_args_event_spec,
EventTriggers.ON_BLUR: no_args_event_spec,
EventTriggers.ON_CLICK: no_args_event_spec,
@ -1142,7 +1146,10 @@ class Component(BaseComponent, ABC):
if isinstance(event, EventCallback):
continue
if isinstance(event, EventSpec):
if event.handler.state_full_name:
if (
event.handler.state_full_name
and event.handler.state_full_name != FRONTEND_EVENT_STATE
):
return True
else:
if event._var_state:

View File

@ -9,3 +9,7 @@ class StateManagerMode(str, Enum):
DISK = "disk"
MEMORY = "memory"
REDIS = "redis"
# Used for things like console_log, etc.
FRONTEND_EVENT_STATE = "__reflex_internal_frontend_event_state"

View File

@ -28,10 +28,10 @@ from typing import (
from typing_extensions import ParamSpec, Protocol, get_args, get_origin
from reflex import constants
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import console, format
from reflex.utils.exceptions import (
EventFnArgMismatch,
EventHandlerArgMismatch,
EventHandlerArgTypeMismatch,
)
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
@ -662,7 +662,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
fn.__qualname__ = name
fn.__signature__ = sig
return EventSpec(
handler=EventHandler(fn=fn),
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
args=tuple(
(
Var(_js_expr=k),
@ -1092,8 +1092,8 @@ def get_hydrate_event(state) -> str:
def call_event_handler(
event_handler: EventHandler | EventSpec,
arg_spec: ArgsSpec | Sequence[ArgsSpec],
event_callback: EventHandler | EventSpec,
event_spec: ArgsSpec | Sequence[ArgsSpec],
key: Optional[str] = None,
) -> EventSpec:
"""Call an event handler to get the event spec.
@ -1103,53 +1103,57 @@ def call_event_handler(
Otherwise, the event handler will be called with no args.
Args:
event_handler: The event handler.
arg_spec: The lambda that define the argument(s) to pass to the event handler.
event_callback: The event handler.
event_spec: The lambda that define the argument(s) to pass to the event handler.
key: The key to pass to the event handler.
Raises:
EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
Returns:
The event spec from calling the event handler.
# noqa: DAR401 failure
"""
parsed_args = parse_args_spec(arg_spec) # type: ignore
event_spec_args = parse_args_spec(event_spec) # type: ignore
if isinstance(event_handler, EventSpec):
# Handle partial application of EventSpec args
return event_handler.add_args(*parsed_args)
provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)
provided_callback_n_args = (
len(provided_callback_fullspec.args) - 1
) # subtract 1 for bound self arg
if provided_callback_n_args != len(parsed_args):
raise EventHandlerArgMismatch(
"The number of arguments accepted by "
f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
if isinstance(event_callback, EventSpec):
check_fn_match_arg_spec(
event_callback.handler.fn,
event_spec,
key,
bool(event_callback.handler.state_full_name) + len(event_callback.args),
event_callback.handler.fn.__qualname__,
)
# Handle partial application of EventSpec args
return event_callback.add_args(*event_spec_args)
all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
check_fn_match_arg_spec(
event_callback.fn,
event_spec,
key,
bool(event_callback.state_full_name),
event_callback.fn.__qualname__,
)
all_acceptable_specs = (
[event_spec] if not isinstance(event_spec, Sequence) else event_spec
)
event_spec_return_types = list(
filter(
lambda event_spec_return_type: event_spec_return_type is not None
and get_origin(event_spec_return_type) is tuple,
(get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
(
get_type_hints(arg_spec).get("return", None)
for arg_spec in all_acceptable_specs
),
)
)
if event_spec_return_types:
failures = []
event_callback_spec = inspect.getfullargspec(event_callback.fn)
for event_spec_index, event_spec_return_type in enumerate(
event_spec_return_types
):
@ -1160,14 +1164,14 @@ def call_event_handler(
]
try:
type_hints_of_provided_callback = get_type_hints(event_handler.fn)
type_hints_of_provided_callback = get_type_hints(event_callback.fn)
except NameError:
type_hints_of_provided_callback = {}
failed_type_check = False
# check that args of event handler are matching the spec if type hints are provided
for i, arg in enumerate(provided_callback_fullspec.args[1:]):
for i, arg in enumerate(event_callback_spec.args[1:]):
if arg not in type_hints_of_provided_callback:
continue
@ -1181,7 +1185,7 @@ def call_event_handler(
# f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
# ) from e
console.warn(
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}."
)
compare_result = False
@ -1189,7 +1193,7 @@ def call_event_handler(
continue
else:
failure = EventHandlerArgTypeMismatch(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
)
failures.append(failure)
failed_type_check = True
@ -1210,14 +1214,14 @@ def call_event_handler(
given_string = ", ".join(
repr(type_hints_of_provided_callback.get(arg, Any))
for arg in provided_callback_fullspec.args[1:]
for arg in event_callback_spec.args[1:]
).replace("[", "\\[")
console.warn(
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. "
f"This may lead to unexpected behavior but is intentionally ignored for {key}."
)
return event_handler(*parsed_args)
return event_callback(*event_spec_args)
if failures:
console.deprecate(
@ -1227,7 +1231,7 @@ def call_event_handler(
"0.7.0",
)
return event_handler(*parsed_args) # type: ignore
return event_callback(*event_spec_args) # type: ignore
def unwrap_var_annotation(annotation: GenericType):
@ -1294,45 +1298,46 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
def check_fn_match_arg_spec(
fn: Callable,
arg_spec: ArgsSpec,
key: Optional[str] = None,
) -> List[Var]:
user_func: Callable,
arg_spec: ArgsSpec | Sequence[ArgsSpec],
key: str | None = None,
number_of_bound_args: int = 0,
func_name: str | None = None,
):
"""Ensures that the function signature matches the passed argument specification
or raises an EventFnArgMismatch if they do not.
Args:
fn: The function to be validated.
user_func: The function to be validated.
arg_spec: The argument specification for the event trigger.
key: The key to pass to the event handler.
Returns:
The parsed arguments from the argument specification.
key: The key of the event trigger.
number_of_bound_args: The number of bound arguments to the function.
func_name: The name of the function to be validated.
Raises:
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
"""
fn_args = inspect.getfullargspec(fn).args
fn_defaults_args = inspect.getfullargspec(fn).defaults
n_fn_args = len(fn_args)
n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
if isinstance(fn, types.MethodType):
n_fn_args -= 1 # subtract 1 for bound self arg
parsed_args = parse_args_spec(arg_spec)
if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
user_args = inspect.getfullargspec(user_func).args
user_default_args = inspect.getfullargspec(user_func).defaults
number_of_user_args = len(user_args) - number_of_bound_args
number_of_user_default_args = len(user_default_args) if user_default_args else 0
parsed_event_args = parse_args_spec(arg_spec)
number_of_event_args = len(parsed_event_args)
if number_of_user_args - number_of_user_default_args > number_of_event_args:
raise EventFnArgMismatch(
"The number of mandatory arguments accepted by "
f"{fn} ({n_fn_args - n_fn_defaults_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
f"Event {key} only provides {number_of_event_args} arguments, but "
f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} "
"arguments to be passed to the event handler.\n"
"See https://reflex.dev/docs/events/event-arguments/"
)
return parsed_args
def call_event_fn(
fn: Callable,
arg_spec: ArgsSpec,
arg_spec: ArgsSpec | Sequence[ArgsSpec],
key: Optional[str] = None,
) -> list[EventSpec] | Var:
"""Call a function to a list of event specs.
@ -1356,10 +1361,14 @@ def call_event_fn(
from reflex.utils.exceptions import EventHandlerValueError
# Check that fn signature matches arg_spec
parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
check_fn_match_arg_spec(fn, arg_spec, key=key)
parsed_args = parse_args_spec(arg_spec)
number_of_fn_args = len(inspect.getfullargspec(fn).args)
# Call the function with the parsed args.
out = fn(*parsed_args)
out = fn(*[*parsed_args][:number_of_fn_args])
# If the function returns a Var, assume it's an EventChain and render it directly.
if isinstance(out, Var):
@ -1478,7 +1487,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
"""
signature = inspect.signature(fn)
new_param = inspect.Parameter(
"state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
FRONTEND_EVENT_STATE, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
)
return signature.replace(parameters=(new_param, *signature.parameters.values()))

View File

@ -89,16 +89,12 @@ class MatchTypeError(ReflexError, TypeError):
"""Raised when the return types of match cases are different."""
class EventHandlerArgMismatch(ReflexError, TypeError):
"""Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""
class EventHandlerArgTypeMismatch(ReflexError, TypeError):
"""Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
class EventFnArgMismatch(ReflexError, TypeError):
"""Raised when the number of args accepted by a lambda differs from that provided by the event trigger."""
"""Raised when the number of args required by an event handler is more than provided by the event trigger."""
class DynamicRouteArgShadowsStateVar(ReflexError, NameError):

View File

@ -9,6 +9,7 @@ import re
from typing import TYPE_CHECKING, Any, List, Optional, Union
from reflex import constants
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import exceptions
from reflex.utils.console import deprecate
@ -439,7 +440,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
from reflex.state import State
if state_full_name == "state" and name not in State.__dict__:
if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__:
return ("", to_snake_case(handler.fn.__qualname__))
return (state_full_name, name)

View File

@ -16,7 +16,7 @@ from itertools import chain
from multiprocessing import Pool, cpu_count
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any, Callable, Iterable, Type, get_args, get_origin
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
from reflex.components.component import Component
from reflex.utils import types as rx_types
@ -560,7 +560,7 @@ def _generate_component_create_functiondef(
inspect.signature(event_specs).return_annotation
)
if not isinstance(
event_specs := event_triggers[trigger], tuple
event_specs := event_triggers[trigger], Sequence
)
else ast.Subscript(
ast.Name("Union"),

View File

@ -29,7 +29,6 @@ from reflex.style import Style
from reflex.utils import imports
from reflex.utils.exceptions import (
EventFnArgMismatch,
EventHandlerArgMismatch,
)
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
@ -907,26 +906,14 @@ def test_invalid_event_handler_args(component2, test_state):
test_state: A test state.
"""
# EventHandler args must match
with pytest.raises(EventHandlerArgMismatch):
with pytest.raises(EventFnArgMismatch):
component2.create(on_click=test_state.do_something_arg)
with pytest.raises(EventHandlerArgMismatch):
component2.create(on_open=test_state.do_something)
with pytest.raises(EventHandlerArgMismatch):
component2.create(on_prop_event=test_state.do_something)
# Multiple EventHandler args: all must match
with pytest.raises(EventHandlerArgMismatch):
with pytest.raises(EventFnArgMismatch):
component2.create(
on_click=[test_state.do_something_arg, test_state.do_something]
)
with pytest.raises(EventHandlerArgMismatch):
component2.create(
on_open=[test_state.do_something_arg, test_state.do_something]
)
with pytest.raises(EventHandlerArgMismatch):
component2.create(
on_prop_event=[test_state.do_something_arg, test_state.do_something]
)
# Enable when 0.7.0 happens
# # Event Handler types must match
@ -957,38 +944,19 @@ def test_invalid_event_handler_args(component2, test_state):
# lambda signature must match event trigger.
with pytest.raises(EventFnArgMismatch):
component2.create(on_click=lambda _: test_state.do_something_arg(1))
with pytest.raises(EventFnArgMismatch):
component2.create(on_open=lambda: test_state.do_something)
with pytest.raises(EventFnArgMismatch):
component2.create(on_prop_event=lambda: test_state.do_something)
# lambda returning EventHandler must match spec
with pytest.raises(EventHandlerArgMismatch):
with pytest.raises(EventFnArgMismatch):
component2.create(on_click=lambda: test_state.do_something_arg)
with pytest.raises(EventHandlerArgMismatch):
component2.create(on_open=lambda _: test_state.do_something)
with pytest.raises(EventHandlerArgMismatch):
component2.create(on_prop_event=lambda _: test_state.do_something)
# Mixed EventSpec and EventHandler must match spec.
with pytest.raises(EventHandlerArgMismatch):
with pytest.raises(EventFnArgMismatch):
component2.create(
on_click=lambda: [
test_state.do_something_arg(1),
test_state.do_something_arg,
]
)
with pytest.raises(EventHandlerArgMismatch):
component2.create(
on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
)
with pytest.raises(EventHandlerArgMismatch):
component2.create(
on_prop_event=lambda _: [
test_state.do_something_arg(1),
test_state.do_something,
]
)
def test_valid_event_handler_args(component2, test_state):
@ -1002,6 +970,10 @@ def test_valid_event_handler_args(component2, test_state):
component2.create(on_click=test_state.do_something)
component2.create(on_click=test_state.do_something_arg(1))
# Does not raise because event handlers are allowed to have less args than the spec.
component2.create(on_open=test_state.do_something)
component2.create(on_prop_event=test_state.do_something)
# Controlled event handlers should take args.
component2.create(on_open=test_state.do_something_arg)
component2.create(on_prop_event=test_state.do_something_arg)
@ -1010,10 +982,20 @@ def test_valid_event_handler_args(component2, test_state):
component2.create(on_open=test_state.do_something())
component2.create(on_prop_event=test_state.do_something())
# Multiple EventHandler args: all must match
component2.create(on_open=[test_state.do_something_arg, test_state.do_something])
component2.create(
on_prop_event=[test_state.do_something_arg, test_state.do_something]
)
# lambda returning EventHandler is okay if the spec matches.
component2.create(on_click=lambda: test_state.do_something)
component2.create(on_open=lambda _: test_state.do_something_arg)
component2.create(on_prop_event=lambda _: test_state.do_something_arg)
component2.create(on_open=lambda: test_state.do_something)
component2.create(on_prop_event=lambda: test_state.do_something)
component2.create(on_open=lambda _: test_state.do_something)
component2.create(on_prop_event=lambda _: test_state.do_something)
# lambda can always return an EventSpec.
component2.create(on_click=lambda: test_state.do_something_arg(1))
@ -1046,6 +1028,15 @@ def test_valid_event_handler_args(component2, test_state):
component2.create(
on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()]
)
component2.create(
on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
)
component2.create(
on_prop_event=lambda _: [
test_state.do_something_arg(1),
test_state.do_something,
]
)
def test_get_hooks_nested(component1, component2, component3):

View File

@ -107,7 +107,7 @@ def test_call_event_handler_partial():
def spec(a2: Var[str]) -> List[Var[str]]:
return [a2]
handler = EventHandler(fn=test_fn_with_args)
handler = EventHandler(fn=test_fn_with_args, state_full_name="BigState")
event_spec = handler(make_var("first"))
event_spec2 = call_event_handler(event_spec, spec)
@ -115,7 +115,10 @@ def test_call_event_handler_partial():
assert len(event_spec.args) == 1
assert event_spec.args[0][0].equals(Var(_js_expr="arg1"))
assert event_spec.args[0][1].equals(Var(_js_expr="first"))
assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})'
assert (
format.format_event(event_spec)
== 'Event("BigState.test_fn_with_args", {arg1:first})'
)
assert event_spec2 is not event_spec
assert event_spec2.handler == handler
@ -126,7 +129,7 @@ def test_call_event_handler_partial():
assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str))
assert (
format.format_event(event_spec2)
== 'Event("test_fn_with_args", {arg1:first,arg2:_a2})'
== 'Event("BigState.test_fn_with_args", {arg1:first,arg2:_a2})'
)