diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2 index efb086ef5..e29deee34 100644 --- a/reflex/.templates/jinja/web/pages/index.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/index.js.jinja2 @@ -8,6 +8,10 @@ {% block export %} export default function Component() { + {% for ref_hook in ref_hooks %} + {{ ref_hook }} + {% endfor %} + {% for hook in hooks %} {{ hook }} {% endfor %} diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index ad970c2f5..ccb375893 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -1,6 +1,10 @@ {% import 'web/pages/utils.js.jinja2' as utils %} export function {{tag_name}} () { + {% for ref_hook in component.get_ref_hooks() %} + {{ ref_hook }} + {% endfor %} + {% for hook in component.get_hooks_internal() %} {{ hook }} {% endfor %} diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 6814159d2..ef0d3e4ca 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -119,6 +119,7 @@ def _compile_page( imports=imports, dynamic_imports=component.get_dynamic_imports(), custom_codes=component.get_custom_code(), + ref_hooks=component.get_ref_hooks(), hooks=component.get_hooks_internal() | component.get_hooks(), render=component.render(), **kwargs, diff --git a/reflex/components/component.py b/reflex/components/component.py index 6b25a4f3a..506031af9 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -74,6 +74,14 @@ class BaseComponent(Base, ABC): The dictionary for template of the component. """ + @abstractmethod + def get_ref_hooks(self) -> set[str]: + """Get the hooks required by refs in this component. + + Returns: + The hooks for the refs. + """ + @abstractmethod def get_hooks_internal(self) -> set[str]: """Get the reflex internal hooks for the component and its children. @@ -636,9 +644,11 @@ class Component(BaseComponent, ABC): ) children = [ - child - if isinstance(child, Component) - else Bare.create(contents=Var.create(child, _var_is_string=True)) + ( + child + if isinstance(child, Component) + else Bare.create(contents=Var.create(child, _var_is_string=True)) + ) for child in children ] @@ -1130,14 +1140,10 @@ class Component(BaseComponent, ABC): Set of internally managed hooks. """ return ( - set( - hook - for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()] - if hook - ) - | self._get_vars_hooks() + self._get_vars_hooks() | self._get_events_hooks() | self._get_special_hooks() + | set(hook for hook in [self._get_mount_lifecycle_hook()] if hook) ) def _get_hooks(self) -> str | None: @@ -1150,6 +1156,19 @@ class Component(BaseComponent, ABC): """ return + def get_ref_hooks(self) -> Set[str]: + """Get the ref hooks for the component and its children. + + Returns: + The ref hooks. + """ + ref_hook = self._get_ref_hook() + hooks = set() if ref_hook is None else {ref_hook} + + for child in self.children: + hooks |= child.get_ref_hooks() + return hooks + def get_hooks_internal(self) -> set[str]: """Get the reflex internal hooks for the component and its children. @@ -1424,9 +1443,9 @@ class CustomComponent(Component): return [ BaseVar( _var_name=name, - _var_type=prop._var_type - if types._isinstance(prop, Var) - else type(prop), + _var_type=( + prop._var_type if types._isinstance(prop, Var) else type(prop) + ), ) for name, prop in self.props.items() ] @@ -1808,6 +1827,14 @@ class StatefulComponent(BaseComponent): ) return trigger_memo + def get_ref_hooks(self) -> set[str]: + """Get the ref hooks for the component and its children. + + Returns: + The ref hooks. + """ + return set() + def get_hooks_internal(self) -> set[str]: """Get the reflex internal hooks for the component and its children.