From 73e8a4e0abfd2fb63fae10b69a9af830bf7a5c93 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 3 Oct 2024 15:33:51 -0700 Subject: [PATCH] support eventspec/eventchain in var operations (#4038) --- reflex/.templates/web/utils/state.js | 22 ++- reflex/app.py | 4 +- reflex/components/component.py | 50 ++++-- reflex/event.py | 193 ++++++++++++++++++++- reflex/utils/format.py | 14 +- reflex/vars/base.py | 81 +++++---- tests/units/components/base/test_script.py | 6 +- tests/units/components/test_component.py | 8 +- tests/units/utils/test_format.py | 22 ++- 9 files changed, 310 insertions(+), 90 deletions(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 78e671809..0fe0db8c1 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -544,13 +544,19 @@ export const uploadFiles = async ( /** * Create an event object. - * @param name The name of the event. - * @param payload The payload of the event. - * @param handler The client handler to process event. + * @param {string} name The name of the event. + * @param {Object.} payload The payload of the event. + * @param {Object.} event_actions The actions to take on the event. + * @param {string} handler The client handler to process event. * @returns The event object. */ -export const Event = (name, payload = {}, handler = null) => { - return { name, payload, handler }; +export const Event = ( + name, + payload = {}, + event_actions = {}, + handler = null +) => { + return { name, payload, handler, event_actions }; }; /** @@ -676,6 +682,12 @@ export const useEventLoop = ( if (!(args instanceof Array)) { args = [args]; } + + event_actions = events.reduce( + (acc, e) => ({ ...acc, ...e.event_actions }), + event_actions ?? {} + ); + const _e = args.filter((o) => o?.preventDefault !== undefined)[0]; if (event_actions?.preventDefault && _e?.preventDefault) { diff --git a/reflex/app.py b/reflex/app.py index 111dd9dfd..d8a6f2590 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1536,7 +1536,9 @@ class EventNamespace(AsyncNamespace): """ fields = json.loads(data) # Get the event. - event = Event(**{k: v for k, v in fields.items() if k != "handler"}) + event = Event( + **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} + ) self.token_to_sid[event.token] = sid self.sid_to_token[sid] = event.token diff --git a/reflex/components/component.py b/reflex/components/component.py index 9bdd12f0e..26ea2fd3f 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -38,8 +38,10 @@ from reflex.constants import ( ) from reflex.event import ( EventChain, + EventChainVar, EventHandler, EventSpec, + EventVar, call_event_fn, call_event_handler, get_handler_args, @@ -514,7 +516,7 @@ class Component(BaseComponent, ABC): Var, EventHandler, EventSpec, - List[Union[EventHandler, EventSpec]], + List[Union[EventHandler, EventSpec, EventVar]], Callable, ], ) -> Union[EventChain, Var]: @@ -532,11 +534,16 @@ class Component(BaseComponent, ABC): """ # If it's an event chain var, return it. if isinstance(value, Var): - if value._var_type is not EventChain: + if isinstance(value, EventChainVar): + return value + elif isinstance(value, EventVar): + value = [value] + elif issubclass(value._var_type, (EventChain, EventSpec)): + return self._create_event_chain(args_spec, value.guess_type()) + else: raise ValueError( - f"Invalid event chain: {repr(value)} of type {type(value)}" + f"Invalid event chain: {str(value)} of type {value._var_type}" ) - return value elif isinstance(value, EventChain): # Trust that the caller knows what they're doing passing an EventChain directly return value @@ -547,7 +554,7 @@ class Component(BaseComponent, ABC): # If the input is a list of event handlers, create an event chain. if isinstance(value, List): - events: list[EventSpec] = [] + events: List[Union[EventSpec, EventVar]] = [] for v in value: if isinstance(v, (EventHandler, EventSpec)): # Call the event handler to get the event. @@ -561,6 +568,8 @@ class Component(BaseComponent, ABC): "lambda inside an EventChain list." ) events.extend(result) + elif isinstance(v, EventVar): + events.append(v) else: raise ValueError(f"Invalid event: {v}") @@ -570,32 +579,30 @@ class Component(BaseComponent, ABC): if isinstance(result, Var): # Recursively call this function if the lambda returned an EventChain Var. return self._create_event_chain(args_spec, result) - events = result + events = [*result] # Otherwise, raise an error. else: raise ValueError(f"Invalid event chain: {value}") # Add args to the event specs if necessary. - events = [e.with_args(get_handler_args(e)) for e in events] - - # Collect event_actions from each spec - event_actions = {} - for e in events: - event_actions.update(e.event_actions) + events = [ + (e.with_args(get_handler_args(e)) if isinstance(e, EventSpec) else e) + for e in events + ] # Return the event chain. if isinstance(args_spec, Var): return EventChain( events=events, args_spec=None, - event_actions=event_actions, + event_actions={}, ) else: return EventChain( events=events, args_spec=args_spec, - event_actions=event_actions, + event_actions={}, ) def get_event_triggers(self) -> Dict[str, Any]: @@ -1030,8 +1037,11 @@ class Component(BaseComponent, ABC): elif isinstance(event, EventChain): event_args = [] for spec in event.events: - for args in spec.args: - event_args.extend(args) + if isinstance(spec, EventSpec): + for args in spec.args: + event_args.extend(args) + else: + event_args.append(spec) yield event_trigger, event_args def _get_vars(self, include_children: bool = False) -> list[Var]: @@ -1105,8 +1115,12 @@ class Component(BaseComponent, ABC): for trigger in self.event_triggers.values(): if isinstance(trigger, EventChain): for event in trigger.events: - if event.handler.state_full_name: - return True + if isinstance(event, EventSpec): + if event.handler.state_full_name: + return True + else: + if event._var_state: + return True elif isinstance(trigger, Var) and trigger._var_state: return True return False diff --git a/reflex/event.py b/reflex/event.py index 9b54eddeb..7384cf5bf 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -4,16 +4,19 @@ from __future__ import annotations import dataclasses import inspect +import sys import types import urllib.parse from base64 import b64encode from typing import ( Any, Callable, + ClassVar, Dict, List, Optional, Tuple, + Type, Union, get_type_hints, ) @@ -25,8 +28,15 @@ from reflex.utils import format from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch from reflex.utils.types import ArgsSpec, GenericType from reflex.vars import VarData -from reflex.vars.base import LiteralVar, Var -from reflex.vars.function import FunctionStringVar, FunctionVar +from reflex.vars.base import ( + CachedVarOperation, + LiteralNoneVar, + LiteralVar, + ToOperation, + Var, + cached_property_no_lock, +) +from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar from reflex.vars.object import ObjectVar try: @@ -375,7 +385,7 @@ class CallableEventSpec(EventSpec): class EventChain(EventActionsMixin): """Container for a chain of events that will be executed in order.""" - events: List[EventSpec] = dataclasses.field(default_factory=list) + events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list) args_spec: Optional[Callable] = dataclasses.field(default=None) @@ -478,7 +488,7 @@ class FileUpload: if isinstance(events, Var): raise ValueError(f"{on_upload_progress} cannot return a var {events}.") on_upload_progress_chain = EventChain( - events=events, + events=[*events], args_spec=self.on_upload_progress_args_spec, ) formatted_chain = str(format.format_prop(on_upload_progress_chain)) @@ -1136,3 +1146,178 @@ def get_fn_signature(fn: Callable) -> inspect.Signature: "state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any ) return signature.replace(parameters=(new_param, *signature.parameters.values())) + + +class EventVar(ObjectVar): + """Base class for event vars.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): + """A literal event var.""" + + _var_value: EventSpec = dataclasses.field(default=None) # type: ignore + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._js_expr)) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str( + FunctionStringVar("Event").call( + # event handler name + ".".join( + filter( + None, + format.get_event_handler_parts(self._var_value.handler), + ) + ), + # event handler args + {str(name): value for name, value in self._var_value.args}, + # event actions + self._var_value.event_actions, + # client handler name + *( + [self._var_value.client_handler_name] + if self._var_value.client_handler_name + else [] + ), + ) + ) + + @classmethod + def create( + cls, + value: EventSpec, + _var_data: VarData | None = None, + ) -> LiteralEventVar: + """Create a new LiteralEventVar instance. + + Args: + value: The value of the var. + _var_data: The data of the var. + + Returns: + The created LiteralEventVar instance. + """ + return cls( + _js_expr="", + _var_type=EventSpec, + _var_data=_var_data, + _var_value=value, + ) + + +class EventChainVar(FunctionVar): + """Base class for event chain vars.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): + """A literal event chain var.""" + + _var_value: EventChain = dataclasses.field(default=None) # type: ignore + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._js_expr)) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + sig = inspect.signature(self._var_value.args_spec) # type: ignore + if sig.parameters: + arg_def = tuple((f"_{p}" for p in sig.parameters)) + arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) + else: + # add a default argument for addEvents if none were specified in value.args_spec + # used to trigger the preventDefault() on the event. + arg_def = ("...args",) + arg_def_expr = Var(_js_expr="args") + + return str( + ArgsFunctionOperation.create( + arg_def, + FunctionStringVar.create("addEvents").call( + LiteralVar.create( + [LiteralVar.create(event) for event in self._var_value.events] + ), + arg_def_expr, + self._var_value.event_actions, + ), + ) + ) + + @classmethod + def create( + cls, + value: EventChain, + _var_data: VarData | None = None, + ) -> LiteralEventChainVar: + """Create a new LiteralEventChainVar instance. + + Args: + value: The value of the var. + _var_data: The data of the var. + + Returns: + The created LiteralEventChainVar instance. + """ + return cls( + _js_expr="", + _var_type=EventChain, + _var_data=_var_data, + _var_value=value, + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToEventVarOperation(ToOperation, EventVar): + """Result of a cast to an event var.""" + + _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) + + _default_var_type: ClassVar[Type] = EventSpec + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToEventChainVarOperation(ToOperation, EventChainVar): + """Result of a cast to an event chain var.""" + + _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) + + _default_var_type: ClassVar[Type] = EventChain diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 4029bd275..65c0f049b 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -359,19 +359,7 @@ def format_prop( # Handle event props. if isinstance(prop, EventChain): - sig = inspect.signature(prop.args_spec) # type: ignore - if sig.parameters: - arg_def = ",".join(f"_{p}" for p in sig.parameters) - arg_def_expr = f"[{arg_def}]" - else: - # add a default argument for addEvents if none were specified in prop.args_spec - # used to trigger the preventDefault() on the event. - arg_def = "...args" - arg_def_expr = "args" - - chain = ",".join([format_event(event) for event in prop.events]) - event = f"addEvents([{chain}], {arg_def_expr}, {json_dumps(prop.event_actions)})" - prop = f"({arg_def}) => {event}" + return str(Var.create(prop)) # Handle other types. elif isinstance(prop, str): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2d78a14be..0f8a80f8d 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -385,6 +385,15 @@ class Var(Generic[VAR_TYPE]): Returns: The converted var. """ + from reflex.event import ( + EventChain, + EventChainVar, + EventSpec, + EventVar, + ToEventChainVarOperation, + ToEventVarOperation, + ) + from .function import FunctionVar, ToFunctionOperation from .number import ( BooleanVar, @@ -416,6 +425,10 @@ class Var(Generic[VAR_TYPE]): return self.to(BooleanVar, output) if fixed_output_type is None: return ToNoneOperation.create(self) + if fixed_output_type is EventSpec: + return self.to(EventVar, output) + if fixed_output_type is EventChain: + return self.to(EventChainVar, output) if issubclass(fixed_output_type, Base): return self.to(ObjectVar, output) if dataclasses.is_dataclass(fixed_output_type) and not issubclass( @@ -453,10 +466,13 @@ class Var(Generic[VAR_TYPE]): if issubclass(output, StringVar): return ToStringOperation.create(self, var_type or str) - if issubclass(output, (ObjectVar, Base)): - return ToObjectOperation.create(self, var_type or dict) + if issubclass(output, EventVar): + return ToEventVarOperation.create(self, var_type or EventSpec) - if dataclasses.is_dataclass(output): + if issubclass(output, EventChainVar): + return ToEventChainVarOperation.create(self, var_type or EventChain) + + if issubclass(output, (ObjectVar, Base)): return ToObjectOperation.create(self, var_type or dict) if issubclass(output, FunctionVar): @@ -469,6 +485,9 @@ class Var(Generic[VAR_TYPE]): if issubclass(output, NoneVar): return ToNoneOperation.create(self) + if dataclasses.is_dataclass(output): + return ToObjectOperation.create(self, var_type or dict) + # If we can't determine the first argument, we just replace the _var_type. if not issubclass(output, Var) or var_type is None: return dataclasses.replace( @@ -494,6 +513,8 @@ class Var(Generic[VAR_TYPE]): Raises: TypeError: If the type is not supported for guessing. """ + from reflex.event import EventChain, EventChainVar, EventSpec, EventVar + from .number import BooleanVar, NumberVar from .object import ObjectVar from .sequence import ArrayVar, StringVar @@ -539,6 +560,10 @@ class Var(Generic[VAR_TYPE]): return self.to(ArrayVar, self._var_type) if issubclass(fixed_type, str): return self.to(StringVar, self._var_type) + if issubclass(fixed_type, EventSpec): + return self.to(EventVar, self._var_type) + if issubclass(fixed_type, EventChain): + return self.to(EventChainVar, self._var_type) if issubclass(fixed_type, Base): return self.to(ObjectVar, self._var_type) if dataclasses.is_dataclass(fixed_type): @@ -1029,47 +1054,22 @@ class LiteralVar(Var): if value is None: return LiteralNoneVar.create(_var_data=_var_data) - from reflex.event import EventChain, EventHandler, EventSpec + from reflex.event import ( + EventChain, + EventHandler, + EventSpec, + LiteralEventChainVar, + LiteralEventVar, + ) from reflex.utils.format import get_event_handler_parts - from .function import ArgsFunctionOperation, FunctionStringVar from .object import LiteralObjectVar if isinstance(value, EventSpec): - event_name = LiteralVar.create( - ".".join(filter(None, get_event_handler_parts(value.handler))) - ) - event_args = LiteralVar.create( - {str(name): value for name, value in value.args} - ) - event_client_name = LiteralVar.create(value.client_handler_name) - return FunctionStringVar("Event").call( - event_name, - event_args, - *([event_client_name] if value.client_handler_name else []), - ) + return LiteralEventVar.create(value, _var_data=_var_data) if isinstance(value, EventChain): - sig = inspect.signature(value.args_spec) # type: ignore - if sig.parameters: - arg_def = tuple((f"_{p}" for p in sig.parameters)) - arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) - else: - # add a default argument for addEvents if none were specified in value.args_spec - # used to trigger the preventDefault() on the event. - arg_def = ("...args",) - arg_def_expr = Var(_js_expr="args") - - return ArgsFunctionOperation.create( - arg_def, - FunctionStringVar.create("addEvents").call( - LiteralVar.create( - [LiteralVar.create(event) for event in value.events] - ), - arg_def_expr, - LiteralVar.create(value.event_actions), - ), - ) + return LiteralEventChainVar.create(value, _var_data=_var_data) if isinstance(value, EventHandler): return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value)))) @@ -2126,9 +2126,16 @@ class NoneVar(Var[None]): """A var representing None.""" +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) class LiteralNoneVar(LiteralVar, NoneVar): """A var representing None.""" + _var_value: None = None + def json(self) -> str: """Serialize the var to a JSON string. diff --git a/tests/units/components/base/test_script.py b/tests/units/components/base/test_script.py index c6b67da11..be62276f2 100644 --- a/tests/units/components/base/test_script.py +++ b/tests/units/components/base/test_script.py @@ -58,14 +58,14 @@ def test_script_event_handler(): ) render_dict = component.render() assert ( - f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }})))], args, ({{ }})))))}}' + f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }}), ({{ }})))], args, ({{ }})))))}}' in render_dict["props"] ) assert ( - f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }})))], args, ({{ }})))))}}' + f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }}), ({{ }})))], args, ({{ }})))))}}' in render_dict["props"] ) assert ( - f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }})))], args, ({{ }})))))}}' + f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }}), ({{ }})))], args, ({{ }})))))}}' in render_dict["props"] ) diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 73d3f611b..5e94db052 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -832,7 +832,7 @@ def test_component_event_trigger_arbitrary_args(): assert comp.render()["props"][0] == ( "onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents(" - f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }})))], ' + f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{ }})))], ' "[__e, _alpha, _bravo, _charlie], ({ })))))}" ) @@ -1178,7 +1178,7 @@ TEST_VAR = LiteralVar.create("test")._replace( ) FORMATTED_TEST_VAR = LiteralVar.create(f"foo{TEST_VAR}bar") STYLE_VAR = TEST_VAR._replace(_js_expr="style") -EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain) +EVENT_CHAIN_VAR = TEST_VAR.to(EventChain) ARG_VAR = Var(_js_expr="arg") TEST_VAR_DICT_OF_DICT = LiteralVar.create({"a": {"b": "test"}})._replace( @@ -2159,7 +2159,7 @@ class TriggerState(rx.State): rx.text("random text", on_click=TriggerState.do_something), rx.text( "random text", - on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain), + on_click=Var(_js_expr="toggleColorMode").to(EventChain), ), ), True, @@ -2169,7 +2169,7 @@ class TriggerState(rx.State): rx.text("random text", on_click=rx.console_log("log")), rx.text( "random text", - on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain), + on_click=Var(_js_expr="toggleColorMode").to(EventChain), ), ), False, diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index 042c3f323..d7b0c791e 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -374,7 +374,7 @@ def test_format_match( events=[EventSpec(handler=EventHandler(fn=mock_event))], args_spec=lambda: [], ), - '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ })))))', + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ })))))', ), ( EventChain( @@ -395,7 +395,7 @@ def test_format_match( ], args_spec=lambda e: [e.target.value], ), - '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] })))], [_e], ({ })))))', + '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({ })))], [_e], ({ })))))', ), ( EventChain( @@ -403,7 +403,19 @@ def test_format_match( args_spec=lambda: [], event_actions={"stopPropagation": True}, ), - '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["stopPropagation"] : true })))))', + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true })))))', + ), + ( + EventChain( + events=[ + EventSpec( + handler=EventHandler(fn=mock_event), + event_actions={"stopPropagation": True}, + ) + ], + args_spec=lambda: [], + ), + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ })))))', ), ( EventChain( @@ -411,7 +423,7 @@ def test_format_match( args_spec=lambda: [], event_actions={"preventDefault": True}, ), - '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["preventDefault"] : true })))))', + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true })))))', ), ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'), (Var(_js_expr="var", _var_type=int).guess_type(), "var"), @@ -519,7 +531,7 @@ def test_format_event_handler(input, output): [ ( EventSpec(handler=EventHandler(fn=mock_event)), - '(Event("mock_event", ({ })))', + '(Event("mock_event", ({ }), ({ })))', ), ], )