use position in vardata to mark internal hooks

This commit is contained in:
Lendemor 2024-12-17 17:18:56 +01:00
parent d7956c19d3
commit 95222e49f3
6 changed files with 47 additions and 26 deletions

View File

@ -1,11 +1,13 @@
{% import 'web/pages/utils.js.jinja2' as utils %} {% import 'web/pages/utils.js.jinja2' as utils %}
{% set all_hooks = component._get_all_hooks().items() %}
export function {{tag_name}} () { export function {{tag_name}} () {
{% for hook in component._get_all_hooks_internal() %} {% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.INTERNAL %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %} {% for hook, data in all_hooks if not data or (not data.position or data.position == const.hook_position.PRE_TRIGGER) %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
@ -13,7 +15,7 @@ export function {{tag_name}} () {
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %} {% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.POST_TRIGGER %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}

View File

@ -9,6 +9,7 @@ from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless from reflex.components.tags.tagless import Tagless
from reflex.utils.imports import ParsedImportDict from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData
class Bare(Component): class Bare(Component):
@ -32,7 +33,7 @@ class Bare(Component):
contents = str(contents) if contents is not None else "" contents = str(contents) if contents is not None else ""
return cls(contents=contents) # type: ignore return cls(contents=contents) # type: ignore
def _get_all_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Include the hooks for the component. """Include the hooks for the component.
Returns: Returns:
@ -43,7 +44,7 @@ class Bare(Component):
hooks |= self.contents._var_value._get_all_hooks_internal() hooks |= self.contents._var_value._get_all_hooks_internal()
return hooks return hooks
def _get_all_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Include the hooks for the component. """Include the hooks for the component.
Returns: Returns:

View File

@ -104,7 +104,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def _get_all_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -112,7 +112,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def _get_all_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component. """Get the React hooks for this component.
Returns: Returns:
@ -1338,7 +1338,7 @@ class Component(BaseComponent, ABC):
""" """
_imports = {} _imports = {}
if self._get_ref_hook(): if self._get_ref_hook() is not None:
# Handle hooks needed for attaching react refs to DOM nodes. # Handle hooks needed for attaching react refs to DOM nodes.
_imports.setdefault("react", set()).add(ImportVar(tag="useRef")) _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
_imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add( _imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
@ -1454,7 +1454,7 @@ class Component(BaseComponent, ABC):
}} }}
}}, []);""" }}, []);"""
def _get_ref_hook(self) -> str | None: def _get_ref_hook(self) -> Var | None:
"""Generate the ref hook for the component. """Generate the ref hook for the component.
Returns: Returns:
@ -1462,11 +1462,12 @@ class Component(BaseComponent, ABC):
""" """
ref = self.get_ref() ref = self.get_ref()
if ref is not None: if ref is not None:
return ( return Var(
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};" f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};",
_var_data=VarData(position=Hooks.HookPosition.INTERNAL),
) )
def _get_vars_hooks(self) -> dict[str, None]: def _get_vars_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by vars referenced in this component. """Get the hooks required by vars referenced in this component.
Returns: Returns:
@ -1479,27 +1480,38 @@ class Component(BaseComponent, ABC):
vars_hooks.update( vars_hooks.update(
var_data.hooks var_data.hooks
if isinstance(var_data.hooks, dict) if isinstance(var_data.hooks, dict)
else {k: None for k in var_data.hooks} else {
k: VarData(position=Hooks.HookPosition.INTERNAL)
for k in var_data.hooks
}
) )
return vars_hooks return vars_hooks
def _get_events_hooks(self) -> dict[str, None]: def _get_events_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by events referenced in this component. """Get the hooks required by events referenced in this component.
Returns: Returns:
The hooks for the events. The hooks for the events.
""" """
return {Hooks.EVENTS: None} if self.event_triggers else {} return (
{Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.event_triggers
else {}
)
def _get_special_hooks(self) -> dict[str, None]: def _get_special_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by special actions referenced in this component. """Get the hooks required by special actions referenced in this component.
Returns: Returns:
The hooks for special actions. The hooks for special actions.
""" """
return {Hooks.AUTOFOCUS: None} if self.autofocus else {} return (
{Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.autofocus
else {}
)
def _get_hooks_internal(self) -> dict[str, None]: def _get_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component managed by the framework. """Get the React hooks for this component managed by the framework.
Downstream components should NOT override this method to avoid breaking Downstream components should NOT override this method to avoid breaking
@ -1510,7 +1522,7 @@ class Component(BaseComponent, ABC):
""" """
return { return {
**{ **{
hook: None str(hook): VarData(position=Hooks.HookPosition.INTERNAL)
for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()] for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
if hook is not None if hook is not None
}, },
@ -1559,7 +1571,7 @@ class Component(BaseComponent, ABC):
""" """
return return
def _get_all_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -1574,7 +1586,7 @@ class Component(BaseComponent, ABC):
return code return code
def _get_all_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component and its children. """Get the React hooks for this component and its children.
Returns: Returns:
@ -1582,6 +1594,9 @@ class Component(BaseComponent, ABC):
""" """
code = {} code = {}
# Add the internal hooks for this component.
code.update(self._get_all_hooks_internal())
# Add the hook code for this component. # Add the hook code for this component.
hooks = self._get_hooks() hooks = self._get_hooks()
if hooks is not None: if hooks is not None:
@ -2277,7 +2292,7 @@ class StatefulComponent(BaseComponent):
) )
return trigger_memo return trigger_memo
def _get_all_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -2285,7 +2300,7 @@ class StatefulComponent(BaseComponent):
""" """
return {} return {}
def _get_all_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component. """Get the React hooks for this component.
Returns: Returns:

View File

@ -135,6 +135,7 @@ class Hooks(SimpleNamespace):
class HookPosition(enum.Enum): class HookPosition(enum.Enum):
"""The position of the hook in the component.""" """The position of the hook in the component."""
INTERNAL = "internal"
PRE_TRIGGER = "pre_trigger" PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger" POST_TRIGGER = "post_trigger"

View File

@ -105,7 +105,7 @@ class ClientStateVar(Var):
else: else:
default_var = default default_var = default
setter_name = f"set{var_name.capitalize()}" setter_name = f"set{var_name.capitalize()}"
hooks = { hooks: dict[str, VarData | None] = {
f"const [{var_name}, {setter_name}] = useState({default_var!s})": None, f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
} }
imports = { imports = {

View File

@ -127,7 +127,7 @@ class VarData:
state: str = "", state: str = "",
field_name: str = "", field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None, imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None, hooks: dict[str, VarData | None] | None = None,
deps: list[Var] | None = None, deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None, position: Hooks.HookPosition | None = None,
): ):
@ -194,7 +194,9 @@ class VarData:
(var_data.state for var_data in all_var_datas if var_data.state), "" (var_data.state for var_data in all_var_datas if var_data.state), ""
) )
hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks} hooks: dict[str, VarData | None] = {
hook: None for var_data in all_var_datas for hook in var_data.hooks
}
_imports = imports.merge_imports( _imports = imports.merge_imports(
*(var_data.imports for var_data in all_var_datas) *(var_data.imports for var_data in all_var_datas)