allow for event handlers to ignore args (#4282)

* allow for event handlers to ignore args

* use a constant

* dang it darglint

* forgor

* keep the tests but move them to valid place
This commit is contained in:
Khaleel Al-Adhami 2024-11-06 09:20:37 -08:00 committed by GitHub
parent d9ab3a0f1c
commit 6334cfab0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 126 additions and 115 deletions

View File

@ -17,6 +17,7 @@ from typing import (
Iterator, Iterator,
List, List,
Optional, Optional,
Sequence,
Set, Set,
Type, Type,
Union, Union,
@ -38,6 +39,7 @@ from reflex.constants import (
PageNames, PageNames,
) )
from reflex.constants.compiler import SpecialAttributes from reflex.constants.compiler import SpecialAttributes
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.event import ( from reflex.event import (
EventCallback, EventCallback,
EventChain, EventChain,
@ -533,7 +535,7 @@ class Component(BaseComponent, ABC):
def _create_event_chain( def _create_event_chain(
self, self,
args_spec: Any, args_spec: types.ArgsSpec | Sequence[types.ArgsSpec],
value: Union[ value: Union[
Var, Var,
EventHandler, EventHandler,
@ -599,7 +601,7 @@ class Component(BaseComponent, ABC):
# If the input is a callable, create an event chain. # If the input is a callable, create an event chain.
elif isinstance(value, Callable): elif isinstance(value, Callable):
result = call_event_fn(value, args_spec) result = call_event_fn(value, args_spec, key=key)
if isinstance(result, Var): if isinstance(result, Var):
# Recursively call this function if the lambda returned an EventChain Var. # Recursively call this function if the lambda returned an EventChain Var.
return self._create_event_chain(args_spec, result, key=key) return self._create_event_chain(args_spec, result, key=key)
@ -629,14 +631,16 @@ class Component(BaseComponent, ABC):
event_actions={}, event_actions={},
) )
def get_event_triggers(self) -> Dict[str, Any]: def get_event_triggers(
self,
) -> Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]]:
"""Get the event triggers for the component. """Get the event triggers for the component.
Returns: Returns:
The event triggers. The event triggers.
""" """
default_triggers = { default_triggers: Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
EventTriggers.ON_FOCUS: no_args_event_spec, EventTriggers.ON_FOCUS: no_args_event_spec,
EventTriggers.ON_BLUR: no_args_event_spec, EventTriggers.ON_BLUR: no_args_event_spec,
EventTriggers.ON_CLICK: no_args_event_spec, EventTriggers.ON_CLICK: no_args_event_spec,
@ -1142,7 +1146,10 @@ class Component(BaseComponent, ABC):
if isinstance(event, EventCallback): if isinstance(event, EventCallback):
continue continue
if isinstance(event, EventSpec): if isinstance(event, EventSpec):
if event.handler.state_full_name: if (
event.handler.state_full_name
and event.handler.state_full_name != FRONTEND_EVENT_STATE
):
return True return True
else: else:
if event._var_state: if event._var_state:

View File

@ -9,3 +9,7 @@ class StateManagerMode(str, Enum):
DISK = "disk" DISK = "disk"
MEMORY = "memory" MEMORY = "memory"
REDIS = "redis" REDIS = "redis"
# Used for things like console_log, etc.
FRONTEND_EVENT_STATE = "__reflex_internal_frontend_event_state"

View File

@ -28,10 +28,10 @@ from typing import (
from typing_extensions import ParamSpec, Protocol, get_args, get_origin from typing_extensions import ParamSpec, Protocol, get_args, get_origin
from reflex import constants from reflex import constants
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import console, format from reflex.utils import console, format
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
EventFnArgMismatch, EventFnArgMismatch,
EventHandlerArgMismatch,
EventHandlerArgTypeMismatch, EventHandlerArgTypeMismatch,
) )
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
@ -662,7 +662,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
fn.__qualname__ = name fn.__qualname__ = name
fn.__signature__ = sig fn.__signature__ = sig
return EventSpec( return EventSpec(
handler=EventHandler(fn=fn), handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
args=tuple( args=tuple(
( (
Var(_js_expr=k), Var(_js_expr=k),
@ -1092,8 +1092,8 @@ def get_hydrate_event(state) -> str:
def call_event_handler( def call_event_handler(
event_handler: EventHandler | EventSpec, event_callback: EventHandler | EventSpec,
arg_spec: ArgsSpec | Sequence[ArgsSpec], event_spec: ArgsSpec | Sequence[ArgsSpec],
key: Optional[str] = None, key: Optional[str] = None,
) -> EventSpec: ) -> EventSpec:
"""Call an event handler to get the event spec. """Call an event handler to get the event spec.
@ -1103,53 +1103,57 @@ def call_event_handler(
Otherwise, the event handler will be called with no args. Otherwise, the event handler will be called with no args.
Args: Args:
event_handler: The event handler. event_callback: The event handler.
arg_spec: The lambda that define the argument(s) to pass to the event handler. event_spec: The lambda that define the argument(s) to pass to the event handler.
key: The key to pass to the event handler. key: The key to pass to the event handler.
Raises:
EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
Returns: Returns:
The event spec from calling the event handler. The event spec from calling the event handler.
# noqa: DAR401 failure # noqa: DAR401 failure
""" """
parsed_args = parse_args_spec(arg_spec) # type: ignore event_spec_args = parse_args_spec(event_spec) # type: ignore
if isinstance(event_handler, EventSpec): if isinstance(event_callback, EventSpec):
# Handle partial application of EventSpec args check_fn_match_arg_spec(
return event_handler.add_args(*parsed_args) event_callback.handler.fn,
event_spec,
provided_callback_fullspec = inspect.getfullargspec(event_handler.fn) key,
bool(event_callback.handler.state_full_name) + len(event_callback.args),
provided_callback_n_args = ( event_callback.handler.fn.__qualname__,
len(provided_callback_fullspec.args) - 1
) # subtract 1 for bound self arg
if provided_callback_n_args != len(parsed_args):
raise EventHandlerArgMismatch(
"The number of arguments accepted by "
f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
) )
# Handle partial application of EventSpec args
return event_callback.add_args(*event_spec_args)
all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec check_fn_match_arg_spec(
event_callback.fn,
event_spec,
key,
bool(event_callback.state_full_name),
event_callback.fn.__qualname__,
)
all_acceptable_specs = (
[event_spec] if not isinstance(event_spec, Sequence) else event_spec
)
event_spec_return_types = list( event_spec_return_types = list(
filter( filter(
lambda event_spec_return_type: event_spec_return_type is not None lambda event_spec_return_type: event_spec_return_type is not None
and get_origin(event_spec_return_type) is tuple, and get_origin(event_spec_return_type) is tuple,
(get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec), (
get_type_hints(arg_spec).get("return", None)
for arg_spec in all_acceptable_specs
),
) )
) )
if event_spec_return_types: if event_spec_return_types:
failures = [] failures = []
event_callback_spec = inspect.getfullargspec(event_callback.fn)
for event_spec_index, event_spec_return_type in enumerate( for event_spec_index, event_spec_return_type in enumerate(
event_spec_return_types event_spec_return_types
): ):
@ -1160,14 +1164,14 @@ def call_event_handler(
] ]
try: try:
type_hints_of_provided_callback = get_type_hints(event_handler.fn) type_hints_of_provided_callback = get_type_hints(event_callback.fn)
except NameError: except NameError:
type_hints_of_provided_callback = {} type_hints_of_provided_callback = {}
failed_type_check = False failed_type_check = False
# check that args of event handler are matching the spec if type hints are provided # check that args of event handler are matching the spec if type hints are provided
for i, arg in enumerate(provided_callback_fullspec.args[1:]): for i, arg in enumerate(event_callback_spec.args[1:]):
if arg not in type_hints_of_provided_callback: if arg not in type_hints_of_provided_callback:
continue continue
@ -1181,7 +1185,7 @@ def call_event_handler(
# f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." # f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
# ) from e # ) from e
console.warn( console.warn(
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}."
) )
compare_result = False compare_result = False
@ -1189,7 +1193,7 @@ def call_event_handler(
continue continue
else: else:
failure = EventHandlerArgTypeMismatch( failure = EventHandlerArgTypeMismatch(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead." f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
) )
failures.append(failure) failures.append(failure)
failed_type_check = True failed_type_check = True
@ -1210,14 +1214,14 @@ def call_event_handler(
given_string = ", ".join( given_string = ", ".join(
repr(type_hints_of_provided_callback.get(arg, Any)) repr(type_hints_of_provided_callback.get(arg, Any))
for arg in provided_callback_fullspec.args[1:] for arg in event_callback_spec.args[1:]
).replace("[", "\\[") ).replace("[", "\\[")
console.warn( console.warn(
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. " f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. "
f"This may lead to unexpected behavior but is intentionally ignored for {key}." f"This may lead to unexpected behavior but is intentionally ignored for {key}."
) )
return event_handler(*parsed_args) return event_callback(*event_spec_args)
if failures: if failures:
console.deprecate( console.deprecate(
@ -1227,7 +1231,7 @@ def call_event_handler(
"0.7.0", "0.7.0",
) )
return event_handler(*parsed_args) # type: ignore return event_callback(*event_spec_args) # type: ignore
def unwrap_var_annotation(annotation: GenericType): def unwrap_var_annotation(annotation: GenericType):
@ -1294,45 +1298,46 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
def check_fn_match_arg_spec( def check_fn_match_arg_spec(
fn: Callable, user_func: Callable,
arg_spec: ArgsSpec, arg_spec: ArgsSpec | Sequence[ArgsSpec],
key: Optional[str] = None, key: str | None = None,
) -> List[Var]: number_of_bound_args: int = 0,
func_name: str | None = None,
):
"""Ensures that the function signature matches the passed argument specification """Ensures that the function signature matches the passed argument specification
or raises an EventFnArgMismatch if they do not. or raises an EventFnArgMismatch if they do not.
Args: Args:
fn: The function to be validated. user_func: The function to be validated.
arg_spec: The argument specification for the event trigger. arg_spec: The argument specification for the event trigger.
key: The key to pass to the event handler. key: The key of the event trigger.
number_of_bound_args: The number of bound arguments to the function.
Returns: func_name: The name of the function to be validated.
The parsed arguments from the argument specification.
Raises: Raises:
EventFnArgMismatch: Raised if the number of mandatory arguments do not match EventFnArgMismatch: Raised if the number of mandatory arguments do not match
""" """
fn_args = inspect.getfullargspec(fn).args user_args = inspect.getfullargspec(user_func).args
fn_defaults_args = inspect.getfullargspec(fn).defaults user_default_args = inspect.getfullargspec(user_func).defaults
n_fn_args = len(fn_args) number_of_user_args = len(user_args) - number_of_bound_args
n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0 number_of_user_default_args = len(user_default_args) if user_default_args else 0
if isinstance(fn, types.MethodType):
n_fn_args -= 1 # subtract 1 for bound self arg parsed_event_args = parse_args_spec(arg_spec)
parsed_args = parse_args_spec(arg_spec)
if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args): number_of_event_args = len(parsed_event_args)
if number_of_user_args - number_of_user_default_args > number_of_event_args:
raise EventFnArgMismatch( raise EventFnArgMismatch(
"The number of mandatory arguments accepted by " f"Event {key} only provides {number_of_event_args} arguments, but "
f"{fn} ({n_fn_args - n_fn_defaults_args}) " f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} "
"does not match the arguments passed by the event trigger: " "arguments to be passed to the event handler.\n"
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/" "See https://reflex.dev/docs/events/event-arguments/"
) )
return parsed_args
def call_event_fn( def call_event_fn(
fn: Callable, fn: Callable,
arg_spec: ArgsSpec, arg_spec: ArgsSpec | Sequence[ArgsSpec],
key: Optional[str] = None, key: Optional[str] = None,
) -> list[EventSpec] | Var: ) -> list[EventSpec] | Var:
"""Call a function to a list of event specs. """Call a function to a list of event specs.
@ -1356,10 +1361,14 @@ def call_event_fn(
from reflex.utils.exceptions import EventHandlerValueError from reflex.utils.exceptions import EventHandlerValueError
# Check that fn signature matches arg_spec # Check that fn signature matches arg_spec
parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key) check_fn_match_arg_spec(fn, arg_spec, key=key)
parsed_args = parse_args_spec(arg_spec)
number_of_fn_args = len(inspect.getfullargspec(fn).args)
# Call the function with the parsed args. # Call the function with the parsed args.
out = fn(*parsed_args) out = fn(*[*parsed_args][:number_of_fn_args])
# If the function returns a Var, assume it's an EventChain and render it directly. # If the function returns a Var, assume it's an EventChain and render it directly.
if isinstance(out, Var): if isinstance(out, Var):
@ -1478,7 +1487,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
""" """
signature = inspect.signature(fn) signature = inspect.signature(fn)
new_param = inspect.Parameter( new_param = inspect.Parameter(
"state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any FRONTEND_EVENT_STATE, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
) )
return signature.replace(parameters=(new_param, *signature.parameters.values())) return signature.replace(parameters=(new_param, *signature.parameters.values()))

View File

@ -89,16 +89,12 @@ class MatchTypeError(ReflexError, TypeError):
"""Raised when the return types of match cases are different.""" """Raised when the return types of match cases are different."""
class EventHandlerArgMismatch(ReflexError, TypeError):
"""Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""
class EventHandlerArgTypeMismatch(ReflexError, TypeError): class EventHandlerArgTypeMismatch(ReflexError, TypeError):
"""Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger.""" """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
class EventFnArgMismatch(ReflexError, TypeError): class EventFnArgMismatch(ReflexError, TypeError):
"""Raised when the number of args accepted by a lambda differs from that provided by the event trigger.""" """Raised when the number of args required by an event handler is more than provided by the event trigger."""
class DynamicRouteArgShadowsStateVar(ReflexError, NameError): class DynamicRouteArgShadowsStateVar(ReflexError, NameError):

View File

@ -9,6 +9,7 @@ import re
from typing import TYPE_CHECKING, Any, List, Optional, Union from typing import TYPE_CHECKING, Any, List, Optional, Union
from reflex import constants from reflex import constants
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import exceptions from reflex.utils import exceptions
from reflex.utils.console import deprecate from reflex.utils.console import deprecate
@ -439,7 +440,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
from reflex.state import State from reflex.state import State
if state_full_name == "state" and name not in State.__dict__: if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__:
return ("", to_snake_case(handler.fn.__qualname__)) return ("", to_snake_case(handler.fn.__qualname__))
return (state_full_name, name) return (state_full_name, name)

View File

@ -16,7 +16,7 @@ from itertools import chain
from multiprocessing import Pool, cpu_count from multiprocessing import Pool, cpu_count
from pathlib import Path from pathlib import Path
from types import ModuleType, SimpleNamespace from types import ModuleType, SimpleNamespace
from typing import Any, Callable, Iterable, Type, get_args, get_origin from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
from reflex.components.component import Component from reflex.components.component import Component
from reflex.utils import types as rx_types from reflex.utils import types as rx_types
@ -560,7 +560,7 @@ def _generate_component_create_functiondef(
inspect.signature(event_specs).return_annotation inspect.signature(event_specs).return_annotation
) )
if not isinstance( if not isinstance(
event_specs := event_triggers[trigger], tuple event_specs := event_triggers[trigger], Sequence
) )
else ast.Subscript( else ast.Subscript(
ast.Name("Union"), ast.Name("Union"),

View File

@ -29,7 +29,6 @@ from reflex.style import Style
from reflex.utils import imports from reflex.utils import imports
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
EventFnArgMismatch, EventFnArgMismatch,
EventHandlerArgMismatch,
) )
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData from reflex.vars import VarData
@ -907,26 +906,14 @@ def test_invalid_event_handler_args(component2, test_state):
test_state: A test state. test_state: A test state.
""" """
# EventHandler args must match # EventHandler args must match
with pytest.raises(EventHandlerArgMismatch): with pytest.raises(EventFnArgMismatch):
component2.create(on_click=test_state.do_something_arg) component2.create(on_click=test_state.do_something_arg)
with pytest.raises(EventHandlerArgMismatch):
component2.create(on_open=test_state.do_something)
with pytest.raises(EventHandlerArgMismatch):
component2.create(on_prop_event=test_state.do_something)
# Multiple EventHandler args: all must match # Multiple EventHandler args: all must match
with pytest.raises(EventHandlerArgMismatch): with pytest.raises(EventFnArgMismatch):
component2.create( component2.create(
on_click=[test_state.do_something_arg, test_state.do_something] on_click=[test_state.do_something_arg, test_state.do_something]
) )
with pytest.raises(EventHandlerArgMismatch):
component2.create(
on_open=[test_state.do_something_arg, test_state.do_something]
)
with pytest.raises(EventHandlerArgMismatch):
component2.create(
on_prop_event=[test_state.do_something_arg, test_state.do_something]
)
# Enable when 0.7.0 happens # Enable when 0.7.0 happens
# # Event Handler types must match # # Event Handler types must match
@ -957,38 +944,19 @@ def test_invalid_event_handler_args(component2, test_state):
# lambda signature must match event trigger. # lambda signature must match event trigger.
with pytest.raises(EventFnArgMismatch): with pytest.raises(EventFnArgMismatch):
component2.create(on_click=lambda _: test_state.do_something_arg(1)) component2.create(on_click=lambda _: test_state.do_something_arg(1))
with pytest.raises(EventFnArgMismatch):
component2.create(on_open=lambda: test_state.do_something)
with pytest.raises(EventFnArgMismatch):
component2.create(on_prop_event=lambda: test_state.do_something)
# lambda returning EventHandler must match spec # lambda returning EventHandler must match spec
with pytest.raises(EventHandlerArgMismatch): with pytest.raises(EventFnArgMismatch):
component2.create(on_click=lambda: test_state.do_something_arg) component2.create(on_click=lambda: test_state.do_something_arg)
with pytest.raises(EventHandlerArgMismatch):
component2.create(on_open=lambda _: test_state.do_something)
with pytest.raises(EventHandlerArgMismatch):
component2.create(on_prop_event=lambda _: test_state.do_something)
# Mixed EventSpec and EventHandler must match spec. # Mixed EventSpec and EventHandler must match spec.
with pytest.raises(EventHandlerArgMismatch): with pytest.raises(EventFnArgMismatch):
component2.create( component2.create(
on_click=lambda: [ on_click=lambda: [
test_state.do_something_arg(1), test_state.do_something_arg(1),
test_state.do_something_arg, test_state.do_something_arg,
] ]
) )
with pytest.raises(EventHandlerArgMismatch):
component2.create(
on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
)
with pytest.raises(EventHandlerArgMismatch):
component2.create(
on_prop_event=lambda _: [
test_state.do_something_arg(1),
test_state.do_something,
]
)
def test_valid_event_handler_args(component2, test_state): def test_valid_event_handler_args(component2, test_state):
@ -1002,6 +970,10 @@ def test_valid_event_handler_args(component2, test_state):
component2.create(on_click=test_state.do_something) component2.create(on_click=test_state.do_something)
component2.create(on_click=test_state.do_something_arg(1)) component2.create(on_click=test_state.do_something_arg(1))
# Does not raise because event handlers are allowed to have less args than the spec.
component2.create(on_open=test_state.do_something)
component2.create(on_prop_event=test_state.do_something)
# Controlled event handlers should take args. # Controlled event handlers should take args.
component2.create(on_open=test_state.do_something_arg) component2.create(on_open=test_state.do_something_arg)
component2.create(on_prop_event=test_state.do_something_arg) component2.create(on_prop_event=test_state.do_something_arg)
@ -1010,10 +982,20 @@ def test_valid_event_handler_args(component2, test_state):
component2.create(on_open=test_state.do_something()) component2.create(on_open=test_state.do_something())
component2.create(on_prop_event=test_state.do_something()) component2.create(on_prop_event=test_state.do_something())
# Multiple EventHandler args: all must match
component2.create(on_open=[test_state.do_something_arg, test_state.do_something])
component2.create(
on_prop_event=[test_state.do_something_arg, test_state.do_something]
)
# lambda returning EventHandler is okay if the spec matches. # lambda returning EventHandler is okay if the spec matches.
component2.create(on_click=lambda: test_state.do_something) component2.create(on_click=lambda: test_state.do_something)
component2.create(on_open=lambda _: test_state.do_something_arg) component2.create(on_open=lambda _: test_state.do_something_arg)
component2.create(on_prop_event=lambda _: test_state.do_something_arg) component2.create(on_prop_event=lambda _: test_state.do_something_arg)
component2.create(on_open=lambda: test_state.do_something)
component2.create(on_prop_event=lambda: test_state.do_something)
component2.create(on_open=lambda _: test_state.do_something)
component2.create(on_prop_event=lambda _: test_state.do_something)
# lambda can always return an EventSpec. # lambda can always return an EventSpec.
component2.create(on_click=lambda: test_state.do_something_arg(1)) component2.create(on_click=lambda: test_state.do_something_arg(1))
@ -1046,6 +1028,15 @@ def test_valid_event_handler_args(component2, test_state):
component2.create( component2.create(
on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()] on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()]
) )
component2.create(
on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
)
component2.create(
on_prop_event=lambda _: [
test_state.do_something_arg(1),
test_state.do_something,
]
)
def test_get_hooks_nested(component1, component2, component3): def test_get_hooks_nested(component1, component2, component3):

View File

@ -107,7 +107,7 @@ def test_call_event_handler_partial():
def spec(a2: Var[str]) -> List[Var[str]]: def spec(a2: Var[str]) -> List[Var[str]]:
return [a2] return [a2]
handler = EventHandler(fn=test_fn_with_args) handler = EventHandler(fn=test_fn_with_args, state_full_name="BigState")
event_spec = handler(make_var("first")) event_spec = handler(make_var("first"))
event_spec2 = call_event_handler(event_spec, spec) event_spec2 = call_event_handler(event_spec, spec)
@ -115,7 +115,10 @@ def test_call_event_handler_partial():
assert len(event_spec.args) == 1 assert len(event_spec.args) == 1
assert event_spec.args[0][0].equals(Var(_js_expr="arg1")) assert event_spec.args[0][0].equals(Var(_js_expr="arg1"))
assert event_spec.args[0][1].equals(Var(_js_expr="first")) assert event_spec.args[0][1].equals(Var(_js_expr="first"))
assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})' assert (
format.format_event(event_spec)
== 'Event("BigState.test_fn_with_args", {arg1:first})'
)
assert event_spec2 is not event_spec assert event_spec2 is not event_spec
assert event_spec2.handler == handler assert event_spec2.handler == handler
@ -126,7 +129,7 @@ def test_call_event_handler_partial():
assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str)) assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str))
assert ( assert (
format.format_event(event_spec2) format.format_event(event_spec2)
== 'Event("test_fn_with_args", {arg1:first,arg2:_a2})' == 'Event("BigState.test_fn_with_args", {arg1:first,arg2:_a2})'
) )