[REF-2273] Implement .setvar special EventHandler (#3163)

* Allow EventHandler args to be partially applied

When an EventHandler is called with an incomplete set of args it creates a
partial EventSpec. This change allows Component._create_event_chain to apply
remaining args from an args_spec to an existing EventSpec to make it
functional.

Instead of requiring the use of `lambda` functions to pass arguments to an
EventHandler, they can now be passed directly and any remaining args defined in
the event trigger will be applied after those.

* [REF-2273] Implement `.setvar` special EventHandler

All State subclasses will now have a special `setvar` EventHandler which
appears in the autocomplete drop down, passes static analysis, and canbe used
to set State Vars in response to event triggers.

Before:
    rx.input(value=State.a, on_change=State.set_a)

After:
    rx.input(value=State.a, on_change=State.setvar("a"))

This reduces the "magic" because `setvar` is statically defined on all State
subclasses.

* Catch invalid Var names and types at compile time

* Add test cases for State.setvar

* Use a proper redis-compatible token
This commit is contained in:
Masen Furer 2024-05-01 17:13:55 -07:00 committed by GitHub
parent 9ead091fec
commit c636c91c9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 187 additions and 32 deletions

View File

@ -506,7 +506,7 @@ class Component(BaseComponent, ABC):
if isinstance(value, List):
events: list[EventSpec] = []
for v in value:
if isinstance(v, EventHandler):
if isinstance(v, (EventHandler, EventSpec)):
# Call the event handler to get the event.
try:
event = call_event_handler(v, args_spec)
@ -517,9 +517,6 @@ class Component(BaseComponent, ABC):
# Add the event to the chain.
events.append(event)
elif isinstance(v, EventSpec):
# Add the event to the chain.
events.append(v)
elif isinstance(v, Callable):
# Call the lambda to get the event chain.
events.extend(call_event_fn(v, args_spec))

View File

@ -18,7 +18,7 @@ from typing import (
from reflex import constants
from reflex.base import Base
from reflex.utils import console, format
from reflex.utils import format
from reflex.utils.types import ArgsSpec
from reflex.vars import BaseVar, Var
@ -168,7 +168,7 @@ class EventHandler(EventActionsMixin):
"""
return getattr(self.fn, BACKGROUND_TASK_MARKER, False)
def __call__(self, *args: Var) -> EventSpec:
def __call__(self, *args: Any) -> EventSpec:
"""Pass arguments to the handler to get an event spec.
This method configures event handlers that take in arguments.
@ -246,6 +246,34 @@ class EventSpec(EventActionsMixin):
event_actions=self.event_actions.copy(),
)
def add_args(self, *args: Var) -> EventSpec:
"""Add arguments to the event spec.
Args:
*args: The arguments to add positionally.
Returns:
The event spec with the new arguments.
Raises:
TypeError: If the arguments are invalid.
"""
# Get the remaining unfilled function args.
fn_args = inspect.getfullargspec(self.handler.fn).args[1 + len(self.args) :]
fn_args = (Var.create_safe(arg) for arg in fn_args)
# Construct the payload.
values = []
for arg in args:
try:
values.append(Var.create(arg, _var_is_string=isinstance(arg, str)))
except TypeError as e:
raise TypeError(
f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}."
) from e
new_payload = tuple(zip(fn_args, values))
return self.with_args(self.args + new_payload)
class CallableEventSpec(EventSpec):
"""Decorate an EventSpec-returning function to act as both a EventSpec and a function.
@ -732,7 +760,8 @@ def get_hydrate_event(state) -> str:
def call_event_handler(
event_handler: EventHandler, arg_spec: Union[Var, ArgsSpec]
event_handler: EventHandler | EventSpec,
arg_spec: ArgsSpec,
) -> EventSpec:
"""Call an event handler to get the event spec.
@ -750,33 +779,21 @@ def call_event_handler(
Returns:
The event spec from calling the event handler.
"""
parsed_args = parse_args_spec(arg_spec) # type: ignore
if isinstance(event_handler, EventSpec):
# Handle partial application of EventSpec args
return event_handler.add_args(*parsed_args)
args = inspect.getfullargspec(event_handler.fn).args
# handle new API using lambda to define triggers
if isinstance(arg_spec, ArgsSpec):
parsed_args = parse_args_spec(arg_spec) # type: ignore
if len(args) == len(["self", *parsed_args]):
return event_handler(*parsed_args) # type: ignore
else:
source = inspect.getsource(arg_spec) # type: ignore
raise ValueError(
f"number of arguments in {event_handler.fn.__qualname__} "
f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'"
)
if len(args) == len(["self", *parsed_args]):
return event_handler(*parsed_args) # type: ignore
else:
console.deprecate(
feature_name="EVENT_ARG API for triggers",
reason="Replaced by new API using lambda allow arbitrary number of args",
deprecation_version="0.2.8",
removal_version="0.5.0",
source = inspect.getsource(arg_spec) # type: ignore
raise ValueError(
f"number of arguments in {event_handler.fn.__qualname__} "
f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'"
)
if len(args) == 1:
return event_handler()
assert (
len(args) == 2
), f"Event handler {event_handler.fn} must have 1 or 2 arguments."
return event_handler(arg_spec) # type: ignore
def parse_args_spec(arg_spec: ArgsSpec):

View File

@ -247,6 +247,60 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]:
return token, state_name
class EventHandlerSetVar(EventHandler):
"""A special event handler to wrap setvar functionality."""
state_cls: Type[BaseState]
def __init__(self, state_cls: Type[BaseState]):
"""Initialize the EventHandlerSetVar.
Args:
state_cls: The state class that vars will be set on.
"""
super().__init__(
fn=type(self).setvar,
state_full_name=state_cls.get_full_name(),
state_cls=state_cls, # type: ignore
)
def setvar(self, var_name: str, value: Any):
"""Set the state variable to the value of the event.
Note: `self` here will be an instance of the state, not EventHandlerSetVar.
Args:
var_name: The name of the variable to set.
value: The value to set the variable to.
"""
getattr(self, constants.SETTER_PREFIX + var_name)(value)
def __call__(self, *args: Any) -> EventSpec:
"""Performs pre-checks and munging on the provided args that will become an EventSpec.
Args:
*args: The event args.
Returns:
The (partial) EventSpec that will be used to create the event to setvar.
Raises:
AttributeError: If the given Var name does not exist on the state.
ValueError: If the given Var name is not a str
"""
if args:
if not isinstance(args[0], str):
raise ValueError(
f"Var name must be passed as a string, got {args[0]!r}"
)
# Check that the requested Var setter exists on the State at compile time.
if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None:
raise AttributeError(
f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`"
)
return super().__call__(*args)
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""
@ -310,6 +364,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Whether the state has ever been touched since instantiation.
_was_touched: bool = False
# A special event handler for setting base vars.
setvar: ClassVar[EventHandler]
def __init__(
self,
*args,
@ -500,6 +557,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
value.__qualname__ = f"{cls.__name__}.{name}"
events[name] = value
# Create the setvar event handler for this state
cls._create_setvar()
for name, fn in events.items():
handler = cls._create_event_handler(fn)
cls.event_handlers[name] = handler
@ -833,6 +893,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""
return EventHandler(fn=fn, state_full_name=cls.get_full_name())
@classmethod
def _create_setvar(cls):
"""Create the setvar method for the state."""
cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls)
@classmethod
def _create_setter(cls, prop: BaseVar):
"""Create a setter for the var.
@ -1800,6 +1865,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return state
EventHandlerSetVar.update_forward_refs()
class State(BaseState):
"""The app Base State."""

View File

@ -1,9 +1,10 @@
import json
from typing import List
import pytest
from reflex import event
from reflex.event import Event, EventHandler, EventSpec, fix_events
from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events
from reflex.state import BaseState
from reflex.utils import format
from reflex.vars import Var
@ -91,6 +92,40 @@ def test_call_event_handler():
handler(test_fn) # type: ignore
def test_call_event_handler_partial():
"""Calling an EventHandler with incomplete args returns an EventSpec that can be extended."""
def test_fn_with_args(_, arg1, arg2):
pass
test_fn_with_args.__qualname__ = "test_fn_with_args"
def spec(a2: str) -> List[str]:
return [a2]
handler = EventHandler(fn=test_fn_with_args)
event_spec = handler(make_var("first"))
event_spec2 = call_event_handler(event_spec, spec)
assert event_spec.handler == handler
assert len(event_spec.args) == 1
assert event_spec.args[0][0].equals(Var.create_safe("arg1"))
assert event_spec.args[0][1].equals(Var.create_safe("first"))
assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})'
assert event_spec2 is not event_spec
assert event_spec2.handler == handler
assert len(event_spec2.args) == 2
assert event_spec2.args[0][0].equals(Var.create_safe("arg1"))
assert event_spec2.args[0][1].equals(Var.create_safe("first"))
assert event_spec2.args[1][0].equals(Var.create_safe("arg2"))
assert event_spec2.args[1][1].equals(Var.create_safe("_a2"))
assert (
format.format_event(event_spec2)
== 'Event("test_fn_with_args", {arg1:first,arg2:_a2})'
)
@pytest.mark.parametrize(
("arg1", "arg2"),
(

View File

@ -2845,3 +2845,41 @@ def test_potentially_dirty_substates():
assert RxState._potentially_dirty_substates() == {State}
assert State._potentially_dirty_substates() == {C1}
assert C1._potentially_dirty_substates() == set()
@pytest.mark.asyncio
async def test_setvar(mock_app: rx.App, token: str):
"""Test that setvar works correctly.
Args:
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
state = await mock_app.state_manager.get_state(_substate_key(token, TestState))
# Set Var in same state (with Var type casting)
for event in rx.event.fix_events(
[TestState.setvar("num1", 42), TestState.setvar("num2", "4.2")], token
):
async for update in state._process(event):
print(update)
assert state.num1 == 42
assert state.num2 == 4.2
# Set Var in parent state
for event in rx.event.fix_events([GrandchildState.setvar("array", [43])], token):
async for update in state._process(event):
print(update)
assert state.array == [43]
# Cannot setvar for non-existant var
with pytest.raises(AttributeError):
TestState.setvar("non_existant_var")
# Cannot setvar for computed vars
with pytest.raises(AttributeError):
TestState.setvar("sum")
# Cannot setvar with non-string
with pytest.raises(ValueError):
TestState.setvar(42, 42)