From c54b736254d2d5f6bae3c3b211ff608598b53499 Mon Sep 17 00:00:00 2001 From: Lendemor Date: Wed, 11 Dec 2024 19:14:33 +0100 Subject: [PATCH] handle position of hooks --- .../web/pages/stateful_component.js.jinja2 | 6 ++- reflex/app.py | 2 + reflex/components/component.py | 34 ++++++++------ reflex/components/core/clipboard.py | 12 +++-- reflex/components/core/clipboard.pyi | 2 +- reflex/constants/compiler.py | 6 +++ reflex/vars/base.py | 35 +++++++++++++-- reflex/vars/hooks.py | 44 +++++++++++++++++++ tests/units/components/test_component.py | 5 ++- 9 files changed, 119 insertions(+), 27 deletions(-) create mode 100644 reflex/vars/hooks.py diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index 206472d2e..f993502d2 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -5,7 +5,7 @@ export function {{tag_name}} () { {{ hook }} {% endfor %} - {% for hook in component._get_all_hooks() %} + {% for hook, data in component._get_all_hooks().items() if not data.position or data.position == positions.PRE_TRIGGER %} {{ hook }} {% endfor %} @@ -13,6 +13,10 @@ export function {{tag_name}} () { {{ hook }} {% endfor %} + {% for hook,data in component._get_all_hooks().items() if data.position and data.position == positions.POST_TRIGGER %} + {{ hook }} + {% endfor %} + return ( {{utils.render(component.render(), indent_width=0)}} ) diff --git a/reflex/app.py b/reflex/app.py index eb453bd0b..6cb2c7530 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -56,6 +56,7 @@ from reflex.components.component import ( Component, ComponentStyle, evaluate_style_namespaces, + memo, ) from reflex.components.core.banner import connection_pulser, connection_toaster from reflex.components.core.breakpoints import set_breakpoints @@ -162,6 +163,7 @@ def default_overlay_component() -> Component: ) +@memo def default_error_boundary(*children: Component) -> Component: """Default error_boundary attribute for App. diff --git a/reflex/components/component.py b/reflex/components/component.py index 75a821ac8..26ddc058f 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -24,6 +24,7 @@ from typing import ( ) import reflex.state +from reflex import constants from reflex.base import Base from reflex.compiler.templates import STATEFUL_COMPONENT from reflex.components.core.breakpoints import Breakpoints @@ -69,6 +70,7 @@ from reflex.vars.base import ( cached_property_no_lock, ) from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar +from reflex.vars.hooks import HookVar from reflex.vars.number import ternary_operation from reflex.vars.object import ObjectVar from reflex.vars.sequence import LiteralArrayVar @@ -1369,7 +1371,9 @@ class Component(BaseComponent, ABC): if user_hooks_data is not None: other_imports.append(user_hooks_data.imports) other_imports.extend( - hook_imports for hook_imports in self._get_added_hooks().values() + hook_vardata.imports + for hook_vardata in self._get_added_hooks().values() + if hook_vardata is not None ) return imports.merge_imports(_imports, *other_imports) @@ -1523,7 +1527,7 @@ class Component(BaseComponent, ABC): **self._get_special_hooks(), } - def _get_added_hooks(self) -> dict[str, ImportDict]: + def _get_added_hooks(self) -> dict[str, VarData]: """Get the hooks added via `add_hooks` method. Returns: @@ -1532,17 +1536,19 @@ class Component(BaseComponent, ABC): code = {} def extract_var_hooks(hook: Var): - _imports = {} var_data = VarData.merge(hook._get_all_var_data()) if var_data is not None: for sub_hook in var_data.hooks: - code[sub_hook] = {} - if var_data.imports: - _imports = var_data.imports + code[sub_hook] = None + if str(hook) in code: - code[str(hook)] = imports.merge_imports(code[str(hook)], _imports) + code[str(hook)] = VarData.merge(var_data, code[str(hook)]) + elif isinstance(hook, HookVar): + code[str(hook)] = VarData.merge( + var_data, VarData(position=hook.position) + ) else: - code[str(hook)] = _imports + code[str(hook)] = var_data # Add the hook code from add_hooks for each parent class (this is reversed to preserve # the order of the hooks in the final output) @@ -1551,7 +1557,9 @@ class Component(BaseComponent, ABC): if isinstance(hook, Var): extract_var_hooks(hook) else: - code[hook] = {} + if isinstance(hook, str): + hook = HookVar.create(hook) + code[hook] = VarData() return code @@ -1593,8 +1601,8 @@ class Component(BaseComponent, ABC): if hooks is not None: code[hooks] = None - for hook in self._get_added_hooks(): - code[hook] = None + for hook, var_data in self._get_added_hooks().items(): + code[hook] = var_data # Add the hook code for the children. for child in self.children: @@ -2168,6 +2176,7 @@ class StatefulComponent(BaseComponent): tag_name=tag_name, memo_trigger_hooks=memo_trigger_hooks, component=component, + positions=constants.Hooks.HookPosition, ) @staticmethod @@ -2244,10 +2253,9 @@ class StatefulComponent(BaseComponent): imports={"react": [ImportVar(tag="useCallback")]}, ), ) - # Store the memoized function name and hook code for this event trigger. trigger_memo[event_trigger] = ( - Var(_js_expr=memo_name)._replace( + HookVar(_js_expr=memo_name)._replace( _var_type=EventChain, merge_var_data=memo_var_data ), f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])", diff --git a/reflex/components/core/clipboard.py b/reflex/components/core/clipboard.py index 938cd13c0..a0c0b7914 100644 --- a/reflex/components/core/clipboard.py +++ b/reflex/components/core/clipboard.py @@ -6,11 +6,13 @@ from typing import Dict, List, Tuple, Union from reflex.components.base.fragment import Fragment from reflex.components.tags.tag import Tag +from reflex.constants.compiler import Hooks from reflex.event import EventChain, EventHandler, passthrough_event_spec from reflex.utils.format import format_prop, wrap from reflex.utils.imports import ImportVar from reflex.vars import get_unique_variable_name from reflex.vars.base import Var +from reflex.vars.hooks import HookVar class Clipboard(Fragment): @@ -72,7 +74,7 @@ class Clipboard(Fragment): ), } - def add_hooks(self) -> list[str]: + def add_hooks(self) -> list[str | Var]: """Add hook to register paste event listener. Returns: @@ -83,13 +85,9 @@ class Clipboard(Fragment): return [] if isinstance(on_paste, EventChain): on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(") + hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})" return [ - "usePasteHandler(%s, %s, %s)" - % ( - str(self.targets), - str(self.on_paste_event_actions), - on_paste, - ) + HookVar.create(hook_expr, _position=Hooks.HookPosition.POST_TRIGGER), ] diff --git a/reflex/components/core/clipboard.pyi b/reflex/components/core/clipboard.pyi index 69e0e866d..761467973 100644 --- a/reflex/components/core/clipboard.pyi +++ b/reflex/components/core/clipboard.pyi @@ -71,6 +71,6 @@ class Clipboard(Fragment): ... def add_imports(self) -> dict[str, ImportVar]: ... - def add_hooks(self) -> list[str]: ... + def add_hooks(self) -> list[str | Var]: ... clipboard = Clipboard.create diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index b7ffef161..7ca55f4dd 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -132,6 +132,12 @@ class Hooks(SimpleNamespace): } })""" + class HookPosition(enum.Enum): + """The position of the hook in the component.""" + + PRE_TRIGGER = "pre_trigger" + POST_TRIGGER = "post_trigger" + class MemoizationDisposition(enum.Enum): """The conditions under which a component should be memoized.""" diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 941a9d81a..712d6e868 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, from reflex import constants from reflex.base import Base -from reflex.utils import console, imports, serializers, types +from reflex.constants.compiler import Hooks +from reflex.utils import console, exceptions, imports, serializers, types from reflex.utils.exceptions import ( VarAttributeError, VarDependencyError, @@ -115,12 +116,16 @@ class VarData: # Hooks that need to be present in the component to render this var hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + # Position of the hook in the component + position: Hooks.HookPosition | None = None + def __init__( self, state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, hooks: dict[str, None] | None = None, + position: Hooks.HookPosition | None = None, ): """Initialize the var data. @@ -129,6 +134,7 @@ class VarData: field_name: The name of the field in the state. imports: Imports needed to render this var. hooks: Hooks that need to be present in the component to render this var. + position: Position of the hook in the component. """ immutable_imports: ImmutableParsedImportDict = tuple( sorted( @@ -139,6 +145,7 @@ class VarData: object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "hooks", tuple(hooks or {})) + object.__setattr__(self, "position", position or None) def old_school_imports(self) -> ImportDict: """Return the imports as a mutable dict. @@ -154,6 +161,9 @@ class VarData: Args: *all: The var data objects to merge. + Raises: + ReflexError: If trying to merge VarData with different positions. + Returns: The merged var data object. @@ -184,12 +194,29 @@ class VarData: *(var_data.imports for var_data in all_var_datas) ) - if state or _imports or hooks or field_name: + positions = list( + { + var_data.position + for var_data in all_var_datas + if var_data.position is not None + } + ) + if positions: + if len(positions) > 1: + raise exceptions.ReflexError( + f"Cannot merge var data with different positions: {positions}" + ) + position = positions[0] + else: + position = None + + if state or _imports or hooks or field_name or position: return VarData( state=state, field_name=field_name, imports=_imports, hooks=hooks, + position=position, ) return None @@ -200,7 +227,9 @@ class VarData: Returns: True if any field is set to a non-default value. """ - return bool(self.state or self.imports or self.hooks or self.field_name) + return bool( + self.state or self.imports or self.hooks or self.field_name or self.position + ) @classmethod def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: diff --git a/reflex/vars/hooks.py b/reflex/vars/hooks.py new file mode 100644 index 000000000..7a1568778 --- /dev/null +++ b/reflex/vars/hooks.py @@ -0,0 +1,44 @@ +"""A module for hooks-related Var.""" + +import dataclasses +import sys + +from reflex.constants import Hooks + +from .base import Var, VarData + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class HookVar(Var): + """A var class for representing a hook.""" + + position: Hooks.HookPosition | None = None + + @classmethod + def create( + cls, + _hook_expr: str, + _var_data: VarData | None = None, + _position: Hooks.HookPosition | None = None, + ): + """Create a hook var. + + Args: + _hook_expr: The hook expression. + _position: The position of the hook in the component. + + Returns: + The hook var. + """ + hook_var = cls( + _js_expr=_hook_expr, + _var_type="str", + _var_data=_var_data, + position=_position, + ) + # print("HookVar.create", _hook_expr, hook_var.position) + return hook_var diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index e2b035a8f..e375be414 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -31,6 +31,7 @@ from reflex.utils.exceptions import EventFnArgMismatch from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var +from reflex.vars.hooks import HookVar @pytest.fixture @@ -2078,10 +2079,10 @@ def test_component_add_hooks_var(): ] assert list(HookComponent()._get_all_hooks()) == [ - "const hook3 = useRef(null)", + HookVar.create("const hook3 = useRef(null)"), "const hook1 = 42", "const hook2 = 43", - "useEffect(() => () => {}, [])", + HookVar.create("useEffect(() => () => {}, [])"), ] imports = HookComponent()._get_all_imports() assert len(imports) == 1