From b89a18f63285ee5f8da7456707a75850fb409fab Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 15 Mar 2024 16:16:09 -0700 Subject: [PATCH] Separate `get_hooks` and `get_hooks_internal` for stable output (#2710) * Separate `get_hooks` and `get_hooks_internal` for stable output When downstream component wrappers depend on State when writing hooks, they need to be assured that all internal hooks (events, var hooks, memoized handlers, etc) will be rendered prior to user-defined hooks. This also makes it less likely for downstream components to feel the need to overwrite `get_hooks` (no underscore) directly and break internal functioning of Reflex components. * Include internal hooks in AppWrap and Page * Apply get_hooks_internal in a few more places --- .../web/pages/stateful_component.js.jinja2 | 6 +++- reflex/compiler/compiler.py | 4 +-- reflex/compiler/utils.py | 2 +- reflex/components/component.py | 36 +++++++++++++++++-- reflex/components/el/elements/forms.py | 2 +- reflex/components/markdown/markdown.py | 6 ++-- 6 files changed, 46 insertions(+), 10 deletions(-) diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index 8ac952827..ad970c2f5 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -1,7 +1,7 @@ {% import 'web/pages/utils.js.jinja2' as utils %} export function {{tag_name}} () { - {% for hook in component.get_hooks() %} + {% for hook in component.get_hooks_internal() %} {{ hook }} {% endfor %} @@ -9,6 +9,10 @@ export function {{tag_name}} () { {{ hook }} {% endfor %} + {% for hook in component.get_hooks() %} + {{ hook }} + {% endfor %} + return ( {{utils.render(component.render(), indent_width=0)}} ) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 603c8b969..e030d64d4 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -50,7 +50,7 @@ def _compile_app(app_root: Component) -> str: return templates.APP_ROOT.render( imports=utils.compile_imports(app_root.get_imports()), custom_codes=app_root.get_custom_code(), - hooks=app_root.get_hooks(), + hooks=app_root.get_hooks_internal() | app_root.get_hooks(), render=app_root.render(), ) @@ -119,7 +119,7 @@ def _compile_page( imports=imports, dynamic_imports=component.get_dynamic_imports(), custom_codes=component.get_custom_code(), - hooks=component.get_hooks(), + hooks=component.get_hooks_internal() | component.get_hooks(), render=component.render(), **kwargs, ) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 9615e126b..540447344 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -256,7 +256,7 @@ def compile_custom_component( "name": component.tag, "props": props, "render": render.render(), - "hooks": render.get_hooks(), + "hooks": render.get_hooks_internal() | render.get_hooks(), "custom_code": render.get_custom_code(), }, imports, diff --git a/reflex/components/component.py b/reflex/components/component.py index b085ee352..41c7494cc 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -74,6 +74,14 @@ class BaseComponent(Base, ABC): The dictionary for template of the component. """ + @abstractmethod + def get_hooks_internal(self) -> set[str]: + """Get the reflex internal hooks for the component and its children. + + Returns: + The code that should appear just before user-defined hooks. + """ + @abstractmethod def get_hooks(self) -> set[str]: """Get the React hooks for this component. @@ -1141,14 +1149,28 @@ class Component(BaseComponent, ABC): """ return + def get_hooks_internal(self) -> set[str]: + """Get the reflex internal hooks for the component and its children. + + Returns: + The code that should appear just before user-defined hooks. + """ + # Store the code in a set to avoid duplicates. + code = self._get_hooks_internal() + + # Add the hook code for the children. + for child in self.children: + code |= child.get_hooks_internal() + + return code + def get_hooks(self) -> Set[str]: """Get the React hooks for this component and its children. Returns: The code that should appear just before returning the rendered component. """ - # Store the code in a set to avoid duplicates. - code = self._get_hooks_internal() + code = set() # Add the hook code for this component. hooks = self._get_hooks() @@ -1785,6 +1807,14 @@ class StatefulComponent(BaseComponent): ) return trigger_memo + def get_hooks_internal(self) -> set[str]: + """Get the reflex internal hooks for the component and its children. + + Returns: + The code that should appear just before user-defined hooks. + """ + return set() + def get_hooks(self) -> set[str]: """Get the React hooks for this component. @@ -1893,7 +1923,7 @@ class MemoizationLeaf(Component): The memoization leaf """ comp = super().create(*children, **props) - if comp.get_hooks(): + if comp.get_hooks() or comp.get_hooks_internal(): comp._memoization_mode = cls._memoization_mode.copy( update={"disposition": MemoizationDisposition.ALWAYS} ) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 7aa73e7b5..c70398ef4 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -162,7 +162,7 @@ class Form(BaseHTML): props["handle_submit_unique_name"] = "" form = super().create(*children, **props) form.handle_submit_unique_name = md5( - str(form.get_hooks()).encode("utf-8") + str(form.get_hooks_internal().union(form.get_hooks())).encode("utf-8") ).hexdigest() return form diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index fde688f47..1a125cbe3 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -291,8 +291,10 @@ class Markdown(Component): def _get_custom_code(self) -> str | None: hooks = set() - for component in self.component_map.values(): - hooks |= component(_MOCK_ARG).get_hooks() + for _component in self.component_map.values(): + comp = _component(_MOCK_ARG) + hooks |= comp.get_hooks_internal() + hooks |= comp.get_hooks() formatted_hooks = "\n".join(hooks) return f""" function {self._get_component_map_name()} () {{