diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index b04a78781..dab3707a7 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -1,11 +1,13 @@ {% import 'web/pages/utils.js.jinja2' as utils %} +{% set all_hooks = component._get_all_hooks().items() %} + export function {{tag_name}} () { - {% for hook in component._get_all_hooks_internal() %} + {% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.INTERNAL %} {{ hook }} {% endfor %} - {% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %} + {% for hook, data in all_hooks if not data or (not data.position or data.position == const.hook_position.PRE_TRIGGER) %} {{ hook }} {% endfor %} @@ -13,7 +15,7 @@ export function {{tag_name}} () { {{ hook }} {% endfor %} - {% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %} + {% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.POST_TRIGGER %} {{ hook }} {% endfor %} diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index e1b5d9237..7cd225deb 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -9,6 +9,7 @@ from reflex.components.tags import Tag from reflex.components.tags.tagless import Tagless from reflex.utils.imports import ParsedImportDict from reflex.vars import BooleanVar, ObjectVar, Var +from reflex.vars.base import VarData class Bare(Component): @@ -32,7 +33,7 @@ class Bare(Component): contents = str(contents) if contents is not None else "" return cls(contents=contents) # type: ignore - def _get_all_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, VarData | None]: """Include the hooks for the component. Returns: @@ -43,7 +44,7 @@ class Bare(Component): hooks |= self.contents._var_value._get_all_hooks_internal() return hooks - def _get_all_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, VarData | None]: """Include the hooks for the component. Returns: diff --git a/reflex/components/component.py b/reflex/components/component.py index 34800ab6e..f261ae4d3 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -104,7 +104,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def _get_all_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, VarData | None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -112,7 +112,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def _get_all_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, VarData | None]: """Get the React hooks for this component. Returns: @@ -1338,7 +1338,7 @@ class Component(BaseComponent, ABC): """ _imports = {} - if self._get_ref_hook(): + if self._get_ref_hook() is not None: # Handle hooks needed for attaching react refs to DOM nodes. _imports.setdefault("react", set()).add(ImportVar(tag="useRef")) _imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add( @@ -1454,7 +1454,7 @@ class Component(BaseComponent, ABC): }} }}, []);""" - def _get_ref_hook(self) -> str | None: + def _get_ref_hook(self) -> Var | None: """Generate the ref hook for the component. Returns: @@ -1462,11 +1462,12 @@ class Component(BaseComponent, ABC): """ ref = self.get_ref() if ref is not None: - return ( - f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};" + return Var( + f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};", + _var_data=VarData(position=Hooks.HookPosition.INTERNAL), ) - def _get_vars_hooks(self) -> dict[str, None]: + def _get_vars_hooks(self) -> dict[str, VarData | None]: """Get the hooks required by vars referenced in this component. Returns: @@ -1479,27 +1480,38 @@ class Component(BaseComponent, ABC): vars_hooks.update( var_data.hooks if isinstance(var_data.hooks, dict) - else {k: None for k in var_data.hooks} + else { + k: VarData(position=Hooks.HookPosition.INTERNAL) + for k in var_data.hooks + } ) return vars_hooks - def _get_events_hooks(self) -> dict[str, None]: + def _get_events_hooks(self) -> dict[str, VarData | None]: """Get the hooks required by events referenced in this component. Returns: The hooks for the events. """ - return {Hooks.EVENTS: None} if self.event_triggers else {} + return ( + {Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)} + if self.event_triggers + else {} + ) - def _get_special_hooks(self) -> dict[str, None]: + def _get_special_hooks(self) -> dict[str, VarData | None]: """Get the hooks required by special actions referenced in this component. Returns: The hooks for special actions. """ - return {Hooks.AUTOFOCUS: None} if self.autofocus else {} + return ( + {Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)} + if self.autofocus + else {} + ) - def _get_hooks_internal(self) -> dict[str, None]: + def _get_hooks_internal(self) -> dict[str, VarData | None]: """Get the React hooks for this component managed by the framework. Downstream components should NOT override this method to avoid breaking @@ -1510,7 +1522,7 @@ class Component(BaseComponent, ABC): """ return { **{ - hook: None + str(hook): VarData(position=Hooks.HookPosition.INTERNAL) for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()] if hook is not None }, @@ -1559,7 +1571,7 @@ class Component(BaseComponent, ABC): """ return - def _get_all_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, VarData | None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -1574,7 +1586,7 @@ class Component(BaseComponent, ABC): return code - def _get_all_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, VarData | None]: """Get the React hooks for this component and its children. Returns: @@ -1582,6 +1594,9 @@ class Component(BaseComponent, ABC): """ code = {} + # Add the internal hooks for this component. + code.update(self._get_all_hooks_internal()) + # Add the hook code for this component. hooks = self._get_hooks() if hooks is not None: @@ -2277,7 +2292,7 @@ class StatefulComponent(BaseComponent): ) return trigger_memo - def _get_all_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, VarData | None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -2285,7 +2300,7 @@ class StatefulComponent(BaseComponent): """ return {} - def _get_all_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, VarData | None]: """Get the React hooks for this component. Returns: diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 7ca55f4dd..d98c04d76 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -135,6 +135,7 @@ class Hooks(SimpleNamespace): class HookPosition(enum.Enum): """The position of the hook in the component.""" + INTERNAL = "internal" PRE_TRIGGER = "pre_trigger" POST_TRIGGER = "post_trigger" diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py index 1982b3dfe..e37ceb14c 100644 --- a/reflex/experimental/client_state.py +++ b/reflex/experimental/client_state.py @@ -105,7 +105,7 @@ class ClientStateVar(Var): else: default_var = default setter_name = f"set{var_name.capitalize()}" - hooks = { + hooks: dict[str, VarData | None] = { f"const [{var_name}, {setter_name}] = useState({default_var!s})": None, } imports = { diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 094a478c8..f2b9fb33d 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -127,7 +127,7 @@ class VarData: state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, - hooks: dict[str, None] | None = None, + hooks: dict[str, VarData | None] | None = None, deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, ): @@ -194,7 +194,9 @@ class VarData: (var_data.state for var_data in all_var_datas if var_data.state), "" ) - hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks} + hooks: dict[str, VarData | None] = { + hook: None for var_data in all_var_datas for hook in var_data.hooks + } _imports = imports.merge_imports( *(var_data.imports for var_data in all_var_datas)