diff --git a/reflex/event.py b/reflex/event.py index 95358ace1..904af252d 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -4,16 +4,19 @@ from __future__ import annotations import dataclasses import inspect +import sys import types import urllib.parse from base64 import b64encode from typing import ( Any, Callable, + ClassVar, Dict, List, Optional, Tuple, + Type, Union, get_type_hints, ) @@ -25,8 +28,15 @@ from reflex.utils import format from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch from reflex.utils.types import ArgsSpec, GenericType from reflex.vars import VarData -from reflex.vars.base import LiteralVar, Var -from reflex.vars.function import FunctionStringVar, FunctionVar +from reflex.vars.base import ( + CachedVarOperation, + LiteralNoneVar, + LiteralVar, + ToOperation, + Var, + cached_property_no_lock, +) +from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar from reflex.vars.object import ObjectVar try: @@ -375,7 +385,7 @@ class CallableEventSpec(EventSpec): class EventChain(EventActionsMixin): """Container for a chain of events that will be executed in order.""" - events: List[EventSpec] = dataclasses.field(default_factory=list) + events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list) args_spec: Optional[Callable] = dataclasses.field(default=None) @@ -1126,3 +1136,153 @@ def get_fn_signature(fn: Callable) -> inspect.Signature: "state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any ) return signature.replace(parameters=(new_param, *signature.parameters.values())) + + +class EventVar(ObjectVar): + pass + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): + _var_value: EventSpec = dataclasses.field(default=None) # type: ignore + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + event_name = LiteralVar.create( + ".".join( + filter(None, format.get_event_handler_parts(self._var_value.handler)) + ) + ) + event_args = LiteralVar.create( + {str(name): value for name, value in self._var_value.args} + ) + event_client_name = LiteralVar.create(self._var_value.client_handler_name) + return str( + FunctionStringVar("Event").call( + event_name, + event_args, + *([event_client_name] if self._var_value.client_handler_name else []), + ) + ) + + @classmethod + def create( + cls, + value: EventSpec, + _var_data: VarData | None = None, + ) -> LiteralEventVar: + """Create a new LiteralEventVar instance. + + Args: + value: The value of the var. + _var_data: The data of the var. + + Returns: + The created LiteralEventVar instance. + """ + return cls( + _js_expr="", + _var_type=EventSpec, + _var_data=_var_data, + _var_value=value, + ) + + +class EventChainVar(FunctionVar): + pass + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): + _var_value: EventChain = dataclasses.field(default=None) # type: ignore + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + sig = inspect.signature(self._var_value.args_spec) # type: ignore + if sig.parameters: + arg_def = tuple((f"_{p}" for p in sig.parameters)) + arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) + else: + # add a default argument for addEvents if none were specified in value.args_spec + # used to trigger the preventDefault() on the event. + arg_def = ("...args",) + arg_def_expr = Var(_js_expr="args") + + return str( + ArgsFunctionOperation.create( + arg_def, + FunctionStringVar.create("addEvents").call( + LiteralVar.create( + [LiteralVar.create(event) for event in self._var_value.events] + ), + arg_def_expr, + LiteralVar.create(self._var_value.event_actions), + ), + ) + ) + + @classmethod + def create( + cls, + value: EventChain, + _var_data: VarData | None = None, + ) -> LiteralEventChainVar: + """Create a new LiteralEventChainVar instance. + + Args: + value: The value of the var. + _var_data: The data of the var. + + Returns: + The created LiteralEventChainVar instance. + """ + return cls( + _js_expr="", + _var_type=EventChain, + _var_data=_var_data, + _var_value=value, + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToEventVarOperation(ToOperation, EventVar): + """Base class for immutable number vars that are the result of a number operation.""" + + _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) + + _default_var_type: ClassVar[Type] = EventSpec + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToEventChainVarOperation(ToOperation, EventChainVar): + """Base class for immutable number vars that are the result of a number operation.""" + + _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) + + _default_var_type: ClassVar[Type] = EventChain diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2d78a14be..dc324d07a 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -385,6 +385,15 @@ class Var(Generic[VAR_TYPE]): Returns: The converted var. """ + from reflex.event import ( + EventChain, + EventChainVar, + EventSpec, + EventVar, + ToEventVarOperation, + ToEventChainVarOperation, + ) + from .function import FunctionVar, ToFunctionOperation from .number import ( BooleanVar, @@ -416,6 +425,10 @@ class Var(Generic[VAR_TYPE]): return self.to(BooleanVar, output) if fixed_output_type is None: return ToNoneOperation.create(self) + if fixed_output_type is EventSpec: + return self.to(EventVar, output) + if fixed_output_type is EventChain: + return self.to(EventChainVar, output) if issubclass(fixed_output_type, Base): return self.to(ObjectVar, output) if dataclasses.is_dataclass(fixed_output_type) and not issubclass( @@ -453,6 +466,12 @@ class Var(Generic[VAR_TYPE]): if issubclass(output, StringVar): return ToStringOperation.create(self, var_type or str) + if issubclass(output, EventVar): + return ToEventVarOperation.create(self, var_type or EventSpec) + + if issubclass(output, EventChainVar): + return ToEventChainVarOperation.create(self, var_type or EventChain) + if issubclass(output, (ObjectVar, Base)): return ToObjectOperation.create(self, var_type or dict) @@ -494,6 +513,8 @@ class Var(Generic[VAR_TYPE]): Raises: TypeError: If the type is not supported for guessing. """ + from reflex.event import EventChain, EventChainVar, EventSpec, EventVar + from .number import BooleanVar, NumberVar from .object import ObjectVar from .sequence import ArrayVar, StringVar @@ -539,6 +560,10 @@ class Var(Generic[VAR_TYPE]): return self.to(ArrayVar, self._var_type) if issubclass(fixed_type, str): return self.to(StringVar, self._var_type) + if issubclass(fixed_type, EventSpec): + return self.to(EventVar, self._var_type) + if issubclass(fixed_type, EventChain): + return self.to(EventChainVar, self._var_type) if issubclass(fixed_type, Base): return self.to(ObjectVar, self._var_type) if dataclasses.is_dataclass(fixed_type): @@ -1029,47 +1054,22 @@ class LiteralVar(Var): if value is None: return LiteralNoneVar.create(_var_data=_var_data) - from reflex.event import EventChain, EventHandler, EventSpec + from reflex.event import ( + EventChain, + EventHandler, + EventSpec, + LiteralEventChainVar, + LiteralEventVar, + ) from reflex.utils.format import get_event_handler_parts - from .function import ArgsFunctionOperation, FunctionStringVar from .object import LiteralObjectVar if isinstance(value, EventSpec): - event_name = LiteralVar.create( - ".".join(filter(None, get_event_handler_parts(value.handler))) - ) - event_args = LiteralVar.create( - {str(name): value for name, value in value.args} - ) - event_client_name = LiteralVar.create(value.client_handler_name) - return FunctionStringVar("Event").call( - event_name, - event_args, - *([event_client_name] if value.client_handler_name else []), - ) + return LiteralEventVar.create(value, _var_data=_var_data) if isinstance(value, EventChain): - sig = inspect.signature(value.args_spec) # type: ignore - if sig.parameters: - arg_def = tuple((f"_{p}" for p in sig.parameters)) - arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) - else: - # add a default argument for addEvents if none were specified in value.args_spec - # used to trigger the preventDefault() on the event. - arg_def = ("...args",) - arg_def_expr = Var(_js_expr="args") - - return ArgsFunctionOperation.create( - arg_def, - FunctionStringVar.create("addEvents").call( - LiteralVar.create( - [LiteralVar.create(event) for event in value.events] - ), - arg_def_expr, - LiteralVar.create(value.event_actions), - ), - ) + return LiteralEventChainVar.create(value, _var_data=_var_data) if isinstance(value, EventHandler): return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value)))) @@ -2126,9 +2126,16 @@ class NoneVar(Var[None]): """A var representing None.""" +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) class LiteralNoneVar(LiteralVar, NoneVar): """A var representing None.""" + _var_value: None = None + def json(self) -> str: """Serialize the var to a JSON string.