From 25cfde751a2b3dd95b3ee73961b8092fa8460ab4 Mon Sep 17 00:00:00 2001 From: Lendemor Date: Wed, 11 Dec 2024 01:43:30 +0100 Subject: [PATCH 1/2] allow to declare deps in event signature for memoized event triggers --- reflex/components/component.py | 12 ++++++++++++ reflex/vars/base.py | 16 +++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 75a821ac8..e82b093aa 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -2232,6 +2232,18 @@ class StatefulComponent(BaseComponent): # Calculate Var dependencies accessed by the handler for useCallback dep array. var_deps = ["addEvents", "Event"] + for ev in event.events: + if isinstance(ev, EventSpec): + for arg in ev.args: + var_deps.extend( + { + str(dep) + for a in arg + if a._var_data is not None + for dep in a._var_data.deps + if a._var_data.deps is not None + } + ) for arg in event_args: var_data = arg._get_all_var_data() if var_data is None: diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 941a9d81a..7f2402bdb 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -115,12 +115,16 @@ class VarData: # Hooks that need to be present in the component to render this var hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + # Dependencies of the var + deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple) + def __init__( self, state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, hooks: dict[str, None] | None = None, + deps: list[Var] | None = None, ): """Initialize the var data. @@ -129,6 +133,7 @@ class VarData: field_name: The name of the field in the state. imports: Imports needed to render this var. hooks: Hooks that need to be present in the component to render this var. + deps: Dependencies of the var for useCallback. """ immutable_imports: ImmutableParsedImportDict = tuple( sorted( @@ -139,6 +144,7 @@ class VarData: object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "hooks", tuple(hooks or {})) + object.__setattr__(self, "deps", tuple(deps or [])) def old_school_imports(self) -> ImportDict: """Return the imports as a mutable dict. @@ -184,12 +190,15 @@ class VarData: *(var_data.imports for var_data in all_var_datas) ) - if state or _imports or hooks or field_name: + deps = [dep for var_data in all_var_datas for dep in var_data.deps] + + if state or _imports or hooks or field_name or deps: return VarData( state=state, field_name=field_name, imports=_imports, hooks=hooks, + deps=deps, ) return None @@ -200,7 +209,9 @@ class VarData: Returns: True if any field is set to a non-default value. """ - return bool(self.state or self.imports or self.hooks or self.field_name) + return bool( + self.state or self.imports or self.hooks or self.field_name or self.deps + ) @classmethod def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: @@ -480,7 +491,6 @@ class Var(Generic[VAR_TYPE]): raise TypeError( "The _var_full_name_needs_state_prefix argument is not supported for Var." ) - value_with_replaced = dataclasses.replace( self, _var_type=_var_type or self._var_type, From bf347841b19146290f35bd6fa2f5a7d3172aef5f Mon Sep 17 00:00:00 2001 From: Lendemor Date: Wed, 11 Dec 2024 15:07:44 +0100 Subject: [PATCH 2/2] clean up the code to pass tests --- reflex/components/component.py | 46 +++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index e82b093aa..85232e5dd 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -2196,6 +2196,35 @@ class StatefulComponent(BaseComponent): ] return [var_name] + @staticmethod + def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]: + """Get the dependencies accessed by event triggers. + + Args: + event: The event trigger to extract deps from. + + Returns: + The dependencies accessed by the event triggers. + """ + events: list = [event] + deps = set() + if isinstance(event, EventChain): + events.extend(event.events) + + for ev in events: + if isinstance(ev, EventSpec): + for arg in ev.args: + deps.union( + { + str(dep) + for a in arg + if a._var_data is not None + for dep in a._var_data.deps + if a._var_data.deps is not None + } + ) + return deps + @classmethod def _get_memoized_event_triggers( cls, @@ -2232,18 +2261,11 @@ class StatefulComponent(BaseComponent): # Calculate Var dependencies accessed by the handler for useCallback dep array. var_deps = ["addEvents", "Event"] - for ev in event.events: - if isinstance(ev, EventSpec): - for arg in ev.args: - var_deps.extend( - { - str(dep) - for a in arg - if a._var_data is not None - for dep in a._var_data.deps - if a._var_data.deps is not None - } - ) + + # Get deps from event trigger var data. + var_deps.extend(cls._get_deps_from_event_trigger(event)) + + # Get deps from hooks. for arg in event_args: var_data = arg._get_all_var_data() if var_data is None: