From c460040040eca793e96aba40c2ce023a0075161c Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 16 Oct 2024 11:35:09 -0700 Subject: [PATCH] LiteralEventChainVar becomes an ArgsFunctionOperation (#4174) * LiteralEventChainVar becomes an ArgsFunctionOperation Instead of using the ArgsFunctionOperation to create the string representation of the _js_expr, make the identity of the var an ArgsFunctionOperation so the _args_names and _return_expr remain accessible. Rely on the default behavior of ArgsFunctionOperation to create the _cached_var_name / _js_expr value. This allows the compat shim in `format_event_chain` to remain functional, as it does special handling for ArgsFunctionOperation to retain the previous behavior of that function (this was a regression introduced in 0.6.2). * _var_type is EventChain; fix parent class order * Re-fix LiteralEventChainVar inheritence list w/ comment * [ENG-3942] LiteralEventVar becomes VarCallOperation instead of using `.call` when constructing the `_js_expr`, have the identity of a LiteralEventVar as a VarCallOperation to take advantage of the _var_data carrying. * add event overlords * EventCallback descriptor always returns EventSpec from class Relax actual `__get__` definition to support the multitude of overloads * test case for event related vars carrying _var_data --------- Co-authored-by: Khaleel Al-Adhami --- reflex/event.py | 173 +++++++++++++++++++++++--------------- tests/units/test_event.py | 38 ++++++++- 2 files changed, 138 insertions(+), 73 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index 04879add3..4b0bf96e2 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -22,6 +22,7 @@ from typing import ( TypeVar, Union, get_type_hints, + overload, ) from typing_extensions import ParamSpec, get_args, get_origin @@ -32,14 +33,17 @@ from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch from reflex.utils.types import ArgsSpec, GenericType from reflex.vars import VarData from reflex.vars.base import ( - CachedVarOperation, LiteralNoneVar, LiteralVar, ToOperation, Var, - cached_property_no_lock, ) -from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar +from reflex.vars.function import ( + ArgsFunctionOperation, + FunctionStringVar, + FunctionVar, + VarOperationCall, +) from reflex.vars.object import ObjectVar try: @@ -1258,7 +1262,7 @@ class EventVar(ObjectVar): frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): +class LiteralEventVar(VarOperationCall, LiteralVar, EventVar): """A literal event var.""" _var_value: EventSpec = dataclasses.field(default=None) # type: ignore @@ -1271,35 +1275,6 @@ class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): """ return hash((self.__class__.__name__, self._js_expr)) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return str( - FunctionStringVar("Event").call( - # event handler name - ".".join( - filter( - None, - format.get_event_handler_parts(self._var_value.handler), - ) - ), - # event handler args - {str(name): value for name, value in self._var_value.args}, - # event actions - self._var_value.event_actions, - # client handler name - *( - [self._var_value.client_handler_name] - if self._var_value.client_handler_name - else [] - ), - ) - ) - @classmethod def create( cls, @@ -1320,6 +1295,22 @@ class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): _var_type=EventSpec, _var_data=_var_data, _var_value=value, + _func=FunctionStringVar("Event"), + _args=( + # event handler name + ".".join( + filter( + None, + format.get_event_handler_parts(value.handler), + ) + ), + # event handler args + {str(name): value for name, value in value.args}, + # event actions + value.event_actions, + # client handler name + *([value.client_handler_name] if value.client_handler_name else []), + ), ) @@ -1332,7 +1323,10 @@ class EventChainVar(FunctionVar): frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): +# Note: LiteralVar is second in the inheritance list allowing it act like a +# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the +# _cached_var_name property. +class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar): """A literal event chain var.""" _var_value: EventChain = dataclasses.field(default=None) # type: ignore @@ -1345,41 +1339,6 @@ class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): """ return hash((self.__class__.__name__, self._js_expr)) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - sig = inspect.signature(self._var_value.args_spec) # type: ignore - if sig.parameters: - arg_def = tuple((f"_{p}" for p in sig.parameters)) - arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) - else: - # add a default argument for addEvents if none were specified in value.args_spec - # used to trigger the preventDefault() on the event. - arg_def = ("...args",) - arg_def_expr = Var(_js_expr="args") - - if self._var_value.invocation is None: - invocation = FunctionStringVar.create("addEvents") - else: - invocation = self._var_value.invocation - - return str( - ArgsFunctionOperation.create( - arg_def, - invocation.call( - LiteralVar.create( - [LiteralVar.create(event) for event in self._var_value.events] - ), - arg_def_expr, - self._var_value.event_actions, - ), - ) - ) - @classmethod def create( cls, @@ -1395,10 +1354,31 @@ class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): Returns: The created LiteralEventChainVar instance. """ + sig = inspect.signature(value.args_spec) # type: ignore + if sig.parameters: + arg_def = tuple((f"_{p}" for p in sig.parameters)) + arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) + else: + # add a default argument for addEvents if none were specified in value.args_spec + # used to trigger the preventDefault() on the event. + arg_def = ("...args",) + arg_def_expr = Var(_js_expr="args") + + if value.invocation is None: + invocation = FunctionStringVar.create("addEvents") + else: + invocation = value.invocation + return cls( _js_expr="", _var_type=EventChain, _var_data=_var_data, + _args_names=arg_def, + _return_expr=invocation.call( + LiteralVar.create([LiteralVar.create(event) for event in value.events]), + arg_def_expr, + value.event_actions, + ), _var_value=value, ) @@ -1437,6 +1417,11 @@ EventType = Union[IndividualEventType[G], List[IndividualEventType[G]]] P = ParamSpec("P") T = TypeVar("T") +V = TypeVar("V") +V2 = TypeVar("V2") +V3 = TypeVar("V3") +V4 = TypeVar("V4") +V5 = TypeVar("V5") if sys.version_info >= (3, 10): from typing import Concatenate @@ -1452,7 +1437,55 @@ if sys.version_info >= (3, 10): """ self.func = func - def __get__(self, instance, owner) -> Callable[P, T]: + @overload + def __get__( + self: EventCallback[[V], T], instance: None, owner + ) -> Callable[[Union[Var[V], V]], EventSpec]: ... + + @overload + def __get__( + self: EventCallback[[V, V2], T], instance: None, owner + ) -> Callable[[Union[Var[V], V], Union[Var[V2], V2]], EventSpec]: ... + + @overload + def __get__( + self: EventCallback[[V, V2, V3], T], instance: None, owner + ) -> Callable[ + [Union[Var[V], V], Union[Var[V2], V2], Union[Var[V3], V3]], + EventSpec, + ]: ... + + @overload + def __get__( + self: EventCallback[[V, V2, V3, V4], T], instance: None, owner + ) -> Callable[ + [ + Union[Var[V], V], + Union[Var[V2], V2], + Union[Var[V3], V3], + Union[Var[V4], V4], + ], + EventSpec, + ]: ... + + @overload + def __get__( + self: EventCallback[[V, V2, V3, V4, V5], T], instance: None, owner + ) -> Callable[ + [ + Union[Var[V], V], + Union[Var[V2], V2], + Union[Var[V3], V3], + Union[Var[V4], V4], + Union[Var[V5], V5], + ], + EventSpec, + ]: ... + + @overload + def __get__(self, instance, owner) -> Callable[P, T]: ... + + def __get__(self, instance, owner) -> Callable: """Get the function with the instance bound to it. Args: diff --git a/tests/units/test_event.py b/tests/units/test_event.py index 3996a6101..d7b7cf7a2 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -2,11 +2,18 @@ from typing import List import pytest -from reflex import event -from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events +from reflex.event import ( + Event, + EventChain, + EventHandler, + EventSpec, + call_event_handler, + event, + fix_events, +) from reflex.state import BaseState from reflex.utils import format -from reflex.vars.base import LiteralVar, Var +from reflex.vars.base import Field, LiteralVar, Var, field def make_var(value) -> Var: @@ -388,3 +395,28 @@ def test_event_actions_on_state(): assert sp_handler.event_actions == {"stopPropagation": True} # should NOT affect other references to the handler assert not handler.event_actions + + +def test_event_var_data(): + class S(BaseState): + x: Field[int] = field(0) + + @event + def s(self, value: int): + pass + + # Handler doesn't have any _var_data because it's just a str + handler_var = Var.create(S.s) + assert handler_var._get_all_var_data() is None + + # Ensure spec carries _var_data + spec_var = Var.create(S.s(S.x)) + assert spec_var._get_all_var_data() == S.x._get_all_var_data() + + # Needed to instantiate the EventChain + def _args_spec(value: Var[int]) -> tuple[Var[int]]: + return (value,) + + # Ensure chain carries _var_data + chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec)) + assert chain_var._get_all_var_data() == S.x._get_all_var_data()