From 144442176687a55e1add5bd115bc6b0e9c73fa96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Fri, 13 Dec 2024 13:28:55 -0800 Subject: [PATCH] add deps and position field in VarData (#4518) * fix memoized event trigger order * allow to declare deps in event signature for memoized event triggers * clean up the code to pass tests * handle position of hooks * clean up code * revert test changes * add future annotations * remove non-necessary stuff * reuse data_callback name if already set during first call to add_hooks * remove HookVar and use Var with VarData instead * remove test change * readd removed line * fix order of stmt for cleaner code * fix typing * something broke during the merge I guess * remove hack and pass proper const for position * oops, bad syntax in jinja * use "hook_position" instead of "hook_positions" match the name of the enum --------- Co-authored-by: Masen Furer --- .../web/pages/stateful_component.js.jinja2 | 6 ++- reflex/compiler/templates.py | 1 + reflex/components/component.py | 52 +++++++++++++++---- reflex/components/core/clipboard.py | 18 ++++--- reflex/components/core/clipboard.pyi | 2 +- reflex/components/datadisplay/dataeditor.py | 7 ++- reflex/constants/compiler.py | 6 +++ reflex/vars/base.py | 50 ++++++++++++++++-- 8 files changed, 115 insertions(+), 27 deletions(-) diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index 4a40ef545..b04a78781 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -5,11 +5,15 @@ export function {{tag_name}} () { {{ hook }} {% endfor %} + {% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %} + {{ hook }} + {% endfor %} + {% for hook in memo_trigger_hooks %} {{ hook }} {% endfor %} - {% for hook in component._get_all_hooks() %} + {% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %} {{ hook }} {% endfor %} diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index c868a0cbb..631aa4ee2 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -45,6 +45,7 @@ class ReflexJinjaEnvironment(Environment): "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL, "update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL, "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL, + "hook_position": constants.Hooks.HookPosition, } diff --git a/reflex/components/component.py b/reflex/components/component.py index 85458f16c..46318a30b 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1368,7 +1368,9 @@ class Component(BaseComponent, ABC): if user_hooks_data is not None: other_imports.append(user_hooks_data.imports) other_imports.extend( - hook_imports for hook_imports in self._get_added_hooks().values() + hook_vardata.imports + for hook_vardata in self._get_added_hooks().values() + if hook_vardata is not None ) return imports.merge_imports(_imports, *other_imports) @@ -1516,7 +1518,7 @@ class Component(BaseComponent, ABC): **self._get_special_hooks(), } - def _get_added_hooks(self) -> dict[str, ImportDict]: + def _get_added_hooks(self) -> dict[str, VarData | None]: """Get the hooks added via `add_hooks` method. Returns: @@ -1525,17 +1527,15 @@ class Component(BaseComponent, ABC): code = {} def extract_var_hooks(hook: Var): - _imports = {} var_data = VarData.merge(hook._get_all_var_data()) if var_data is not None: for sub_hook in var_data.hooks: - code[sub_hook] = {} - if var_data.imports: - _imports = var_data.imports + code[sub_hook] = None + if str(hook) in code: - code[str(hook)] = imports.merge_imports(code[str(hook)], _imports) + code[str(hook)] = VarData.merge(var_data, code[str(hook)]) else: - code[str(hook)] = _imports + code[str(hook)] = var_data # Add the hook code from add_hooks for each parent class (this is reversed to preserve # the order of the hooks in the final output) @@ -1544,7 +1544,7 @@ class Component(BaseComponent, ABC): if isinstance(hook, Var): extract_var_hooks(hook) else: - code[hook] = {} + code[hook] = None return code @@ -1586,8 +1586,8 @@ class Component(BaseComponent, ABC): if hooks is not None: code[hooks] = None - for hook in self._get_added_hooks(): - code[hook] = None + for hook, var_data in self._get_added_hooks().items(): + code[hook] = var_data # Add the hook code for the children. for child in self.children: @@ -2189,6 +2189,31 @@ class StatefulComponent(BaseComponent): ] return [var_name] + @staticmethod + def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]: + """Get the dependencies accessed by event triggers. + + Args: + event: The event trigger to extract deps from. + + Returns: + The dependencies accessed by the event triggers. + """ + events: list = [event] + deps = set() + + if isinstance(event, EventChain): + events.extend(event.events) + + for ev in events: + if isinstance(ev, EventSpec): + for arg in ev.args: + for a in arg: + var_datas = VarData.merge(a._get_all_var_data()) + if var_datas and var_datas.deps is not None: + deps |= {str(dep) for dep in var_datas.deps} + return deps + @classmethod def _get_memoized_event_triggers( cls, @@ -2225,6 +2250,11 @@ class StatefulComponent(BaseComponent): # Calculate Var dependencies accessed by the handler for useCallback dep array. var_deps = ["addEvents", "Event"] + + # Get deps from event trigger var data. + var_deps.extend(cls._get_deps_from_event_trigger(event)) + + # Get deps from hooks. for arg in event_args: var_data = arg._get_all_var_data() if var_data is None: diff --git a/reflex/components/core/clipboard.py b/reflex/components/core/clipboard.py index 938cd13c0..644de80d0 100644 --- a/reflex/components/core/clipboard.py +++ b/reflex/components/core/clipboard.py @@ -6,11 +6,12 @@ from typing import Dict, List, Tuple, Union from reflex.components.base.fragment import Fragment from reflex.components.tags.tag import Tag +from reflex.constants.compiler import Hooks from reflex.event import EventChain, EventHandler, passthrough_event_spec from reflex.utils.format import format_prop, wrap from reflex.utils.imports import ImportVar from reflex.vars import get_unique_variable_name -from reflex.vars.base import Var +from reflex.vars.base import Var, VarData class Clipboard(Fragment): @@ -72,7 +73,7 @@ class Clipboard(Fragment): ), } - def add_hooks(self) -> list[str]: + def add_hooks(self) -> list[str | Var[str]]: """Add hook to register paste event listener. Returns: @@ -83,13 +84,14 @@ class Clipboard(Fragment): return [] if isinstance(on_paste, EventChain): on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(") + hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})" + return [ - "usePasteHandler(%s, %s, %s)" - % ( - str(self.targets), - str(self.on_paste_event_actions), - on_paste, - ) + Var( + hook_expr, + _var_type="str", + _var_data=VarData(position=Hooks.HookPosition.POST_TRIGGER), + ), ] diff --git a/reflex/components/core/clipboard.pyi b/reflex/components/core/clipboard.pyi index 69e0e866d..328554f2a 100644 --- a/reflex/components/core/clipboard.pyi +++ b/reflex/components/core/clipboard.pyi @@ -71,6 +71,6 @@ class Clipboard(Fragment): ... def add_imports(self) -> dict[str, ImportVar]: ... - def add_hooks(self) -> list[str]: ... + def add_hooks(self) -> list[str | Var[str]]: ... clipboard = Clipboard.create diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index 79813205f..2b80720ea 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -339,8 +339,11 @@ class DataEditor(NoSSRComponent): editor_id = get_unique_variable_name() # Define the name of the getData callback associated with this component and assign to get_cell_content. - data_callback = f"getData_{editor_id}" - self.get_cell_content = Var(_js_expr=data_callback) # type: ignore + if self.get_cell_content is not None: + data_callback = self.get_cell_content._js_expr + else: + data_callback = f"getData_{editor_id}" + self.get_cell_content = Var(_js_expr=data_callback) # type: ignore code = [f"function {data_callback}([col, row])" "{"] diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index b7ffef161..7ca55f4dd 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -132,6 +132,12 @@ class Hooks(SimpleNamespace): } })""" + class HookPosition(enum.Enum): + """The position of the hook in the component.""" + + PRE_TRIGGER = "pre_trigger" + POST_TRIGGER = "post_trigger" + class MemoizationDisposition(enum.Enum): """The conditions under which a component should be memoized.""" diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 3ff3c52de..094a478c8 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, from reflex import constants from reflex.base import Base -from reflex.utils import console, imports, serializers, types +from reflex.constants.compiler import Hooks +from reflex.utils import console, exceptions, imports, serializers, types from reflex.utils.exceptions import ( VarAttributeError, VarDependencyError, @@ -115,12 +116,20 @@ class VarData: # Hooks that need to be present in the component to render this var hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + # Dependencies of the var + deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple) + + # Position of the hook in the component + position: Hooks.HookPosition | None = None + def __init__( self, state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, hooks: dict[str, None] | None = None, + deps: list[Var] | None = None, + position: Hooks.HookPosition | None = None, ): """Initialize the var data. @@ -129,6 +138,8 @@ class VarData: field_name: The name of the field in the state. imports: Imports needed to render this var. hooks: Hooks that need to be present in the component to render this var. + deps: Dependencies of the var for useCallback. + position: Position of the hook in the component. """ immutable_imports: ImmutableParsedImportDict = tuple( sorted( @@ -139,6 +150,8 @@ class VarData: object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "hooks", tuple(hooks or {})) + object.__setattr__(self, "deps", tuple(deps or [])) + object.__setattr__(self, "position", position or None) def old_school_imports(self) -> ImportDict: """Return the imports as a mutable dict. @@ -154,6 +167,9 @@ class VarData: Args: *all: The var data objects to merge. + Raises: + ReflexError: If trying to merge VarData with different positions. + Returns: The merged var data object. @@ -184,12 +200,32 @@ class VarData: *(var_data.imports for var_data in all_var_datas) ) - if state or _imports or hooks or field_name: + deps = [dep for var_data in all_var_datas for dep in var_data.deps] + + positions = list( + { + var_data.position + for var_data in all_var_datas + if var_data.position is not None + } + ) + if positions: + if len(positions) > 1: + raise exceptions.ReflexError( + f"Cannot merge var data with different positions: {positions}" + ) + position = positions[0] + else: + position = None + + if state or _imports or hooks or field_name or deps or position: return VarData( state=state, field_name=field_name, imports=_imports, hooks=hooks, + deps=deps, + position=position, ) return None @@ -200,7 +236,14 @@ class VarData: Returns: True if any field is set to a non-default value. """ - return bool(self.state or self.imports or self.hooks or self.field_name) + return bool( + self.state + or self.imports + or self.hooks + or self.field_name + or self.deps + or self.position + ) @classmethod def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: @@ -480,7 +523,6 @@ class Var(Generic[VAR_TYPE]): raise TypeError( "The _var_full_name_needs_state_prefix argument is not supported for Var." ) - value_with_replaced = dataclasses.replace( self, _var_type=_var_type or self._var_type,