This commit is contained in:
Khaleel Al-Adhami 2024-10-01 15:36:05 -07:00
parent ca7df11bf2
commit 20aca15644
3 changed files with 48 additions and 14 deletions

View File

@ -38,8 +38,10 @@ from reflex.constants import (
) )
from reflex.event import ( from reflex.event import (
EventChain, EventChain,
EventChainVar,
EventHandler, EventHandler,
EventSpec, EventSpec,
EventVar,
call_event_fn, call_event_fn,
call_event_handler, call_event_handler,
get_handler_args, get_handler_args,
@ -514,7 +516,7 @@ class Component(BaseComponent, ABC):
Var, Var,
EventHandler, EventHandler,
EventSpec, EventSpec,
List[Union[EventHandler, EventSpec]], List[Union[EventHandler, EventSpec, EventVar]],
Callable, Callable,
], ],
) -> Union[EventChain, Var]: ) -> Union[EventChain, Var]:
@ -532,11 +534,14 @@ class Component(BaseComponent, ABC):
""" """
# If it's an event chain var, return it. # If it's an event chain var, return it.
if isinstance(value, Var): if isinstance(value, Var):
if value._var_type is not EventChain: if isinstance(value, EventChainVar):
raise ValueError(
f"Invalid event chain: {repr(value)} of type {type(value)}"
)
return value return value
if isinstance(value, EventVar):
value = [value]
else:
raise ValueError(
f"Invalid event chain: {str(value)} of type {value._var_type}"
)
elif isinstance(value, EventChain): elif isinstance(value, EventChain):
# Trust that the caller knows what they're doing passing an EventChain directly # Trust that the caller knows what they're doing passing an EventChain directly
return value return value
@ -547,7 +552,7 @@ class Component(BaseComponent, ABC):
# If the input is a list of event handlers, create an event chain. # If the input is a list of event handlers, create an event chain.
if isinstance(value, List): if isinstance(value, List):
events: list[EventSpec] = [] events: List[Union[EventSpec, EventVar]] = []
for v in value: for v in value:
if isinstance(v, (EventHandler, EventSpec)): if isinstance(v, (EventHandler, EventSpec)):
# Call the event handler to get the event. # Call the event handler to get the event.
@ -561,6 +566,8 @@ class Component(BaseComponent, ABC):
"lambda inside an EventChain list." "lambda inside an EventChain list."
) )
events.extend(result) events.extend(result)
elif isinstance(v, EventVar):
events.append(v)
else: else:
raise ValueError(f"Invalid event: {v}") raise ValueError(f"Invalid event: {v}")
@ -577,11 +584,15 @@ class Component(BaseComponent, ABC):
raise ValueError(f"Invalid event chain: {value}") raise ValueError(f"Invalid event chain: {value}")
# Add args to the event specs if necessary. # Add args to the event specs if necessary.
events = [e.with_args(get_handler_args(e)) for e in events] events = [
(e.with_args(get_handler_args(e)) if isinstance(e, EventSpec) else e)
for e in events
]
# Collect event_actions from each spec # Collect event_actions from each spec
event_actions = {} event_actions = {}
for e in events: for e in events:
if isinstance(e, EventSpec):
event_actions.update(e.event_actions) event_actions.update(e.event_actions)
# Return the event chain. # Return the event chain.
@ -1030,8 +1041,11 @@ class Component(BaseComponent, ABC):
elif isinstance(event, EventChain): elif isinstance(event, EventChain):
event_args = [] event_args = []
for spec in event.events: for spec in event.events:
if isinstance(spec, EventSpec):
for args in spec.args: for args in spec.args:
event_args.extend(args) event_args.extend(args)
else:
event_args.append(spec)
yield event_trigger, event_args yield event_trigger, event_args
def _get_vars(self, include_children: bool = False) -> list[Var]: def _get_vars(self, include_children: bool = False) -> list[Var]:
@ -1105,8 +1119,12 @@ class Component(BaseComponent, ABC):
for trigger in self.event_triggers.values(): for trigger in self.event_triggers.values():
if isinstance(trigger, EventChain): if isinstance(trigger, EventChain):
for event in trigger.events: for event in trigger.events:
if isinstance(event, EventSpec):
if event.handler.state_full_name: if event.handler.state_full_name:
return True return True
else:
if event._var_state:
return True
elif isinstance(trigger, Var) and trigger._var_state: elif isinstance(trigger, Var) and trigger._var_state:
return True return True
return False return False

View File

@ -1150,6 +1150,14 @@ class EventVar(ObjectVar):
class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
_var_value: EventSpec = dataclasses.field(default=None) # type: ignore _var_value: EventSpec = dataclasses.field(default=None) # type: ignore
def __hash__(self) -> int:
"""Get the hash of the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock @cached_property_no_lock
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -1209,6 +1217,14 @@ class EventChainVar(FunctionVar):
class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
_var_value: EventChain = dataclasses.field(default=None) # type: ignore _var_value: EventChain = dataclasses.field(default=None) # type: ignore
def __hash__(self) -> int:
"""Get the hash of the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock @cached_property_no_lock
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.

View File

@ -475,9 +475,6 @@ class Var(Generic[VAR_TYPE]):
if issubclass(output, (ObjectVar, Base)): if issubclass(output, (ObjectVar, Base)):
return ToObjectOperation.create(self, var_type or dict) return ToObjectOperation.create(self, var_type or dict)
if dataclasses.is_dataclass(output):
return ToObjectOperation.create(self, var_type or dict)
if issubclass(output, FunctionVar): if issubclass(output, FunctionVar):
# if fixed_type is not None and not issubclass(fixed_type, Callable): # if fixed_type is not None and not issubclass(fixed_type, Callable):
# raise TypeError( # raise TypeError(
@ -488,6 +485,9 @@ class Var(Generic[VAR_TYPE]):
if issubclass(output, NoneVar): if issubclass(output, NoneVar):
return ToNoneOperation.create(self) return ToNoneOperation.create(self)
if dataclasses.is_dataclass(output):
return ToObjectOperation.create(self, var_type or dict)
# If we can't determine the first argument, we just replace the _var_type. # If we can't determine the first argument, we just replace the _var_type.
if not issubclass(output, Var) or var_type is None: if not issubclass(output, Var) or var_type is None:
return dataclasses.replace( return dataclasses.replace(