From 20aca156445142516f68bb536d113fe5ff2d97ec Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Tue, 1 Oct 2024 15:36:05 -0700 Subject: [PATCH] add hash --- reflex/components/component.py | 40 ++++++++++++++++++++++++---------- reflex/event.py | 16 ++++++++++++++ reflex/vars/base.py | 6 ++--- 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 9bdd12f0e..87a3a10ad 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -38,8 +38,10 @@ from reflex.constants import ( ) from reflex.event import ( EventChain, + EventChainVar, EventHandler, EventSpec, + EventVar, call_event_fn, call_event_handler, get_handler_args, @@ -514,7 +516,7 @@ class Component(BaseComponent, ABC): Var, EventHandler, EventSpec, - List[Union[EventHandler, EventSpec]], + List[Union[EventHandler, EventSpec, EventVar]], Callable, ], ) -> Union[EventChain, Var]: @@ -532,11 +534,14 @@ class Component(BaseComponent, ABC): """ # If it's an event chain var, return it. if isinstance(value, Var): - if value._var_type is not EventChain: + if isinstance(value, EventChainVar): + return value + if isinstance(value, EventVar): + value = [value] + else: raise ValueError( - f"Invalid event chain: {repr(value)} of type {type(value)}" + f"Invalid event chain: {str(value)} of type {value._var_type}" ) - return value elif isinstance(value, EventChain): # Trust that the caller knows what they're doing passing an EventChain directly return value @@ -547,7 +552,7 @@ class Component(BaseComponent, ABC): # If the input is a list of event handlers, create an event chain. if isinstance(value, List): - events: list[EventSpec] = [] + events: List[Union[EventSpec, EventVar]] = [] for v in value: if isinstance(v, (EventHandler, EventSpec)): # Call the event handler to get the event. @@ -561,6 +566,8 @@ class Component(BaseComponent, ABC): "lambda inside an EventChain list." ) events.extend(result) + elif isinstance(v, EventVar): + events.append(v) else: raise ValueError(f"Invalid event: {v}") @@ -577,12 +584,16 @@ class Component(BaseComponent, ABC): raise ValueError(f"Invalid event chain: {value}") # Add args to the event specs if necessary. - events = [e.with_args(get_handler_args(e)) for e in events] + events = [ + (e.with_args(get_handler_args(e)) if isinstance(e, EventSpec) else e) + for e in events + ] # Collect event_actions from each spec event_actions = {} for e in events: - event_actions.update(e.event_actions) + if isinstance(e, EventSpec): + event_actions.update(e.event_actions) # Return the event chain. if isinstance(args_spec, Var): @@ -1030,8 +1041,11 @@ class Component(BaseComponent, ABC): elif isinstance(event, EventChain): event_args = [] for spec in event.events: - for args in spec.args: - event_args.extend(args) + if isinstance(spec, EventSpec): + for args in spec.args: + event_args.extend(args) + else: + event_args.append(spec) yield event_trigger, event_args def _get_vars(self, include_children: bool = False) -> list[Var]: @@ -1105,8 +1119,12 @@ class Component(BaseComponent, ABC): for trigger in self.event_triggers.values(): if isinstance(trigger, EventChain): for event in trigger.events: - if event.handler.state_full_name: - return True + if isinstance(event, EventSpec): + if event.handler.state_full_name: + return True + else: + if event._var_state: + return True elif isinstance(trigger, Var) and trigger._var_state: return True return False diff --git a/reflex/event.py b/reflex/event.py index 904af252d..df261a862 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1150,6 +1150,14 @@ class EventVar(ObjectVar): class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): _var_value: EventSpec = dataclasses.field(default=None) # type: ignore + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._js_expr)) + @cached_property_no_lock def _cached_var_name(self) -> str: """The name of the var. @@ -1209,6 +1217,14 @@ class EventChainVar(FunctionVar): class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): _var_value: EventChain = dataclasses.field(default=None) # type: ignore + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._js_expr)) + @cached_property_no_lock def _cached_var_name(self) -> str: """The name of the var. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index dc324d07a..84cb589e1 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -475,9 +475,6 @@ class Var(Generic[VAR_TYPE]): if issubclass(output, (ObjectVar, Base)): return ToObjectOperation.create(self, var_type or dict) - if dataclasses.is_dataclass(output): - return ToObjectOperation.create(self, var_type or dict) - if issubclass(output, FunctionVar): # if fixed_type is not None and not issubclass(fixed_type, Callable): # raise TypeError( @@ -488,6 +485,9 @@ class Var(Generic[VAR_TYPE]): if issubclass(output, NoneVar): return ToNoneOperation.create(self) + if dataclasses.is_dataclass(output): + return ToObjectOperation.create(self, var_type or dict) + # If we can't determine the first argument, we just replace the _var_type. if not issubclass(output, Var) or var_type is None: return dataclasses.replace(