diff --git a/reflex/components/component.py b/reflex/components/component.py index 3372357ab..311240dc8 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -2198,6 +2198,35 @@ class StatefulComponent(BaseComponent): ] return [var_name] + @staticmethod + def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]: + """Get the dependencies accessed by event triggers. + + Args: + event: The event trigger to extract deps from. + + Returns: + The dependencies accessed by the event triggers. + """ + events: list = [event] + deps = set() + if isinstance(event, EventChain): + events.extend(event.events) + + for ev in events: + if isinstance(ev, EventSpec): + for arg in ev.args: + deps.union( + { + str(dep) + for a in arg + if a._var_data is not None + for dep in a._var_data.deps + if a._var_data.deps is not None + } + ) + return deps + @classmethod def _get_memoized_event_triggers( cls, @@ -2234,6 +2263,11 @@ class StatefulComponent(BaseComponent): # Calculate Var dependencies accessed by the handler for useCallback dep array. var_deps = ["addEvents", "Event"] + + # Get deps from event trigger var data. + var_deps.extend(cls._get_deps_from_event_trigger(event)) + + # Get deps from hooks. for arg in event_args: var_data = arg._get_all_var_data() if var_data is None: diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 712d6e868..4136945cf 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -116,6 +116,9 @@ class VarData: # Hooks that need to be present in the component to render this var hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + # Dependencies of the var + deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple) + # Position of the hook in the component position: Hooks.HookPosition | None = None @@ -125,6 +128,7 @@ class VarData: field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, hooks: dict[str, None] | None = None, + deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, ): """Initialize the var data. @@ -134,6 +138,7 @@ class VarData: field_name: The name of the field in the state. imports: Imports needed to render this var. hooks: Hooks that need to be present in the component to render this var. + deps: Dependencies of the var for useCallback. position: Position of the hook in the component. """ immutable_imports: ImmutableParsedImportDict = tuple( @@ -146,6 +151,7 @@ class VarData: object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "hooks", tuple(hooks or {})) object.__setattr__(self, "position", position or None) + object.__setattr__(self, "deps", tuple(deps or [])) def old_school_imports(self) -> ImportDict: """Return the imports as a mutable dict. @@ -194,6 +200,8 @@ class VarData: *(var_data.imports for var_data in all_var_datas) ) + deps = [dep for var_data in all_var_datas for dep in var_data.deps] + positions = list( { var_data.position @@ -210,12 +218,13 @@ class VarData: else: position = None - if state or _imports or hooks or field_name or position: + if state or _imports or hooks or field_name or deps or position: return VarData( state=state, field_name=field_name, imports=_imports, hooks=hooks, + deps=deps, position=position, ) @@ -228,7 +237,12 @@ class VarData: True if any field is set to a non-default value. """ return bool( - self.state or self.imports or self.hooks or self.field_name or self.position + self.state + or self.imports + or self.hooks + or self.field_name + or self.deps + or self.position ) @classmethod @@ -509,7 +523,6 @@ class Var(Generic[VAR_TYPE]): raise TypeError( "The _var_full_name_needs_state_prefix argument is not supported for Var." ) - value_with_replaced = dataclasses.replace( self, _var_type=_var_type or self._var_type,