diff --git a/reflex/components/component.py b/reflex/components/component.py index 470ba1145..9a9e49a9d 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -17,6 +17,7 @@ from typing import ( Iterator, List, Optional, + Sequence, Set, Type, Union, @@ -38,6 +39,7 @@ from reflex.constants import ( PageNames, ) from reflex.constants.compiler import SpecialAttributes +from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.event import ( EventCallback, EventChain, @@ -533,7 +535,7 @@ class Component(BaseComponent, ABC): def _create_event_chain( self, - args_spec: Any, + args_spec: types.ArgsSpec | Sequence[types.ArgsSpec], value: Union[ Var, EventHandler, @@ -599,7 +601,7 @@ class Component(BaseComponent, ABC): # If the input is a callable, create an event chain. elif isinstance(value, Callable): - result = call_event_fn(value, args_spec) + result = call_event_fn(value, args_spec, key=key) if isinstance(result, Var): # Recursively call this function if the lambda returned an EventChain Var. return self._create_event_chain(args_spec, result, key=key) @@ -629,14 +631,16 @@ class Component(BaseComponent, ABC): event_actions={}, ) - def get_event_triggers(self) -> Dict[str, Any]: + def get_event_triggers( + self, + ) -> Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]]: """Get the event triggers for the component. Returns: The event triggers. """ - default_triggers = { + default_triggers: Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = { EventTriggers.ON_FOCUS: no_args_event_spec, EventTriggers.ON_BLUR: no_args_event_spec, EventTriggers.ON_CLICK: no_args_event_spec, @@ -1142,7 +1146,10 @@ class Component(BaseComponent, ABC): if isinstance(event, EventCallback): continue if isinstance(event, EventSpec): - if event.handler.state_full_name: + if ( + event.handler.state_full_name + and event.handler.state_full_name != FRONTEND_EVENT_STATE + ): return True else: if event._var_state: diff --git a/reflex/constants/state.py b/reflex/constants/state.py index aa0e2f97f..5ce7cd62a 100644 --- a/reflex/constants/state.py +++ b/reflex/constants/state.py @@ -9,3 +9,7 @@ class StateManagerMode(str, Enum): DISK = "disk" MEMORY = "memory" REDIS = "redis" + + +# Used for things like console_log, etc. +FRONTEND_EVENT_STATE = "__reflex_internal_frontend_event_state" diff --git a/reflex/event.py b/reflex/event.py index 245937a44..7d6249d14 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -28,10 +28,10 @@ from typing import ( from typing_extensions import ParamSpec, Protocol, get_args, get_origin from reflex import constants +from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.utils import console, format from reflex.utils.exceptions import ( EventFnArgMismatch, - EventHandlerArgMismatch, EventHandlerArgTypeMismatch, ) from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass @@ -662,7 +662,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec: fn.__qualname__ = name fn.__signature__ = sig return EventSpec( - handler=EventHandler(fn=fn), + handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE), args=tuple( ( Var(_js_expr=k), @@ -1092,8 +1092,8 @@ def get_hydrate_event(state) -> str: def call_event_handler( - event_handler: EventHandler | EventSpec, - arg_spec: ArgsSpec | Sequence[ArgsSpec], + event_callback: EventHandler | EventSpec, + event_spec: ArgsSpec | Sequence[ArgsSpec], key: Optional[str] = None, ) -> EventSpec: """Call an event handler to get the event spec. @@ -1103,53 +1103,57 @@ def call_event_handler( Otherwise, the event handler will be called with no args. Args: - event_handler: The event handler. - arg_spec: The lambda that define the argument(s) to pass to the event handler. + event_callback: The event handler. + event_spec: The lambda that define the argument(s) to pass to the event handler. key: The key to pass to the event handler. - Raises: - EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec. - Returns: The event spec from calling the event handler. # noqa: DAR401 failure """ - parsed_args = parse_args_spec(arg_spec) # type: ignore + event_spec_args = parse_args_spec(event_spec) # type: ignore - if isinstance(event_handler, EventSpec): - # Handle partial application of EventSpec args - return event_handler.add_args(*parsed_args) - - provided_callback_fullspec = inspect.getfullargspec(event_handler.fn) - - provided_callback_n_args = ( - len(provided_callback_fullspec.args) - 1 - ) # subtract 1 for bound self arg - - if provided_callback_n_args != len(parsed_args): - raise EventHandlerArgMismatch( - "The number of arguments accepted by " - f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) " - "does not match the arguments passed by the event trigger: " - f"{[str(v) for v in parsed_args]}\n" - "See https://reflex.dev/docs/events/event-arguments/" + if isinstance(event_callback, EventSpec): + check_fn_match_arg_spec( + event_callback.handler.fn, + event_spec, + key, + bool(event_callback.handler.state_full_name) + len(event_callback.args), + event_callback.handler.fn.__qualname__, ) + # Handle partial application of EventSpec args + return event_callback.add_args(*event_spec_args) - all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec + check_fn_match_arg_spec( + event_callback.fn, + event_spec, + key, + bool(event_callback.state_full_name), + event_callback.fn.__qualname__, + ) + + all_acceptable_specs = ( + [event_spec] if not isinstance(event_spec, Sequence) else event_spec + ) event_spec_return_types = list( filter( lambda event_spec_return_type: event_spec_return_type is not None and get_origin(event_spec_return_type) is tuple, - (get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec), + ( + get_type_hints(arg_spec).get("return", None) + for arg_spec in all_acceptable_specs + ), ) ) if event_spec_return_types: failures = [] + event_callback_spec = inspect.getfullargspec(event_callback.fn) + for event_spec_index, event_spec_return_type in enumerate( event_spec_return_types ): @@ -1160,14 +1164,14 @@ def call_event_handler( ] try: - type_hints_of_provided_callback = get_type_hints(event_handler.fn) + type_hints_of_provided_callback = get_type_hints(event_callback.fn) except NameError: type_hints_of_provided_callback = {} failed_type_check = False # check that args of event handler are matching the spec if type hints are provided - for i, arg in enumerate(provided_callback_fullspec.args[1:]): + for i, arg in enumerate(event_callback_spec.args[1:]): if arg not in type_hints_of_provided_callback: continue @@ -1181,7 +1185,7 @@ def call_event_handler( # f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." # ) from e console.warn( - f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." + f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}." ) compare_result = False @@ -1189,7 +1193,7 @@ def call_event_handler( continue else: failure = EventHandlerArgTypeMismatch( - f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead." + f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead." ) failures.append(failure) failed_type_check = True @@ -1210,14 +1214,14 @@ def call_event_handler( given_string = ", ".join( repr(type_hints_of_provided_callback.get(arg, Any)) - for arg in provided_callback_fullspec.args[1:] + for arg in event_callback_spec.args[1:] ).replace("[", "\\[") console.warn( - f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. " + f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. " f"This may lead to unexpected behavior but is intentionally ignored for {key}." ) - return event_handler(*parsed_args) + return event_callback(*event_spec_args) if failures: console.deprecate( @@ -1227,7 +1231,7 @@ def call_event_handler( "0.7.0", ) - return event_handler(*parsed_args) # type: ignore + return event_callback(*event_spec_args) # type: ignore def unwrap_var_annotation(annotation: GenericType): @@ -1294,45 +1298,46 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]): def check_fn_match_arg_spec( - fn: Callable, - arg_spec: ArgsSpec, - key: Optional[str] = None, -) -> List[Var]: + user_func: Callable, + arg_spec: ArgsSpec | Sequence[ArgsSpec], + key: str | None = None, + number_of_bound_args: int = 0, + func_name: str | None = None, +): """Ensures that the function signature matches the passed argument specification or raises an EventFnArgMismatch if they do not. Args: - fn: The function to be validated. + user_func: The function to be validated. arg_spec: The argument specification for the event trigger. - key: The key to pass to the event handler. - - Returns: - The parsed arguments from the argument specification. + key: The key of the event trigger. + number_of_bound_args: The number of bound arguments to the function. + func_name: The name of the function to be validated. Raises: EventFnArgMismatch: Raised if the number of mandatory arguments do not match """ - fn_args = inspect.getfullargspec(fn).args - fn_defaults_args = inspect.getfullargspec(fn).defaults - n_fn_args = len(fn_args) - n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0 - if isinstance(fn, types.MethodType): - n_fn_args -= 1 # subtract 1 for bound self arg - parsed_args = parse_args_spec(arg_spec) - if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args): + user_args = inspect.getfullargspec(user_func).args + user_default_args = inspect.getfullargspec(user_func).defaults + number_of_user_args = len(user_args) - number_of_bound_args + number_of_user_default_args = len(user_default_args) if user_default_args else 0 + + parsed_event_args = parse_args_spec(arg_spec) + + number_of_event_args = len(parsed_event_args) + + if number_of_user_args - number_of_user_default_args > number_of_event_args: raise EventFnArgMismatch( - "The number of mandatory arguments accepted by " - f"{fn} ({n_fn_args - n_fn_defaults_args}) " - "does not match the arguments passed by the event trigger: " - f"{[str(v) for v in parsed_args]}\n" + f"Event {key} only provides {number_of_event_args} arguments, but " + f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} " + "arguments to be passed to the event handler.\n" "See https://reflex.dev/docs/events/event-arguments/" ) - return parsed_args def call_event_fn( fn: Callable, - arg_spec: ArgsSpec, + arg_spec: ArgsSpec | Sequence[ArgsSpec], key: Optional[str] = None, ) -> list[EventSpec] | Var: """Call a function to a list of event specs. @@ -1356,10 +1361,14 @@ def call_event_fn( from reflex.utils.exceptions import EventHandlerValueError # Check that fn signature matches arg_spec - parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key) + check_fn_match_arg_spec(fn, arg_spec, key=key) + + parsed_args = parse_args_spec(arg_spec) + + number_of_fn_args = len(inspect.getfullargspec(fn).args) # Call the function with the parsed args. - out = fn(*parsed_args) + out = fn(*[*parsed_args][:number_of_fn_args]) # If the function returns a Var, assume it's an EventChain and render it directly. if isinstance(out, Var): @@ -1478,7 +1487,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature: """ signature = inspect.signature(fn) new_param = inspect.Parameter( - "state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any + FRONTEND_EVENT_STATE, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any ) return signature.replace(parameters=(new_param, *signature.parameters.values())) diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 661f29095..dc25a09e0 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -89,16 +89,12 @@ class MatchTypeError(ReflexError, TypeError): """Raised when the return types of match cases are different.""" -class EventHandlerArgMismatch(ReflexError, TypeError): - """Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger.""" - - class EventHandlerArgTypeMismatch(ReflexError, TypeError): """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger.""" class EventFnArgMismatch(ReflexError, TypeError): - """Raised when the number of args accepted by a lambda differs from that provided by the event trigger.""" + """Raised when the number of args required by an event handler is more than provided by the event trigger.""" class DynamicRouteArgShadowsStateVar(ReflexError, NameError): diff --git a/reflex/utils/format.py b/reflex/utils/format.py index c4fbff20b..1b3d1740f 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -9,6 +9,7 @@ import re from typing import TYPE_CHECKING, Any, List, Optional, Union from reflex import constants +from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.utils import exceptions from reflex.utils.console import deprecate @@ -439,7 +440,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: from reflex.state import State - if state_full_name == "state" and name not in State.__dict__: + if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__: return ("", to_snake_case(handler.fn.__qualname__)) return (state_full_name, name) diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 342277cad..961eea178 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -16,7 +16,7 @@ from itertools import chain from multiprocessing import Pool, cpu_count from pathlib import Path from types import ModuleType, SimpleNamespace -from typing import Any, Callable, Iterable, Type, get_args, get_origin +from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin from reflex.components.component import Component from reflex.utils import types as rx_types @@ -560,7 +560,7 @@ def _generate_component_create_functiondef( inspect.signature(event_specs).return_annotation ) if not isinstance( - event_specs := event_triggers[trigger], tuple + event_specs := event_triggers[trigger], Sequence ) else ast.Subscript( ast.Name("Union"), diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 0574c007b..f2c0d50e9 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -29,7 +29,6 @@ from reflex.style import Style from reflex.utils import imports from reflex.utils.exceptions import ( EventFnArgMismatch, - EventHandlerArgMismatch, ) from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData @@ -907,26 +906,14 @@ def test_invalid_event_handler_args(component2, test_state): test_state: A test state. """ # EventHandler args must match - with pytest.raises(EventHandlerArgMismatch): + with pytest.raises(EventFnArgMismatch): component2.create(on_click=test_state.do_something_arg) - with pytest.raises(EventHandlerArgMismatch): - component2.create(on_open=test_state.do_something) - with pytest.raises(EventHandlerArgMismatch): - component2.create(on_prop_event=test_state.do_something) # Multiple EventHandler args: all must match - with pytest.raises(EventHandlerArgMismatch): + with pytest.raises(EventFnArgMismatch): component2.create( on_click=[test_state.do_something_arg, test_state.do_something] ) - with pytest.raises(EventHandlerArgMismatch): - component2.create( - on_open=[test_state.do_something_arg, test_state.do_something] - ) - with pytest.raises(EventHandlerArgMismatch): - component2.create( - on_prop_event=[test_state.do_something_arg, test_state.do_something] - ) # Enable when 0.7.0 happens # # Event Handler types must match @@ -957,38 +944,19 @@ def test_invalid_event_handler_args(component2, test_state): # lambda signature must match event trigger. with pytest.raises(EventFnArgMismatch): component2.create(on_click=lambda _: test_state.do_something_arg(1)) - with pytest.raises(EventFnArgMismatch): - component2.create(on_open=lambda: test_state.do_something) - with pytest.raises(EventFnArgMismatch): - component2.create(on_prop_event=lambda: test_state.do_something) # lambda returning EventHandler must match spec - with pytest.raises(EventHandlerArgMismatch): + with pytest.raises(EventFnArgMismatch): component2.create(on_click=lambda: test_state.do_something_arg) - with pytest.raises(EventHandlerArgMismatch): - component2.create(on_open=lambda _: test_state.do_something) - with pytest.raises(EventHandlerArgMismatch): - component2.create(on_prop_event=lambda _: test_state.do_something) # Mixed EventSpec and EventHandler must match spec. - with pytest.raises(EventHandlerArgMismatch): + with pytest.raises(EventFnArgMismatch): component2.create( on_click=lambda: [ test_state.do_something_arg(1), test_state.do_something_arg, ] ) - with pytest.raises(EventHandlerArgMismatch): - component2.create( - on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something] - ) - with pytest.raises(EventHandlerArgMismatch): - component2.create( - on_prop_event=lambda _: [ - test_state.do_something_arg(1), - test_state.do_something, - ] - ) def test_valid_event_handler_args(component2, test_state): @@ -1002,6 +970,10 @@ def test_valid_event_handler_args(component2, test_state): component2.create(on_click=test_state.do_something) component2.create(on_click=test_state.do_something_arg(1)) + # Does not raise because event handlers are allowed to have less args than the spec. + component2.create(on_open=test_state.do_something) + component2.create(on_prop_event=test_state.do_something) + # Controlled event handlers should take args. component2.create(on_open=test_state.do_something_arg) component2.create(on_prop_event=test_state.do_something_arg) @@ -1010,10 +982,20 @@ def test_valid_event_handler_args(component2, test_state): component2.create(on_open=test_state.do_something()) component2.create(on_prop_event=test_state.do_something()) + # Multiple EventHandler args: all must match + component2.create(on_open=[test_state.do_something_arg, test_state.do_something]) + component2.create( + on_prop_event=[test_state.do_something_arg, test_state.do_something] + ) + # lambda returning EventHandler is okay if the spec matches. component2.create(on_click=lambda: test_state.do_something) component2.create(on_open=lambda _: test_state.do_something_arg) component2.create(on_prop_event=lambda _: test_state.do_something_arg) + component2.create(on_open=lambda: test_state.do_something) + component2.create(on_prop_event=lambda: test_state.do_something) + component2.create(on_open=lambda _: test_state.do_something) + component2.create(on_prop_event=lambda _: test_state.do_something) # lambda can always return an EventSpec. component2.create(on_click=lambda: test_state.do_something_arg(1)) @@ -1046,6 +1028,15 @@ def test_valid_event_handler_args(component2, test_state): component2.create( on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()] ) + component2.create( + on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something] + ) + component2.create( + on_prop_event=lambda _: [ + test_state.do_something_arg(1), + test_state.do_something, + ] + ) def test_get_hooks_nested(component1, component2, component3): diff --git a/tests/units/test_event.py b/tests/units/test_event.py index 5cefa5883..5e26da5d8 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -107,7 +107,7 @@ def test_call_event_handler_partial(): def spec(a2: Var[str]) -> List[Var[str]]: return [a2] - handler = EventHandler(fn=test_fn_with_args) + handler = EventHandler(fn=test_fn_with_args, state_full_name="BigState") event_spec = handler(make_var("first")) event_spec2 = call_event_handler(event_spec, spec) @@ -115,7 +115,10 @@ def test_call_event_handler_partial(): assert len(event_spec.args) == 1 assert event_spec.args[0][0].equals(Var(_js_expr="arg1")) assert event_spec.args[0][1].equals(Var(_js_expr="first")) - assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})' + assert ( + format.format_event(event_spec) + == 'Event("BigState.test_fn_with_args", {arg1:first})' + ) assert event_spec2 is not event_spec assert event_spec2.handler == handler @@ -126,7 +129,7 @@ def test_call_event_handler_partial(): assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str)) assert ( format.format_event(event_spec2) - == 'Event("test_fn_with_args", {arg1:first,arg2:_a2})' + == 'Event("BigState.test_fn_with_args", {arg1:first,arg2:_a2})' )