handle position of hooks

This commit is contained in:
Lendemor 2024-12-11 19:14:33 +01:00
parent 2d553c365c
commit c54b736254
9 changed files with 119 additions and 27 deletions

View File

@ -5,7 +5,7 @@ export function {{tag_name}} () {
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
{% for hook in component._get_all_hooks() %} {% for hook, data in component._get_all_hooks().items() if not data.position or data.position == positions.PRE_TRIGGER %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
@ -13,6 +13,10 @@ export function {{tag_name}} () {
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
{% for hook,data in component._get_all_hooks().items() if data.position and data.position == positions.POST_TRIGGER %}
{{ hook }}
{% endfor %}
return ( return (
{{utils.render(component.render(), indent_width=0)}} {{utils.render(component.render(), indent_width=0)}}
) )

View File

@ -56,6 +56,7 @@ from reflex.components.component import (
Component, Component,
ComponentStyle, ComponentStyle,
evaluate_style_namespaces, evaluate_style_namespaces,
memo,
) )
from reflex.components.core.banner import connection_pulser, connection_toaster from reflex.components.core.banner import connection_pulser, connection_toaster
from reflex.components.core.breakpoints import set_breakpoints from reflex.components.core.breakpoints import set_breakpoints
@ -162,6 +163,7 @@ def default_overlay_component() -> Component:
) )
@memo
def default_error_boundary(*children: Component) -> Component: def default_error_boundary(*children: Component) -> Component:
"""Default error_boundary attribute for App. """Default error_boundary attribute for App.

View File

@ -24,6 +24,7 @@ from typing import (
) )
import reflex.state import reflex.state
from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.compiler.templates import STATEFUL_COMPONENT from reflex.compiler.templates import STATEFUL_COMPONENT
from reflex.components.core.breakpoints import Breakpoints from reflex.components.core.breakpoints import Breakpoints
@ -69,6 +70,7 @@ from reflex.vars.base import (
cached_property_no_lock, cached_property_no_lock,
) )
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
from reflex.vars.hooks import HookVar
from reflex.vars.number import ternary_operation from reflex.vars.number import ternary_operation
from reflex.vars.object import ObjectVar from reflex.vars.object import ObjectVar
from reflex.vars.sequence import LiteralArrayVar from reflex.vars.sequence import LiteralArrayVar
@ -1369,7 +1371,9 @@ class Component(BaseComponent, ABC):
if user_hooks_data is not None: if user_hooks_data is not None:
other_imports.append(user_hooks_data.imports) other_imports.append(user_hooks_data.imports)
other_imports.extend( other_imports.extend(
hook_imports for hook_imports in self._get_added_hooks().values() hook_vardata.imports
for hook_vardata in self._get_added_hooks().values()
if hook_vardata is not None
) )
return imports.merge_imports(_imports, *other_imports) return imports.merge_imports(_imports, *other_imports)
@ -1523,7 +1527,7 @@ class Component(BaseComponent, ABC):
**self._get_special_hooks(), **self._get_special_hooks(),
} }
def _get_added_hooks(self) -> dict[str, ImportDict]: def _get_added_hooks(self) -> dict[str, VarData]:
"""Get the hooks added via `add_hooks` method. """Get the hooks added via `add_hooks` method.
Returns: Returns:
@ -1532,17 +1536,19 @@ class Component(BaseComponent, ABC):
code = {} code = {}
def extract_var_hooks(hook: Var): def extract_var_hooks(hook: Var):
_imports = {}
var_data = VarData.merge(hook._get_all_var_data()) var_data = VarData.merge(hook._get_all_var_data())
if var_data is not None: if var_data is not None:
for sub_hook in var_data.hooks: for sub_hook in var_data.hooks:
code[sub_hook] = {} code[sub_hook] = None
if var_data.imports:
_imports = var_data.imports
if str(hook) in code: if str(hook) in code:
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports) code[str(hook)] = VarData.merge(var_data, code[str(hook)])
elif isinstance(hook, HookVar):
code[str(hook)] = VarData.merge(
var_data, VarData(position=hook.position)
)
else: else:
code[str(hook)] = _imports code[str(hook)] = var_data
# Add the hook code from add_hooks for each parent class (this is reversed to preserve # Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output) # the order of the hooks in the final output)
@ -1551,7 +1557,9 @@ class Component(BaseComponent, ABC):
if isinstance(hook, Var): if isinstance(hook, Var):
extract_var_hooks(hook) extract_var_hooks(hook)
else: else:
code[hook] = {} if isinstance(hook, str):
hook = HookVar.create(hook)
code[hook] = VarData()
return code return code
@ -1593,8 +1601,8 @@ class Component(BaseComponent, ABC):
if hooks is not None: if hooks is not None:
code[hooks] = None code[hooks] = None
for hook in self._get_added_hooks(): for hook, var_data in self._get_added_hooks().items():
code[hook] = None code[hook] = var_data
# Add the hook code for the children. # Add the hook code for the children.
for child in self.children: for child in self.children:
@ -2168,6 +2176,7 @@ class StatefulComponent(BaseComponent):
tag_name=tag_name, tag_name=tag_name,
memo_trigger_hooks=memo_trigger_hooks, memo_trigger_hooks=memo_trigger_hooks,
component=component, component=component,
positions=constants.Hooks.HookPosition,
) )
@staticmethod @staticmethod
@ -2244,10 +2253,9 @@ class StatefulComponent(BaseComponent):
imports={"react": [ImportVar(tag="useCallback")]}, imports={"react": [ImportVar(tag="useCallback")]},
), ),
) )
# Store the memoized function name and hook code for this event trigger. # Store the memoized function name and hook code for this event trigger.
trigger_memo[event_trigger] = ( trigger_memo[event_trigger] = (
Var(_js_expr=memo_name)._replace( HookVar(_js_expr=memo_name)._replace(
_var_type=EventChain, merge_var_data=memo_var_data _var_type=EventChain, merge_var_data=memo_var_data
), ),
f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])", f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])",

View File

@ -6,11 +6,13 @@ from typing import Dict, List, Tuple, Union
from reflex.components.base.fragment import Fragment from reflex.components.base.fragment import Fragment
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.constants.compiler import Hooks
from reflex.event import EventChain, EventHandler, passthrough_event_spec from reflex.event import EventChain, EventHandler, passthrough_event_spec
from reflex.utils.format import format_prop, wrap from reflex.utils.format import format_prop, wrap
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
from reflex.vars import get_unique_variable_name from reflex.vars import get_unique_variable_name
from reflex.vars.base import Var from reflex.vars.base import Var
from reflex.vars.hooks import HookVar
class Clipboard(Fragment): class Clipboard(Fragment):
@ -72,7 +74,7 @@ class Clipboard(Fragment):
), ),
} }
def add_hooks(self) -> list[str]: def add_hooks(self) -> list[str | Var]:
"""Add hook to register paste event listener. """Add hook to register paste event listener.
Returns: Returns:
@ -83,13 +85,9 @@ class Clipboard(Fragment):
return [] return []
if isinstance(on_paste, EventChain): if isinstance(on_paste, EventChain):
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(") on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
return [ return [
"usePasteHandler(%s, %s, %s)" HookVar.create(hook_expr, _position=Hooks.HookPosition.POST_TRIGGER),
% (
str(self.targets),
str(self.on_paste_event_actions),
on_paste,
)
] ]

View File

@ -71,6 +71,6 @@ class Clipboard(Fragment):
... ...
def add_imports(self) -> dict[str, ImportVar]: ... def add_imports(self) -> dict[str, ImportVar]: ...
def add_hooks(self) -> list[str]: ... def add_hooks(self) -> list[str | Var]: ...
clipboard = Clipboard.create clipboard = Clipboard.create

View File

@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
} }
})""" })"""
class HookPosition(enum.Enum):
"""The position of the hook in the component."""
PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger"
class MemoizationDisposition(enum.Enum): class MemoizationDisposition(enum.Enum):
"""The conditions under which a component should be memoized.""" """The conditions under which a component should be memoized."""

View File

@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.utils import console, imports, serializers, types from reflex.constants.compiler import Hooks
from reflex.utils import console, exceptions, imports, serializers, types
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
VarAttributeError, VarAttributeError,
VarDependencyError, VarDependencyError,
@ -115,12 +116,16 @@ class VarData:
# Hooks that need to be present in the component to render this var # Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
# Position of the hook in the component
position: Hooks.HookPosition | None = None
def __init__( def __init__(
self, self,
state: str = "", state: str = "",
field_name: str = "", field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None, imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None, hooks: dict[str, None] | None = None,
position: Hooks.HookPosition | None = None,
): ):
"""Initialize the var data. """Initialize the var data.
@ -129,6 +134,7 @@ class VarData:
field_name: The name of the field in the state. field_name: The name of the field in the state.
imports: Imports needed to render this var. imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var. hooks: Hooks that need to be present in the component to render this var.
position: Position of the hook in the component.
""" """
immutable_imports: ImmutableParsedImportDict = tuple( immutable_imports: ImmutableParsedImportDict = tuple(
sorted( sorted(
@ -139,6 +145,7 @@ class VarData:
object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {})) object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "position", position or None)
def old_school_imports(self) -> ImportDict: def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict. """Return the imports as a mutable dict.
@ -154,6 +161,9 @@ class VarData:
Args: Args:
*all: The var data objects to merge. *all: The var data objects to merge.
Raises:
ReflexError: If trying to merge VarData with different positions.
Returns: Returns:
The merged var data object. The merged var data object.
@ -184,12 +194,29 @@ class VarData:
*(var_data.imports for var_data in all_var_datas) *(var_data.imports for var_data in all_var_datas)
) )
if state or _imports or hooks or field_name: positions = list(
{
var_data.position
for var_data in all_var_datas
if var_data.position is not None
}
)
if positions:
if len(positions) > 1:
raise exceptions.ReflexError(
f"Cannot merge var data with different positions: {positions}"
)
position = positions[0]
else:
position = None
if state or _imports or hooks or field_name or position:
return VarData( return VarData(
state=state, state=state,
field_name=field_name, field_name=field_name,
imports=_imports, imports=_imports,
hooks=hooks, hooks=hooks,
position=position,
) )
return None return None
@ -200,7 +227,9 @@ class VarData:
Returns: Returns:
True if any field is set to a non-default value. True if any field is set to a non-default value.
""" """
return bool(self.state or self.imports or self.hooks or self.field_name) return bool(
self.state or self.imports or self.hooks or self.field_name or self.position
)
@classmethod @classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:

44
reflex/vars/hooks.py Normal file
View File

@ -0,0 +1,44 @@
"""A module for hooks-related Var."""
import dataclasses
import sys
from reflex.constants import Hooks
from .base import Var, VarData
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class HookVar(Var):
"""A var class for representing a hook."""
position: Hooks.HookPosition | None = None
@classmethod
def create(
cls,
_hook_expr: str,
_var_data: VarData | None = None,
_position: Hooks.HookPosition | None = None,
):
"""Create a hook var.
Args:
_hook_expr: The hook expression.
_position: The position of the hook in the component.
Returns:
The hook var.
"""
hook_var = cls(
_js_expr=_hook_expr,
_var_type="str",
_var_data=_var_data,
position=_position,
)
# print("HookVar.create", _hook_expr, hook_var.position)
return hook_var

View File

@ -31,6 +31,7 @@ from reflex.utils.exceptions import EventFnArgMismatch
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
from reflex.vars.hooks import HookVar
@pytest.fixture @pytest.fixture
@ -2078,10 +2079,10 @@ def test_component_add_hooks_var():
] ]
assert list(HookComponent()._get_all_hooks()) == [ assert list(HookComponent()._get_all_hooks()) == [
"const hook3 = useRef(null)", HookVar.create("const hook3 = useRef(null)"),
"const hook1 = 42", "const hook1 = 42",
"const hook2 = 43", "const hook2 = 43",
"useEffect(() => () => {}, [])", HookVar.create("useEffect(() => () => {}, [])"),
] ]
imports = HookComponent()._get_all_imports() imports = HookComponent()._get_all_imports()
assert len(imports) == 1 assert len(imports) == 1