support eventspec/eventchain in var operations (#4038)

This commit is contained in:
Khaleel Al-Adhami 2024-10-03 15:33:51 -07:00 committed by GitHub
parent ad0827c59c
commit 73e8a4e0ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 310 additions and 90 deletions

View File

@ -544,13 +544,19 @@ export const uploadFiles = async (
/**
* Create an event object.
* @param name The name of the event.
* @param payload The payload of the event.
* @param handler The client handler to process event.
* @param {string} name The name of the event.
* @param {Object.<string, Any>} payload The payload of the event.
* @param {Object.<string, (number|boolean)>} event_actions The actions to take on the event.
* @param {string} handler The client handler to process event.
* @returns The event object.
*/
export const Event = (name, payload = {}, handler = null) => {
return { name, payload, handler };
export const Event = (
name,
payload = {},
event_actions = {},
handler = null
) => {
return { name, payload, handler, event_actions };
};
/**
@ -676,6 +682,12 @@ export const useEventLoop = (
if (!(args instanceof Array)) {
args = [args];
}
event_actions = events.reduce(
(acc, e) => ({ ...acc, ...e.event_actions }),
event_actions ?? {}
);
const _e = args.filter((o) => o?.preventDefault !== undefined)[0];
if (event_actions?.preventDefault && _e?.preventDefault) {

View File

@ -1536,7 +1536,9 @@ class EventNamespace(AsyncNamespace):
"""
fields = json.loads(data)
# Get the event.
event = Event(**{k: v for k, v in fields.items() if k != "handler"})
event = Event(
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}
)
self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token

View File

@ -38,8 +38,10 @@ from reflex.constants import (
)
from reflex.event import (
EventChain,
EventChainVar,
EventHandler,
EventSpec,
EventVar,
call_event_fn,
call_event_handler,
get_handler_args,
@ -514,7 +516,7 @@ class Component(BaseComponent, ABC):
Var,
EventHandler,
EventSpec,
List[Union[EventHandler, EventSpec]],
List[Union[EventHandler, EventSpec, EventVar]],
Callable,
],
) -> Union[EventChain, Var]:
@ -532,11 +534,16 @@ class Component(BaseComponent, ABC):
"""
# If it's an event chain var, return it.
if isinstance(value, Var):
if value._var_type is not EventChain:
if isinstance(value, EventChainVar):
return value
elif isinstance(value, EventVar):
value = [value]
elif issubclass(value._var_type, (EventChain, EventSpec)):
return self._create_event_chain(args_spec, value.guess_type())
else:
raise ValueError(
f"Invalid event chain: {repr(value)} of type {type(value)}"
f"Invalid event chain: {str(value)} of type {value._var_type}"
)
return value
elif isinstance(value, EventChain):
# Trust that the caller knows what they're doing passing an EventChain directly
return value
@ -547,7 +554,7 @@ class Component(BaseComponent, ABC):
# If the input is a list of event handlers, create an event chain.
if isinstance(value, List):
events: list[EventSpec] = []
events: List[Union[EventSpec, EventVar]] = []
for v in value:
if isinstance(v, (EventHandler, EventSpec)):
# Call the event handler to get the event.
@ -561,6 +568,8 @@ class Component(BaseComponent, ABC):
"lambda inside an EventChain list."
)
events.extend(result)
elif isinstance(v, EventVar):
events.append(v)
else:
raise ValueError(f"Invalid event: {v}")
@ -570,32 +579,30 @@ class Component(BaseComponent, ABC):
if isinstance(result, Var):
# Recursively call this function if the lambda returned an EventChain Var.
return self._create_event_chain(args_spec, result)
events = result
events = [*result]
# Otherwise, raise an error.
else:
raise ValueError(f"Invalid event chain: {value}")
# Add args to the event specs if necessary.
events = [e.with_args(get_handler_args(e)) for e in events]
# Collect event_actions from each spec
event_actions = {}
for e in events:
event_actions.update(e.event_actions)
events = [
(e.with_args(get_handler_args(e)) if isinstance(e, EventSpec) else e)
for e in events
]
# Return the event chain.
if isinstance(args_spec, Var):
return EventChain(
events=events,
args_spec=None,
event_actions=event_actions,
event_actions={},
)
else:
return EventChain(
events=events,
args_spec=args_spec,
event_actions=event_actions,
event_actions={},
)
def get_event_triggers(self) -> Dict[str, Any]:
@ -1030,8 +1037,11 @@ class Component(BaseComponent, ABC):
elif isinstance(event, EventChain):
event_args = []
for spec in event.events:
for args in spec.args:
event_args.extend(args)
if isinstance(spec, EventSpec):
for args in spec.args:
event_args.extend(args)
else:
event_args.append(spec)
yield event_trigger, event_args
def _get_vars(self, include_children: bool = False) -> list[Var]:
@ -1105,8 +1115,12 @@ class Component(BaseComponent, ABC):
for trigger in self.event_triggers.values():
if isinstance(trigger, EventChain):
for event in trigger.events:
if event.handler.state_full_name:
return True
if isinstance(event, EventSpec):
if event.handler.state_full_name:
return True
else:
if event._var_state:
return True
elif isinstance(trigger, Var) and trigger._var_state:
return True
return False

View File

@ -4,16 +4,19 @@ from __future__ import annotations
import dataclasses
import inspect
import sys
import types
import urllib.parse
from base64 import b64encode
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
get_type_hints,
)
@ -25,8 +28,15 @@ from reflex.utils import format
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.types import ArgsSpec, GenericType
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.function import FunctionStringVar, FunctionVar
from reflex.vars.base import (
CachedVarOperation,
LiteralNoneVar,
LiteralVar,
ToOperation,
Var,
cached_property_no_lock,
)
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar
from reflex.vars.object import ObjectVar
try:
@ -375,7 +385,7 @@ class CallableEventSpec(EventSpec):
class EventChain(EventActionsMixin):
"""Container for a chain of events that will be executed in order."""
events: List[EventSpec] = dataclasses.field(default_factory=list)
events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list)
args_spec: Optional[Callable] = dataclasses.field(default=None)
@ -478,7 +488,7 @@ class FileUpload:
if isinstance(events, Var):
raise ValueError(f"{on_upload_progress} cannot return a var {events}.")
on_upload_progress_chain = EventChain(
events=events,
events=[*events],
args_spec=self.on_upload_progress_args_spec,
)
formatted_chain = str(format.format_prop(on_upload_progress_chain))
@ -1136,3 +1146,178 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
"state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
)
return signature.replace(parameters=(new_param, *signature.parameters.values()))
class EventVar(ObjectVar):
"""Base class for event vars."""
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
"""A literal event var."""
_var_value: EventSpec = dataclasses.field(default=None) # type: ignore
def __hash__(self) -> int:
"""Get the hash of the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return str(
FunctionStringVar("Event").call(
# event handler name
".".join(
filter(
None,
format.get_event_handler_parts(self._var_value.handler),
)
),
# event handler args
{str(name): value for name, value in self._var_value.args},
# event actions
self._var_value.event_actions,
# client handler name
*(
[self._var_value.client_handler_name]
if self._var_value.client_handler_name
else []
),
)
)
@classmethod
def create(
cls,
value: EventSpec,
_var_data: VarData | None = None,
) -> LiteralEventVar:
"""Create a new LiteralEventVar instance.
Args:
value: The value of the var.
_var_data: The data of the var.
Returns:
The created LiteralEventVar instance.
"""
return cls(
_js_expr="",
_var_type=EventSpec,
_var_data=_var_data,
_var_value=value,
)
class EventChainVar(FunctionVar):
"""Base class for event chain vars."""
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
"""A literal event chain var."""
_var_value: EventChain = dataclasses.field(default=None) # type: ignore
def __hash__(self) -> int:
"""Get the hash of the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
sig = inspect.signature(self._var_value.args_spec) # type: ignore
if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
else:
# add a default argument for addEvents if none were specified in value.args_spec
# used to trigger the preventDefault() on the event.
arg_def = ("...args",)
arg_def_expr = Var(_js_expr="args")
return str(
ArgsFunctionOperation.create(
arg_def,
FunctionStringVar.create("addEvents").call(
LiteralVar.create(
[LiteralVar.create(event) for event in self._var_value.events]
),
arg_def_expr,
self._var_value.event_actions,
),
)
)
@classmethod
def create(
cls,
value: EventChain,
_var_data: VarData | None = None,
) -> LiteralEventChainVar:
"""Create a new LiteralEventChainVar instance.
Args:
value: The value of the var.
_var_data: The data of the var.
Returns:
The created LiteralEventChainVar instance.
"""
return cls(
_js_expr="",
_var_type=EventChain,
_var_data=_var_data,
_var_value=value,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToEventVarOperation(ToOperation, EventVar):
"""Result of a cast to an event var."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = EventSpec
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToEventChainVarOperation(ToOperation, EventChainVar):
"""Result of a cast to an event chain var."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = EventChain

View File

@ -359,19 +359,7 @@ def format_prop(
# Handle event props.
if isinstance(prop, EventChain):
sig = inspect.signature(prop.args_spec) # type: ignore
if sig.parameters:
arg_def = ",".join(f"_{p}" for p in sig.parameters)
arg_def_expr = f"[{arg_def}]"
else:
# add a default argument for addEvents if none were specified in prop.args_spec
# used to trigger the preventDefault() on the event.
arg_def = "...args"
arg_def_expr = "args"
chain = ",".join([format_event(event) for event in prop.events])
event = f"addEvents([{chain}], {arg_def_expr}, {json_dumps(prop.event_actions)})"
prop = f"({arg_def}) => {event}"
return str(Var.create(prop))
# Handle other types.
elif isinstance(prop, str):

View File

@ -385,6 +385,15 @@ class Var(Generic[VAR_TYPE]):
Returns:
The converted var.
"""
from reflex.event import (
EventChain,
EventChainVar,
EventSpec,
EventVar,
ToEventChainVarOperation,
ToEventVarOperation,
)
from .function import FunctionVar, ToFunctionOperation
from .number import (
BooleanVar,
@ -416,6 +425,10 @@ class Var(Generic[VAR_TYPE]):
return self.to(BooleanVar, output)
if fixed_output_type is None:
return ToNoneOperation.create(self)
if fixed_output_type is EventSpec:
return self.to(EventVar, output)
if fixed_output_type is EventChain:
return self.to(EventChainVar, output)
if issubclass(fixed_output_type, Base):
return self.to(ObjectVar, output)
if dataclasses.is_dataclass(fixed_output_type) and not issubclass(
@ -453,10 +466,13 @@ class Var(Generic[VAR_TYPE]):
if issubclass(output, StringVar):
return ToStringOperation.create(self, var_type or str)
if issubclass(output, (ObjectVar, Base)):
return ToObjectOperation.create(self, var_type or dict)
if issubclass(output, EventVar):
return ToEventVarOperation.create(self, var_type or EventSpec)
if dataclasses.is_dataclass(output):
if issubclass(output, EventChainVar):
return ToEventChainVarOperation.create(self, var_type or EventChain)
if issubclass(output, (ObjectVar, Base)):
return ToObjectOperation.create(self, var_type or dict)
if issubclass(output, FunctionVar):
@ -469,6 +485,9 @@ class Var(Generic[VAR_TYPE]):
if issubclass(output, NoneVar):
return ToNoneOperation.create(self)
if dataclasses.is_dataclass(output):
return ToObjectOperation.create(self, var_type or dict)
# If we can't determine the first argument, we just replace the _var_type.
if not issubclass(output, Var) or var_type is None:
return dataclasses.replace(
@ -494,6 +513,8 @@ class Var(Generic[VAR_TYPE]):
Raises:
TypeError: If the type is not supported for guessing.
"""
from reflex.event import EventChain, EventChainVar, EventSpec, EventVar
from .number import BooleanVar, NumberVar
from .object import ObjectVar
from .sequence import ArrayVar, StringVar
@ -539,6 +560,10 @@ class Var(Generic[VAR_TYPE]):
return self.to(ArrayVar, self._var_type)
if issubclass(fixed_type, str):
return self.to(StringVar, self._var_type)
if issubclass(fixed_type, EventSpec):
return self.to(EventVar, self._var_type)
if issubclass(fixed_type, EventChain):
return self.to(EventChainVar, self._var_type)
if issubclass(fixed_type, Base):
return self.to(ObjectVar, self._var_type)
if dataclasses.is_dataclass(fixed_type):
@ -1029,47 +1054,22 @@ class LiteralVar(Var):
if value is None:
return LiteralNoneVar.create(_var_data=_var_data)
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.event import (
EventChain,
EventHandler,
EventSpec,
LiteralEventChainVar,
LiteralEventVar,
)
from reflex.utils.format import get_event_handler_parts
from .function import ArgsFunctionOperation, FunctionStringVar
from .object import LiteralObjectVar
if isinstance(value, EventSpec):
event_name = LiteralVar.create(
".".join(filter(None, get_event_handler_parts(value.handler)))
)
event_args = LiteralVar.create(
{str(name): value for name, value in value.args}
)
event_client_name = LiteralVar.create(value.client_handler_name)
return FunctionStringVar("Event").call(
event_name,
event_args,
*([event_client_name] if value.client_handler_name else []),
)
return LiteralEventVar.create(value, _var_data=_var_data)
if isinstance(value, EventChain):
sig = inspect.signature(value.args_spec) # type: ignore
if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
else:
# add a default argument for addEvents if none were specified in value.args_spec
# used to trigger the preventDefault() on the event.
arg_def = ("...args",)
arg_def_expr = Var(_js_expr="args")
return ArgsFunctionOperation.create(
arg_def,
FunctionStringVar.create("addEvents").call(
LiteralVar.create(
[LiteralVar.create(event) for event in value.events]
),
arg_def_expr,
LiteralVar.create(value.event_actions),
),
)
return LiteralEventChainVar.create(value, _var_data=_var_data)
if isinstance(value, EventHandler):
return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value))))
@ -2126,9 +2126,16 @@ class NoneVar(Var[None]):
"""A var representing None."""
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralNoneVar(LiteralVar, NoneVar):
"""A var representing None."""
_var_value: None = None
def json(self) -> str:
"""Serialize the var to a JSON string.

View File

@ -58,14 +58,14 @@ def test_script_event_handler():
)
render_dict = component.render()
assert (
f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }})))], args, ({{ }})))))}}'
f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }}), ({{ }})))], args, ({{ }})))))}}'
in render_dict["props"]
)
assert (
f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }})))], args, ({{ }})))))}}'
f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }}), ({{ }})))], args, ({{ }})))))}}'
in render_dict["props"]
)
assert (
f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }})))], args, ({{ }})))))}}'
f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }}), ({{ }})))], args, ({{ }})))))}}'
in render_dict["props"]
)

View File

@ -832,7 +832,7 @@ def test_component_event_trigger_arbitrary_args():
assert comp.render()["props"][0] == (
"onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents("
f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }})))], '
f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{ }})))], '
"[__e, _alpha, _bravo, _charlie], ({ })))))}"
)
@ -1178,7 +1178,7 @@ TEST_VAR = LiteralVar.create("test")._replace(
)
FORMATTED_TEST_VAR = LiteralVar.create(f"foo{TEST_VAR}bar")
STYLE_VAR = TEST_VAR._replace(_js_expr="style")
EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain)
EVENT_CHAIN_VAR = TEST_VAR.to(EventChain)
ARG_VAR = Var(_js_expr="arg")
TEST_VAR_DICT_OF_DICT = LiteralVar.create({"a": {"b": "test"}})._replace(
@ -2159,7 +2159,7 @@ class TriggerState(rx.State):
rx.text("random text", on_click=TriggerState.do_something),
rx.text(
"random text",
on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain),
on_click=Var(_js_expr="toggleColorMode").to(EventChain),
),
),
True,
@ -2169,7 +2169,7 @@ class TriggerState(rx.State):
rx.text("random text", on_click=rx.console_log("log")),
rx.text(
"random text",
on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain),
on_click=Var(_js_expr="toggleColorMode").to(EventChain),
),
),
False,

View File

@ -374,7 +374,7 @@ def test_format_match(
events=[EventSpec(handler=EventHandler(fn=mock_event))],
args_spec=lambda: [],
),
'((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ })))))',
'((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ })))))',
),
(
EventChain(
@ -395,7 +395,7 @@ def test_format_match(
],
args_spec=lambda e: [e.target.value],
),
'((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] })))], [_e], ({ })))))',
'((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({ })))], [_e], ({ })))))',
),
(
EventChain(
@ -403,7 +403,19 @@ def test_format_match(
args_spec=lambda: [],
event_actions={"stopPropagation": True},
),
'((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["stopPropagation"] : true })))))',
'((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true })))))',
),
(
EventChain(
events=[
EventSpec(
handler=EventHandler(fn=mock_event),
event_actions={"stopPropagation": True},
)
],
args_spec=lambda: [],
),
'((...args) => ((addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ })))))',
),
(
EventChain(
@ -411,7 +423,7 @@ def test_format_match(
args_spec=lambda: [],
event_actions={"preventDefault": True},
),
'((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["preventDefault"] : true })))))',
'((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true })))))',
),
({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'),
(Var(_js_expr="var", _var_type=int).guess_type(), "var"),
@ -519,7 +531,7 @@ def test_format_event_handler(input, output):
[
(
EventSpec(handler=EventHandler(fn=mock_event)),
'(Event("mock_event", ({ })))',
'(Event("mock_event", ({ }), ({ })))',
),
],
)