handle position of hooks

This commit is contained in:
Lendemor 2024-12-11 19:14:33 +01:00
parent 2d553c365c
commit c54b736254
9 changed files with 119 additions and 27 deletions

View File

@ -5,7 +5,7 @@ export function {{tag_name}} () {
{{ hook }}
{% endfor %}
{% for hook in component._get_all_hooks() %}
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == positions.PRE_TRIGGER %}
{{ hook }}
{% endfor %}
@ -13,6 +13,10 @@ export function {{tag_name}} () {
{{ hook }}
{% endfor %}
{% for hook,data in component._get_all_hooks().items() if data.position and data.position == positions.POST_TRIGGER %}
{{ hook }}
{% endfor %}
return (
{{utils.render(component.render(), indent_width=0)}}
)

View File

@ -56,6 +56,7 @@ from reflex.components.component import (
Component,
ComponentStyle,
evaluate_style_namespaces,
memo,
)
from reflex.components.core.banner import connection_pulser, connection_toaster
from reflex.components.core.breakpoints import set_breakpoints
@ -162,6 +163,7 @@ def default_overlay_component() -> Component:
)
@memo
def default_error_boundary(*children: Component) -> Component:
"""Default error_boundary attribute for App.

View File

@ -24,6 +24,7 @@ from typing import (
)
import reflex.state
from reflex import constants
from reflex.base import Base
from reflex.compiler.templates import STATEFUL_COMPONENT
from reflex.components.core.breakpoints import Breakpoints
@ -69,6 +70,7 @@ from reflex.vars.base import (
cached_property_no_lock,
)
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
from reflex.vars.hooks import HookVar
from reflex.vars.number import ternary_operation
from reflex.vars.object import ObjectVar
from reflex.vars.sequence import LiteralArrayVar
@ -1369,7 +1371,9 @@ class Component(BaseComponent, ABC):
if user_hooks_data is not None:
other_imports.append(user_hooks_data.imports)
other_imports.extend(
hook_imports for hook_imports in self._get_added_hooks().values()
hook_vardata.imports
for hook_vardata in self._get_added_hooks().values()
if hook_vardata is not None
)
return imports.merge_imports(_imports, *other_imports)
@ -1523,7 +1527,7 @@ class Component(BaseComponent, ABC):
**self._get_special_hooks(),
}
def _get_added_hooks(self) -> dict[str, ImportDict]:
def _get_added_hooks(self) -> dict[str, VarData]:
"""Get the hooks added via `add_hooks` method.
Returns:
@ -1532,17 +1536,19 @@ class Component(BaseComponent, ABC):
code = {}
def extract_var_hooks(hook: Var):
_imports = {}
var_data = VarData.merge(hook._get_all_var_data())
if var_data is not None:
for sub_hook in var_data.hooks:
code[sub_hook] = {}
if var_data.imports:
_imports = var_data.imports
code[sub_hook] = None
if str(hook) in code:
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
code[str(hook)] = VarData.merge(var_data, code[str(hook)])
elif isinstance(hook, HookVar):
code[str(hook)] = VarData.merge(
var_data, VarData(position=hook.position)
)
else:
code[str(hook)] = _imports
code[str(hook)] = var_data
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
@ -1551,7 +1557,9 @@ class Component(BaseComponent, ABC):
if isinstance(hook, Var):
extract_var_hooks(hook)
else:
code[hook] = {}
if isinstance(hook, str):
hook = HookVar.create(hook)
code[hook] = VarData()
return code
@ -1593,8 +1601,8 @@ class Component(BaseComponent, ABC):
if hooks is not None:
code[hooks] = None
for hook in self._get_added_hooks():
code[hook] = None
for hook, var_data in self._get_added_hooks().items():
code[hook] = var_data
# Add the hook code for the children.
for child in self.children:
@ -2168,6 +2176,7 @@ class StatefulComponent(BaseComponent):
tag_name=tag_name,
memo_trigger_hooks=memo_trigger_hooks,
component=component,
positions=constants.Hooks.HookPosition,
)
@staticmethod
@ -2244,10 +2253,9 @@ class StatefulComponent(BaseComponent):
imports={"react": [ImportVar(tag="useCallback")]},
),
)
# Store the memoized function name and hook code for this event trigger.
trigger_memo[event_trigger] = (
Var(_js_expr=memo_name)._replace(
HookVar(_js_expr=memo_name)._replace(
_var_type=EventChain, merge_var_data=memo_var_data
),
f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])",

View File

@ -6,11 +6,13 @@ from typing import Dict, List, Tuple, Union
from reflex.components.base.fragment import Fragment
from reflex.components.tags.tag import Tag
from reflex.constants.compiler import Hooks
from reflex.event import EventChain, EventHandler, passthrough_event_spec
from reflex.utils.format import format_prop, wrap
from reflex.utils.imports import ImportVar
from reflex.vars import get_unique_variable_name
from reflex.vars.base import Var
from reflex.vars.hooks import HookVar
class Clipboard(Fragment):
@ -72,7 +74,7 @@ class Clipboard(Fragment):
),
}
def add_hooks(self) -> list[str]:
def add_hooks(self) -> list[str | Var]:
"""Add hook to register paste event listener.
Returns:
@ -83,13 +85,9 @@ class Clipboard(Fragment):
return []
if isinstance(on_paste, EventChain):
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
return [
"usePasteHandler(%s, %s, %s)"
% (
str(self.targets),
str(self.on_paste_event_actions),
on_paste,
)
HookVar.create(hook_expr, _position=Hooks.HookPosition.POST_TRIGGER),
]

View File

@ -71,6 +71,6 @@ class Clipboard(Fragment):
...
def add_imports(self) -> dict[str, ImportVar]: ...
def add_hooks(self) -> list[str]: ...
def add_hooks(self) -> list[str | Var]: ...
clipboard = Clipboard.create

View File

@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
}
})"""
class HookPosition(enum.Enum):
"""The position of the hook in the component."""
PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger"
class MemoizationDisposition(enum.Enum):
"""The conditions under which a component should be memoized."""

View File

@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
from reflex import constants
from reflex.base import Base
from reflex.utils import console, imports, serializers, types
from reflex.constants.compiler import Hooks
from reflex.utils import console, exceptions, imports, serializers, types
from reflex.utils.exceptions import (
VarAttributeError,
VarDependencyError,
@ -115,12 +116,16 @@ class VarData:
# Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
# Position of the hook in the component
position: Hooks.HookPosition | None = None
def __init__(
self,
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
position: Hooks.HookPosition | None = None,
):
"""Initialize the var data.
@ -129,6 +134,7 @@ class VarData:
field_name: The name of the field in the state.
imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var.
position: Position of the hook in the component.
"""
immutable_imports: ImmutableParsedImportDict = tuple(
sorted(
@ -139,6 +145,7 @@ class VarData:
object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "position", position or None)
def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
@ -154,6 +161,9 @@ class VarData:
Args:
*all: The var data objects to merge.
Raises:
ReflexError: If trying to merge VarData with different positions.
Returns:
The merged var data object.
@ -184,12 +194,29 @@ class VarData:
*(var_data.imports for var_data in all_var_datas)
)
if state or _imports or hooks or field_name:
positions = list(
{
var_data.position
for var_data in all_var_datas
if var_data.position is not None
}
)
if positions:
if len(positions) > 1:
raise exceptions.ReflexError(
f"Cannot merge var data with different positions: {positions}"
)
position = positions[0]
else:
position = None
if state or _imports or hooks or field_name or position:
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
position=position,
)
return None
@ -200,7 +227,9 @@ class VarData:
Returns:
True if any field is set to a non-default value.
"""
return bool(self.state or self.imports or self.hooks or self.field_name)
return bool(
self.state or self.imports or self.hooks or self.field_name or self.position
)
@classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:

44
reflex/vars/hooks.py Normal file
View File

@ -0,0 +1,44 @@
"""A module for hooks-related Var."""
import dataclasses
import sys
from reflex.constants import Hooks
from .base import Var, VarData
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class HookVar(Var):
"""A var class for representing a hook."""
position: Hooks.HookPosition | None = None
@classmethod
def create(
cls,
_hook_expr: str,
_var_data: VarData | None = None,
_position: Hooks.HookPosition | None = None,
):
"""Create a hook var.
Args:
_hook_expr: The hook expression.
_position: The position of the hook in the component.
Returns:
The hook var.
"""
hook_var = cls(
_js_expr=_hook_expr,
_var_type="str",
_var_data=_var_data,
position=_position,
)
# print("HookVar.create", _hook_expr, hook_var.position)
return hook_var

View File

@ -31,6 +31,7 @@ from reflex.utils.exceptions import EventFnArgMismatch
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.hooks import HookVar
@pytest.fixture
@ -2078,10 +2079,10 @@ def test_component_add_hooks_var():
]
assert list(HookComponent()._get_all_hooks()) == [
"const hook3 = useRef(null)",
HookVar.create("const hook3 = useRef(null)"),
"const hook1 = 42",
"const hook2 = 43",
"useEffect(() => () => {}, [])",
HookVar.create("useEffect(() => () => {}, [])"),
]
imports = HookComponent()._get_all_imports()
assert len(imports) == 1