use position in vardata to mark internal hooks
This commit is contained in:
parent
d7956c19d3
commit
95222e49f3
@ -1,11 +1,13 @@
|
||||
{% import 'web/pages/utils.js.jinja2' as utils %}
|
||||
|
||||
{% set all_hooks = component._get_all_hooks().items() %}
|
||||
|
||||
export function {{tag_name}} () {
|
||||
{% for hook in component._get_all_hooks_internal() %}
|
||||
{% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.INTERNAL %}
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
|
||||
{% for hook, data in all_hooks if not data or (not data.position or data.position == const.hook_position.PRE_TRIGGER) %}
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
@ -13,7 +15,7 @@ export function {{tag_name}} () {
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
|
||||
{% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.POST_TRIGGER %}
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
|
@ -9,6 +9,7 @@ from reflex.components.tags import Tag
|
||||
from reflex.components.tags.tagless import Tagless
|
||||
from reflex.utils.imports import ParsedImportDict
|
||||
from reflex.vars import BooleanVar, ObjectVar, Var
|
||||
from reflex.vars.base import VarData
|
||||
|
||||
|
||||
class Bare(Component):
|
||||
@ -32,7 +33,7 @@ class Bare(Component):
|
||||
contents = str(contents) if contents is not None else ""
|
||||
return cls(contents=contents) # type: ignore
|
||||
|
||||
def _get_all_hooks_internal(self) -> dict[str, None]:
|
||||
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
|
||||
"""Include the hooks for the component.
|
||||
|
||||
Returns:
|
||||
@ -43,7 +44,7 @@ class Bare(Component):
|
||||
hooks |= self.contents._var_value._get_all_hooks_internal()
|
||||
return hooks
|
||||
|
||||
def _get_all_hooks(self) -> dict[str, None]:
|
||||
def _get_all_hooks(self) -> dict[str, VarData | None]:
|
||||
"""Include the hooks for the component.
|
||||
|
||||
Returns:
|
||||
|
@ -104,7 +104,7 @@ class BaseComponent(Base, ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _get_all_hooks_internal(self) -> dict[str, None]:
|
||||
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
|
||||
"""Get the reflex internal hooks for the component and its children.
|
||||
|
||||
Returns:
|
||||
@ -112,7 +112,7 @@ class BaseComponent(Base, ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _get_all_hooks(self) -> dict[str, None]:
|
||||
def _get_all_hooks(self) -> dict[str, VarData | None]:
|
||||
"""Get the React hooks for this component.
|
||||
|
||||
Returns:
|
||||
@ -1338,7 +1338,7 @@ class Component(BaseComponent, ABC):
|
||||
"""
|
||||
_imports = {}
|
||||
|
||||
if self._get_ref_hook():
|
||||
if self._get_ref_hook() is not None:
|
||||
# Handle hooks needed for attaching react refs to DOM nodes.
|
||||
_imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
|
||||
_imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
|
||||
@ -1454,7 +1454,7 @@ class Component(BaseComponent, ABC):
|
||||
}}
|
||||
}}, []);"""
|
||||
|
||||
def _get_ref_hook(self) -> str | None:
|
||||
def _get_ref_hook(self) -> Var | None:
|
||||
"""Generate the ref hook for the component.
|
||||
|
||||
Returns:
|
||||
@ -1462,11 +1462,12 @@ class Component(BaseComponent, ABC):
|
||||
"""
|
||||
ref = self.get_ref()
|
||||
if ref is not None:
|
||||
return (
|
||||
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};"
|
||||
return Var(
|
||||
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};",
|
||||
_var_data=VarData(position=Hooks.HookPosition.INTERNAL),
|
||||
)
|
||||
|
||||
def _get_vars_hooks(self) -> dict[str, None]:
|
||||
def _get_vars_hooks(self) -> dict[str, VarData | None]:
|
||||
"""Get the hooks required by vars referenced in this component.
|
||||
|
||||
Returns:
|
||||
@ -1479,27 +1480,38 @@ class Component(BaseComponent, ABC):
|
||||
vars_hooks.update(
|
||||
var_data.hooks
|
||||
if isinstance(var_data.hooks, dict)
|
||||
else {k: None for k in var_data.hooks}
|
||||
else {
|
||||
k: VarData(position=Hooks.HookPosition.INTERNAL)
|
||||
for k in var_data.hooks
|
||||
}
|
||||
)
|
||||
return vars_hooks
|
||||
|
||||
def _get_events_hooks(self) -> dict[str, None]:
|
||||
def _get_events_hooks(self) -> dict[str, VarData | None]:
|
||||
"""Get the hooks required by events referenced in this component.
|
||||
|
||||
Returns:
|
||||
The hooks for the events.
|
||||
"""
|
||||
return {Hooks.EVENTS: None} if self.event_triggers else {}
|
||||
return (
|
||||
{Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)}
|
||||
if self.event_triggers
|
||||
else {}
|
||||
)
|
||||
|
||||
def _get_special_hooks(self) -> dict[str, None]:
|
||||
def _get_special_hooks(self) -> dict[str, VarData | None]:
|
||||
"""Get the hooks required by special actions referenced in this component.
|
||||
|
||||
Returns:
|
||||
The hooks for special actions.
|
||||
"""
|
||||
return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
|
||||
return (
|
||||
{Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)}
|
||||
if self.autofocus
|
||||
else {}
|
||||
)
|
||||
|
||||
def _get_hooks_internal(self) -> dict[str, None]:
|
||||
def _get_hooks_internal(self) -> dict[str, VarData | None]:
|
||||
"""Get the React hooks for this component managed by the framework.
|
||||
|
||||
Downstream components should NOT override this method to avoid breaking
|
||||
@ -1510,7 +1522,7 @@ class Component(BaseComponent, ABC):
|
||||
"""
|
||||
return {
|
||||
**{
|
||||
hook: None
|
||||
str(hook): VarData(position=Hooks.HookPosition.INTERNAL)
|
||||
for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
|
||||
if hook is not None
|
||||
},
|
||||
@ -1559,7 +1571,7 @@ class Component(BaseComponent, ABC):
|
||||
"""
|
||||
return
|
||||
|
||||
def _get_all_hooks_internal(self) -> dict[str, None]:
|
||||
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
|
||||
"""Get the reflex internal hooks for the component and its children.
|
||||
|
||||
Returns:
|
||||
@ -1574,7 +1586,7 @@ class Component(BaseComponent, ABC):
|
||||
|
||||
return code
|
||||
|
||||
def _get_all_hooks(self) -> dict[str, None]:
|
||||
def _get_all_hooks(self) -> dict[str, VarData | None]:
|
||||
"""Get the React hooks for this component and its children.
|
||||
|
||||
Returns:
|
||||
@ -1582,6 +1594,9 @@ class Component(BaseComponent, ABC):
|
||||
"""
|
||||
code = {}
|
||||
|
||||
# Add the internal hooks for this component.
|
||||
code.update(self._get_all_hooks_internal())
|
||||
|
||||
# Add the hook code for this component.
|
||||
hooks = self._get_hooks()
|
||||
if hooks is not None:
|
||||
@ -2277,7 +2292,7 @@ class StatefulComponent(BaseComponent):
|
||||
)
|
||||
return trigger_memo
|
||||
|
||||
def _get_all_hooks_internal(self) -> dict[str, None]:
|
||||
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
|
||||
"""Get the reflex internal hooks for the component and its children.
|
||||
|
||||
Returns:
|
||||
@ -2285,7 +2300,7 @@ class StatefulComponent(BaseComponent):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _get_all_hooks(self) -> dict[str, None]:
|
||||
def _get_all_hooks(self) -> dict[str, VarData | None]:
|
||||
"""Get the React hooks for this component.
|
||||
|
||||
Returns:
|
||||
|
@ -135,6 +135,7 @@ class Hooks(SimpleNamespace):
|
||||
class HookPosition(enum.Enum):
|
||||
"""The position of the hook in the component."""
|
||||
|
||||
INTERNAL = "internal"
|
||||
PRE_TRIGGER = "pre_trigger"
|
||||
POST_TRIGGER = "post_trigger"
|
||||
|
||||
|
@ -105,7 +105,7 @@ class ClientStateVar(Var):
|
||||
else:
|
||||
default_var = default
|
||||
setter_name = f"set{var_name.capitalize()}"
|
||||
hooks = {
|
||||
hooks: dict[str, VarData | None] = {
|
||||
f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
|
||||
}
|
||||
imports = {
|
||||
|
@ -127,7 +127,7 @@ class VarData:
|
||||
state: str = "",
|
||||
field_name: str = "",
|
||||
imports: ImportDict | ParsedImportDict | None = None,
|
||||
hooks: dict[str, None] | None = None,
|
||||
hooks: dict[str, VarData | None] | None = None,
|
||||
deps: list[Var] | None = None,
|
||||
position: Hooks.HookPosition | None = None,
|
||||
):
|
||||
@ -194,7 +194,9 @@ class VarData:
|
||||
(var_data.state for var_data in all_var_datas if var_data.state), ""
|
||||
)
|
||||
|
||||
hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks}
|
||||
hooks: dict[str, VarData | None] = {
|
||||
hook: None for var_data in all_var_datas for hook in var_data.hooks
|
||||
}
|
||||
|
||||
_imports = imports.merge_imports(
|
||||
*(var_data.imports for var_data in all_var_datas)
|
||||
|
Loading…
Reference in New Issue
Block a user