From c636c91c9ca6317e7b5d1b0aec8963df2db4653a Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 May 2024 17:13:55 -0700 Subject: [PATCH] [REF-2273] Implement .setvar special EventHandler (#3163) * Allow EventHandler args to be partially applied When an EventHandler is called with an incomplete set of args it creates a partial EventSpec. This change allows Component._create_event_chain to apply remaining args from an args_spec to an existing EventSpec to make it functional. Instead of requiring the use of `lambda` functions to pass arguments to an EventHandler, they can now be passed directly and any remaining args defined in the event trigger will be applied after those. * [REF-2273] Implement `.setvar` special EventHandler All State subclasses will now have a special `setvar` EventHandler which appears in the autocomplete drop down, passes static analysis, and canbe used to set State Vars in response to event triggers. Before: rx.input(value=State.a, on_change=State.set_a) After: rx.input(value=State.a, on_change=State.setvar("a")) This reduces the "magic" because `setvar` is statically defined on all State subclasses. * Catch invalid Var names and types at compile time * Add test cases for State.setvar * Use a proper redis-compatible token --- reflex/components/component.py | 5 +-- reflex/event.py | 71 +++++++++++++++++++++------------- reflex/state.py | 68 ++++++++++++++++++++++++++++++++ tests/test_event.py | 37 +++++++++++++++++- tests/test_state.py | 38 ++++++++++++++++++ 5 files changed, 187 insertions(+), 32 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 3442d8cb8..463cc9c56 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -506,7 +506,7 @@ class Component(BaseComponent, ABC): if isinstance(value, List): events: list[EventSpec] = [] for v in value: - if isinstance(v, EventHandler): + if isinstance(v, (EventHandler, EventSpec)): # Call the event handler to get the event. try: event = call_event_handler(v, args_spec) @@ -517,9 +517,6 @@ class Component(BaseComponent, ABC): # Add the event to the chain. events.append(event) - elif isinstance(v, EventSpec): - # Add the event to the chain. - events.append(v) elif isinstance(v, Callable): # Call the lambda to get the event chain. events.extend(call_event_fn(v, args_spec)) diff --git a/reflex/event.py b/reflex/event.py index 2604cc202..ece101582 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -18,7 +18,7 @@ from typing import ( from reflex import constants from reflex.base import Base -from reflex.utils import console, format +from reflex.utils import format from reflex.utils.types import ArgsSpec from reflex.vars import BaseVar, Var @@ -168,7 +168,7 @@ class EventHandler(EventActionsMixin): """ return getattr(self.fn, BACKGROUND_TASK_MARKER, False) - def __call__(self, *args: Var) -> EventSpec: + def __call__(self, *args: Any) -> EventSpec: """Pass arguments to the handler to get an event spec. This method configures event handlers that take in arguments. @@ -246,6 +246,34 @@ class EventSpec(EventActionsMixin): event_actions=self.event_actions.copy(), ) + def add_args(self, *args: Var) -> EventSpec: + """Add arguments to the event spec. + + Args: + *args: The arguments to add positionally. + + Returns: + The event spec with the new arguments. + + Raises: + TypeError: If the arguments are invalid. + """ + # Get the remaining unfilled function args. + fn_args = inspect.getfullargspec(self.handler.fn).args[1 + len(self.args) :] + fn_args = (Var.create_safe(arg) for arg in fn_args) + + # Construct the payload. + values = [] + for arg in args: + try: + values.append(Var.create(arg, _var_is_string=isinstance(arg, str))) + except TypeError as e: + raise TypeError( + f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}." + ) from e + new_payload = tuple(zip(fn_args, values)) + return self.with_args(self.args + new_payload) + class CallableEventSpec(EventSpec): """Decorate an EventSpec-returning function to act as both a EventSpec and a function. @@ -732,7 +760,8 @@ def get_hydrate_event(state) -> str: def call_event_handler( - event_handler: EventHandler, arg_spec: Union[Var, ArgsSpec] + event_handler: EventHandler | EventSpec, + arg_spec: ArgsSpec, ) -> EventSpec: """Call an event handler to get the event spec. @@ -750,33 +779,21 @@ def call_event_handler( Returns: The event spec from calling the event handler. """ + parsed_args = parse_args_spec(arg_spec) # type: ignore + + if isinstance(event_handler, EventSpec): + # Handle partial application of EventSpec args + return event_handler.add_args(*parsed_args) + args = inspect.getfullargspec(event_handler.fn).args - - # handle new API using lambda to define triggers - if isinstance(arg_spec, ArgsSpec): - parsed_args = parse_args_spec(arg_spec) # type: ignore - - if len(args) == len(["self", *parsed_args]): - return event_handler(*parsed_args) # type: ignore - else: - source = inspect.getsource(arg_spec) # type: ignore - raise ValueError( - f"number of arguments in {event_handler.fn.__qualname__} " - f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'" - ) + if len(args) == len(["self", *parsed_args]): + return event_handler(*parsed_args) # type: ignore else: - console.deprecate( - feature_name="EVENT_ARG API for triggers", - reason="Replaced by new API using lambda allow arbitrary number of args", - deprecation_version="0.2.8", - removal_version="0.5.0", + source = inspect.getsource(arg_spec) # type: ignore + raise ValueError( + f"number of arguments in {event_handler.fn.__qualname__} " + f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'" ) - if len(args) == 1: - return event_handler() - assert ( - len(args) == 2 - ), f"Event handler {event_handler.fn} must have 1 or 2 arguments." - return event_handler(arg_spec) # type: ignore def parse_args_spec(arg_spec: ArgsSpec): diff --git a/reflex/state.py b/reflex/state.py index 0cc5d837f..26dfed93e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -247,6 +247,60 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]: return token, state_name +class EventHandlerSetVar(EventHandler): + """A special event handler to wrap setvar functionality.""" + + state_cls: Type[BaseState] + + def __init__(self, state_cls: Type[BaseState]): + """Initialize the EventHandlerSetVar. + + Args: + state_cls: The state class that vars will be set on. + """ + super().__init__( + fn=type(self).setvar, + state_full_name=state_cls.get_full_name(), + state_cls=state_cls, # type: ignore + ) + + def setvar(self, var_name: str, value: Any): + """Set the state variable to the value of the event. + + Note: `self` here will be an instance of the state, not EventHandlerSetVar. + + Args: + var_name: The name of the variable to set. + value: The value to set the variable to. + """ + getattr(self, constants.SETTER_PREFIX + var_name)(value) + + def __call__(self, *args: Any) -> EventSpec: + """Performs pre-checks and munging on the provided args that will become an EventSpec. + + Args: + *args: The event args. + + Returns: + The (partial) EventSpec that will be used to create the event to setvar. + + Raises: + AttributeError: If the given Var name does not exist on the state. + ValueError: If the given Var name is not a str + """ + if args: + if not isinstance(args[0], str): + raise ValueError( + f"Var name must be passed as a string, got {args[0]!r}" + ) + # Check that the requested Var setter exists on the State at compile time. + if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None: + raise AttributeError( + f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`" + ) + return super().__call__(*args) + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -310,6 +364,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Whether the state has ever been touched since instantiation. _was_touched: bool = False + # A special event handler for setting base vars. + setvar: ClassVar[EventHandler] + def __init__( self, *args, @@ -500,6 +557,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): value.__qualname__ = f"{cls.__name__}.{name}" events[name] = value + # Create the setvar event handler for this state + cls._create_setvar() + for name, fn in events.items(): handler = cls._create_event_handler(fn) cls.event_handlers[name] = handler @@ -833,6 +893,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """ return EventHandler(fn=fn, state_full_name=cls.get_full_name()) + @classmethod + def _create_setvar(cls): + """Create the setvar method for the state.""" + cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls) + @classmethod def _create_setter(cls, prop: BaseVar): """Create a setter for the var. @@ -1800,6 +1865,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return state +EventHandlerSetVar.update_forward_refs() + + class State(BaseState): """The app Base State.""" diff --git a/tests/test_event.py b/tests/test_event.py index 5915baf12..885263157 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -1,9 +1,10 @@ import json +from typing import List import pytest from reflex import event -from reflex.event import Event, EventHandler, EventSpec, fix_events +from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events from reflex.state import BaseState from reflex.utils import format from reflex.vars import Var @@ -91,6 +92,40 @@ def test_call_event_handler(): handler(test_fn) # type: ignore +def test_call_event_handler_partial(): + """Calling an EventHandler with incomplete args returns an EventSpec that can be extended.""" + + def test_fn_with_args(_, arg1, arg2): + pass + + test_fn_with_args.__qualname__ = "test_fn_with_args" + + def spec(a2: str) -> List[str]: + return [a2] + + handler = EventHandler(fn=test_fn_with_args) + event_spec = handler(make_var("first")) + event_spec2 = call_event_handler(event_spec, spec) + + assert event_spec.handler == handler + assert len(event_spec.args) == 1 + assert event_spec.args[0][0].equals(Var.create_safe("arg1")) + assert event_spec.args[0][1].equals(Var.create_safe("first")) + assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})' + + assert event_spec2 is not event_spec + assert event_spec2.handler == handler + assert len(event_spec2.args) == 2 + assert event_spec2.args[0][0].equals(Var.create_safe("arg1")) + assert event_spec2.args[0][1].equals(Var.create_safe("first")) + assert event_spec2.args[1][0].equals(Var.create_safe("arg2")) + assert event_spec2.args[1][1].equals(Var.create_safe("_a2")) + assert ( + format.format_event(event_spec2) + == 'Event("test_fn_with_args", {arg1:first,arg2:_a2})' + ) + + @pytest.mark.parametrize( ("arg1", "arg2"), ( diff --git a/tests/test_state.py b/tests/test_state.py index 23fa1fa75..ce62e9c64 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -2845,3 +2845,41 @@ def test_potentially_dirty_substates(): assert RxState._potentially_dirty_substates() == {State} assert State._potentially_dirty_substates() == {C1} assert C1._potentially_dirty_substates() == set() + + +@pytest.mark.asyncio +async def test_setvar(mock_app: rx.App, token: str): + """Test that setvar works correctly. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + state = await mock_app.state_manager.get_state(_substate_key(token, TestState)) + + # Set Var in same state (with Var type casting) + for event in rx.event.fix_events( + [TestState.setvar("num1", 42), TestState.setvar("num2", "4.2")], token + ): + async for update in state._process(event): + print(update) + assert state.num1 == 42 + assert state.num2 == 4.2 + + # Set Var in parent state + for event in rx.event.fix_events([GrandchildState.setvar("array", [43])], token): + async for update in state._process(event): + print(update) + assert state.array == [43] + + # Cannot setvar for non-existant var + with pytest.raises(AttributeError): + TestState.setvar("non_existant_var") + + # Cannot setvar for computed vars + with pytest.raises(AttributeError): + TestState.setvar("sum") + + # Cannot setvar with non-string + with pytest.raises(ValueError): + TestState.setvar(42, 42)