diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index 8ac952827..ad970c2f5 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -1,7 +1,7 @@ {% import 'web/pages/utils.js.jinja2' as utils %} export function {{tag_name}} () { - {% for hook in component.get_hooks() %} + {% for hook in component.get_hooks_internal() %} {{ hook }} {% endfor %} @@ -9,6 +9,10 @@ export function {{tag_name}} () { {{ hook }} {% endfor %} + {% for hook in component.get_hooks() %} + {{ hook }} + {% endfor %} + return ( {{utils.render(component.render(), indent_width=0)}} ) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 603c8b969..e030d64d4 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -50,7 +50,7 @@ def _compile_app(app_root: Component) -> str: return templates.APP_ROOT.render( imports=utils.compile_imports(app_root.get_imports()), custom_codes=app_root.get_custom_code(), - hooks=app_root.get_hooks(), + hooks=app_root.get_hooks_internal() | app_root.get_hooks(), render=app_root.render(), ) @@ -119,7 +119,7 @@ def _compile_page( imports=imports, dynamic_imports=component.get_dynamic_imports(), custom_codes=component.get_custom_code(), - hooks=component.get_hooks(), + hooks=component.get_hooks_internal() | component.get_hooks(), render=component.render(), **kwargs, ) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 9615e126b..540447344 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -256,7 +256,7 @@ def compile_custom_component( "name": component.tag, "props": props, "render": render.render(), - "hooks": render.get_hooks(), + "hooks": render.get_hooks_internal() | render.get_hooks(), "custom_code": render.get_custom_code(), }, imports, diff --git a/reflex/components/component.py b/reflex/components/component.py index b085ee352..41c7494cc 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -74,6 +74,14 @@ class BaseComponent(Base, ABC): The dictionary for template of the component. """ + @abstractmethod + def get_hooks_internal(self) -> set[str]: + """Get the reflex internal hooks for the component and its children. + + Returns: + The code that should appear just before user-defined hooks. + """ + @abstractmethod def get_hooks(self) -> set[str]: """Get the React hooks for this component. @@ -1141,14 +1149,28 @@ class Component(BaseComponent, ABC): """ return + def get_hooks_internal(self) -> set[str]: + """Get the reflex internal hooks for the component and its children. + + Returns: + The code that should appear just before user-defined hooks. + """ + # Store the code in a set to avoid duplicates. + code = self._get_hooks_internal() + + # Add the hook code for the children. + for child in self.children: + code |= child.get_hooks_internal() + + return code + def get_hooks(self) -> Set[str]: """Get the React hooks for this component and its children. Returns: The code that should appear just before returning the rendered component. """ - # Store the code in a set to avoid duplicates. - code = self._get_hooks_internal() + code = set() # Add the hook code for this component. hooks = self._get_hooks() @@ -1785,6 +1807,14 @@ class StatefulComponent(BaseComponent): ) return trigger_memo + def get_hooks_internal(self) -> set[str]: + """Get the reflex internal hooks for the component and its children. + + Returns: + The code that should appear just before user-defined hooks. + """ + return set() + def get_hooks(self) -> set[str]: """Get the React hooks for this component. @@ -1893,7 +1923,7 @@ class MemoizationLeaf(Component): The memoization leaf """ comp = super().create(*children, **props) - if comp.get_hooks(): + if comp.get_hooks() or comp.get_hooks_internal(): comp._memoization_mode = cls._memoization_mode.copy( update={"disposition": MemoizationDisposition.ALWAYS} ) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 7aa73e7b5..c70398ef4 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -162,7 +162,7 @@ class Form(BaseHTML): props["handle_submit_unique_name"] = "" form = super().create(*children, **props) form.handle_submit_unique_name = md5( - str(form.get_hooks()).encode("utf-8") + str(form.get_hooks_internal().union(form.get_hooks())).encode("utf-8") ).hexdigest() return form diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index fde688f47..1a125cbe3 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -291,8 +291,10 @@ class Markdown(Component): def _get_custom_code(self) -> str | None: hooks = set() - for component in self.component_map.values(): - hooks |= component(_MOCK_ARG).get_hooks() + for _component in self.component_map.values(): + comp = _component(_MOCK_ARG) + hooks |= comp.get_hooks_internal() + hooks |= comp.get_hooks() formatted_hooks = "\n".join(hooks) return f""" function {self._get_component_map_name()} () {{