This commit is contained in:
Khaleel Al-Adhami 2024-10-01 15:36:05 -07:00
parent ca7df11bf2
commit 20aca15644
3 changed files with 48 additions and 14 deletions

View File

@ -38,8 +38,10 @@ from reflex.constants import (
)
from reflex.event import (
EventChain,
EventChainVar,
EventHandler,
EventSpec,
EventVar,
call_event_fn,
call_event_handler,
get_handler_args,
@ -514,7 +516,7 @@ class Component(BaseComponent, ABC):
Var,
EventHandler,
EventSpec,
List[Union[EventHandler, EventSpec]],
List[Union[EventHandler, EventSpec, EventVar]],
Callable,
],
) -> Union[EventChain, Var]:
@ -532,11 +534,14 @@ class Component(BaseComponent, ABC):
"""
# If it's an event chain var, return it.
if isinstance(value, Var):
if value._var_type is not EventChain:
raise ValueError(
f"Invalid event chain: {repr(value)} of type {type(value)}"
)
if isinstance(value, EventChainVar):
return value
if isinstance(value, EventVar):
value = [value]
else:
raise ValueError(
f"Invalid event chain: {str(value)} of type {value._var_type}"
)
elif isinstance(value, EventChain):
# Trust that the caller knows what they're doing passing an EventChain directly
return value
@ -547,7 +552,7 @@ class Component(BaseComponent, ABC):
# If the input is a list of event handlers, create an event chain.
if isinstance(value, List):
events: list[EventSpec] = []
events: List[Union[EventSpec, EventVar]] = []
for v in value:
if isinstance(v, (EventHandler, EventSpec)):
# Call the event handler to get the event.
@ -561,6 +566,8 @@ class Component(BaseComponent, ABC):
"lambda inside an EventChain list."
)
events.extend(result)
elif isinstance(v, EventVar):
events.append(v)
else:
raise ValueError(f"Invalid event: {v}")
@ -577,11 +584,15 @@ class Component(BaseComponent, ABC):
raise ValueError(f"Invalid event chain: {value}")
# Add args to the event specs if necessary.
events = [e.with_args(get_handler_args(e)) for e in events]
events = [
(e.with_args(get_handler_args(e)) if isinstance(e, EventSpec) else e)
for e in events
]
# Collect event_actions from each spec
event_actions = {}
for e in events:
if isinstance(e, EventSpec):
event_actions.update(e.event_actions)
# Return the event chain.
@ -1030,8 +1041,11 @@ class Component(BaseComponent, ABC):
elif isinstance(event, EventChain):
event_args = []
for spec in event.events:
if isinstance(spec, EventSpec):
for args in spec.args:
event_args.extend(args)
else:
event_args.append(spec)
yield event_trigger, event_args
def _get_vars(self, include_children: bool = False) -> list[Var]:
@ -1105,8 +1119,12 @@ class Component(BaseComponent, ABC):
for trigger in self.event_triggers.values():
if isinstance(trigger, EventChain):
for event in trigger.events:
if isinstance(event, EventSpec):
if event.handler.state_full_name:
return True
else:
if event._var_state:
return True
elif isinstance(trigger, Var) and trigger._var_state:
return True
return False

View File

@ -1150,6 +1150,14 @@ class EventVar(ObjectVar):
class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
_var_value: EventSpec = dataclasses.field(default=None) # type: ignore
def __hash__(self) -> int:
"""Get the hash of the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
@ -1209,6 +1217,14 @@ class EventChainVar(FunctionVar):
class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
_var_value: EventChain = dataclasses.field(default=None) # type: ignore
def __hash__(self) -> int:
"""Get the hash of the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.

View File

@ -475,9 +475,6 @@ class Var(Generic[VAR_TYPE]):
if issubclass(output, (ObjectVar, Base)):
return ToObjectOperation.create(self, var_type or dict)
if dataclasses.is_dataclass(output):
return ToObjectOperation.create(self, var_type or dict)
if issubclass(output, FunctionVar):
# if fixed_type is not None and not issubclass(fixed_type, Callable):
# raise TypeError(
@ -488,6 +485,9 @@ class Var(Generic[VAR_TYPE]):
if issubclass(output, NoneVar):
return ToNoneOperation.create(self)
if dataclasses.is_dataclass(output):
return ToObjectOperation.create(self, var_type or dict)
# If we can't determine the first argument, we just replace the _var_type.
if not issubclass(output, Var) or var_type is None:
return dataclasses.replace(