handle position of hooks
This commit is contained in:
parent
2d553c365c
commit
c54b736254
@ -5,7 +5,7 @@ export function {{tag_name}} () {
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
{% for hook in component._get_all_hooks() %}
|
||||
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == positions.PRE_TRIGGER %}
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
@ -13,6 +13,10 @@ export function {{tag_name}} () {
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
{% for hook,data in component._get_all_hooks().items() if data.position and data.position == positions.POST_TRIGGER %}
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
return (
|
||||
{{utils.render(component.render(), indent_width=0)}}
|
||||
)
|
||||
|
@ -56,6 +56,7 @@ from reflex.components.component import (
|
||||
Component,
|
||||
ComponentStyle,
|
||||
evaluate_style_namespaces,
|
||||
memo,
|
||||
)
|
||||
from reflex.components.core.banner import connection_pulser, connection_toaster
|
||||
from reflex.components.core.breakpoints import set_breakpoints
|
||||
@ -162,6 +163,7 @@ def default_overlay_component() -> Component:
|
||||
)
|
||||
|
||||
|
||||
@memo
|
||||
def default_error_boundary(*children: Component) -> Component:
|
||||
"""Default error_boundary attribute for App.
|
||||
|
||||
|
@ -24,6 +24,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import reflex.state
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.compiler.templates import STATEFUL_COMPONENT
|
||||
from reflex.components.core.breakpoints import Breakpoints
|
||||
@ -69,6 +70,7 @@ from reflex.vars.base import (
|
||||
cached_property_no_lock,
|
||||
)
|
||||
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
|
||||
from reflex.vars.hooks import HookVar
|
||||
from reflex.vars.number import ternary_operation
|
||||
from reflex.vars.object import ObjectVar
|
||||
from reflex.vars.sequence import LiteralArrayVar
|
||||
@ -1369,7 +1371,9 @@ class Component(BaseComponent, ABC):
|
||||
if user_hooks_data is not None:
|
||||
other_imports.append(user_hooks_data.imports)
|
||||
other_imports.extend(
|
||||
hook_imports for hook_imports in self._get_added_hooks().values()
|
||||
hook_vardata.imports
|
||||
for hook_vardata in self._get_added_hooks().values()
|
||||
if hook_vardata is not None
|
||||
)
|
||||
|
||||
return imports.merge_imports(_imports, *other_imports)
|
||||
@ -1523,7 +1527,7 @@ class Component(BaseComponent, ABC):
|
||||
**self._get_special_hooks(),
|
||||
}
|
||||
|
||||
def _get_added_hooks(self) -> dict[str, ImportDict]:
|
||||
def _get_added_hooks(self) -> dict[str, VarData]:
|
||||
"""Get the hooks added via `add_hooks` method.
|
||||
|
||||
Returns:
|
||||
@ -1532,17 +1536,19 @@ class Component(BaseComponent, ABC):
|
||||
code = {}
|
||||
|
||||
def extract_var_hooks(hook: Var):
|
||||
_imports = {}
|
||||
var_data = VarData.merge(hook._get_all_var_data())
|
||||
if var_data is not None:
|
||||
for sub_hook in var_data.hooks:
|
||||
code[sub_hook] = {}
|
||||
if var_data.imports:
|
||||
_imports = var_data.imports
|
||||
code[sub_hook] = None
|
||||
|
||||
if str(hook) in code:
|
||||
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
|
||||
code[str(hook)] = VarData.merge(var_data, code[str(hook)])
|
||||
elif isinstance(hook, HookVar):
|
||||
code[str(hook)] = VarData.merge(
|
||||
var_data, VarData(position=hook.position)
|
||||
)
|
||||
else:
|
||||
code[str(hook)] = _imports
|
||||
code[str(hook)] = var_data
|
||||
|
||||
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
||||
# the order of the hooks in the final output)
|
||||
@ -1551,7 +1557,9 @@ class Component(BaseComponent, ABC):
|
||||
if isinstance(hook, Var):
|
||||
extract_var_hooks(hook)
|
||||
else:
|
||||
code[hook] = {}
|
||||
if isinstance(hook, str):
|
||||
hook = HookVar.create(hook)
|
||||
code[hook] = VarData()
|
||||
|
||||
return code
|
||||
|
||||
@ -1593,8 +1601,8 @@ class Component(BaseComponent, ABC):
|
||||
if hooks is not None:
|
||||
code[hooks] = None
|
||||
|
||||
for hook in self._get_added_hooks():
|
||||
code[hook] = None
|
||||
for hook, var_data in self._get_added_hooks().items():
|
||||
code[hook] = var_data
|
||||
|
||||
# Add the hook code for the children.
|
||||
for child in self.children:
|
||||
@ -2168,6 +2176,7 @@ class StatefulComponent(BaseComponent):
|
||||
tag_name=tag_name,
|
||||
memo_trigger_hooks=memo_trigger_hooks,
|
||||
component=component,
|
||||
positions=constants.Hooks.HookPosition,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -2244,10 +2253,9 @@ class StatefulComponent(BaseComponent):
|
||||
imports={"react": [ImportVar(tag="useCallback")]},
|
||||
),
|
||||
)
|
||||
|
||||
# Store the memoized function name and hook code for this event trigger.
|
||||
trigger_memo[event_trigger] = (
|
||||
Var(_js_expr=memo_name)._replace(
|
||||
HookVar(_js_expr=memo_name)._replace(
|
||||
_var_type=EventChain, merge_var_data=memo_var_data
|
||||
),
|
||||
f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])",
|
||||
|
@ -6,11 +6,13 @@ from typing import Dict, List, Tuple, Union
|
||||
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.tags.tag import Tag
|
||||
from reflex.constants.compiler import Hooks
|
||||
from reflex.event import EventChain, EventHandler, passthrough_event_spec
|
||||
from reflex.utils.format import format_prop, wrap
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import get_unique_variable_name
|
||||
from reflex.vars.base import Var
|
||||
from reflex.vars.hooks import HookVar
|
||||
|
||||
|
||||
class Clipboard(Fragment):
|
||||
@ -72,7 +74,7 @@ class Clipboard(Fragment):
|
||||
),
|
||||
}
|
||||
|
||||
def add_hooks(self) -> list[str]:
|
||||
def add_hooks(self) -> list[str | Var]:
|
||||
"""Add hook to register paste event listener.
|
||||
|
||||
Returns:
|
||||
@ -83,13 +85,9 @@ class Clipboard(Fragment):
|
||||
return []
|
||||
if isinstance(on_paste, EventChain):
|
||||
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
|
||||
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
|
||||
return [
|
||||
"usePasteHandler(%s, %s, %s)"
|
||||
% (
|
||||
str(self.targets),
|
||||
str(self.on_paste_event_actions),
|
||||
on_paste,
|
||||
)
|
||||
HookVar.create(hook_expr, _position=Hooks.HookPosition.POST_TRIGGER),
|
||||
]
|
||||
|
||||
|
||||
|
@ -71,6 +71,6 @@ class Clipboard(Fragment):
|
||||
...
|
||||
|
||||
def add_imports(self) -> dict[str, ImportVar]: ...
|
||||
def add_hooks(self) -> list[str]: ...
|
||||
def add_hooks(self) -> list[str | Var]: ...
|
||||
|
||||
clipboard = Clipboard.create
|
||||
|
@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
|
||||
}
|
||||
})"""
|
||||
|
||||
class HookPosition(enum.Enum):
|
||||
"""The position of the hook in the component."""
|
||||
|
||||
PRE_TRIGGER = "pre_trigger"
|
||||
POST_TRIGGER = "post_trigger"
|
||||
|
||||
|
||||
class MemoizationDisposition(enum.Enum):
|
||||
"""The conditions under which a component should be memoized."""
|
||||
|
@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
|
||||
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.utils import console, imports, serializers, types
|
||||
from reflex.constants.compiler import Hooks
|
||||
from reflex.utils import console, exceptions, imports, serializers, types
|
||||
from reflex.utils.exceptions import (
|
||||
VarAttributeError,
|
||||
VarDependencyError,
|
||||
@ -115,12 +116,16 @@ class VarData:
|
||||
# Hooks that need to be present in the component to render this var
|
||||
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
|
||||
|
||||
# Position of the hook in the component
|
||||
position: Hooks.HookPosition | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: str = "",
|
||||
field_name: str = "",
|
||||
imports: ImportDict | ParsedImportDict | None = None,
|
||||
hooks: dict[str, None] | None = None,
|
||||
position: Hooks.HookPosition | None = None,
|
||||
):
|
||||
"""Initialize the var data.
|
||||
|
||||
@ -129,6 +134,7 @@ class VarData:
|
||||
field_name: The name of the field in the state.
|
||||
imports: Imports needed to render this var.
|
||||
hooks: Hooks that need to be present in the component to render this var.
|
||||
position: Position of the hook in the component.
|
||||
"""
|
||||
immutable_imports: ImmutableParsedImportDict = tuple(
|
||||
sorted(
|
||||
@ -139,6 +145,7 @@ class VarData:
|
||||
object.__setattr__(self, "field_name", field_name)
|
||||
object.__setattr__(self, "imports", immutable_imports)
|
||||
object.__setattr__(self, "hooks", tuple(hooks or {}))
|
||||
object.__setattr__(self, "position", position or None)
|
||||
|
||||
def old_school_imports(self) -> ImportDict:
|
||||
"""Return the imports as a mutable dict.
|
||||
@ -154,6 +161,9 @@ class VarData:
|
||||
Args:
|
||||
*all: The var data objects to merge.
|
||||
|
||||
Raises:
|
||||
ReflexError: If trying to merge VarData with different positions.
|
||||
|
||||
Returns:
|
||||
The merged var data object.
|
||||
|
||||
@ -184,12 +194,29 @@ class VarData:
|
||||
*(var_data.imports for var_data in all_var_datas)
|
||||
)
|
||||
|
||||
if state or _imports or hooks or field_name:
|
||||
positions = list(
|
||||
{
|
||||
var_data.position
|
||||
for var_data in all_var_datas
|
||||
if var_data.position is not None
|
||||
}
|
||||
)
|
||||
if positions:
|
||||
if len(positions) > 1:
|
||||
raise exceptions.ReflexError(
|
||||
f"Cannot merge var data with different positions: {positions}"
|
||||
)
|
||||
position = positions[0]
|
||||
else:
|
||||
position = None
|
||||
|
||||
if state or _imports or hooks or field_name or position:
|
||||
return VarData(
|
||||
state=state,
|
||||
field_name=field_name,
|
||||
imports=_imports,
|
||||
hooks=hooks,
|
||||
position=position,
|
||||
)
|
||||
|
||||
return None
|
||||
@ -200,7 +227,9 @@ class VarData:
|
||||
Returns:
|
||||
True if any field is set to a non-default value.
|
||||
"""
|
||||
return bool(self.state or self.imports or self.hooks or self.field_name)
|
||||
return bool(
|
||||
self.state or self.imports or self.hooks or self.field_name or self.position
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
|
||||
|
44
reflex/vars/hooks.py
Normal file
44
reflex/vars/hooks.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""A module for hooks-related Var."""
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
|
||||
from reflex.constants import Hooks
|
||||
|
||||
from .base import Var, VarData
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class HookVar(Var):
|
||||
"""A var class for representing a hook."""
|
||||
|
||||
position: Hooks.HookPosition | None = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
_hook_expr: str,
|
||||
_var_data: VarData | None = None,
|
||||
_position: Hooks.HookPosition | None = None,
|
||||
):
|
||||
"""Create a hook var.
|
||||
|
||||
Args:
|
||||
_hook_expr: The hook expression.
|
||||
_position: The position of the hook in the component.
|
||||
|
||||
Returns:
|
||||
The hook var.
|
||||
"""
|
||||
hook_var = cls(
|
||||
_js_expr=_hook_expr,
|
||||
_var_type="str",
|
||||
_var_data=_var_data,
|
||||
position=_position,
|
||||
)
|
||||
# print("HookVar.create", _hook_expr, hook_var.position)
|
||||
return hook_var
|
@ -31,6 +31,7 @@ from reflex.utils.exceptions import EventFnArgMismatch
|
||||
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
from reflex.vars.hooks import HookVar
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -2078,10 +2079,10 @@ def test_component_add_hooks_var():
|
||||
]
|
||||
|
||||
assert list(HookComponent()._get_all_hooks()) == [
|
||||
"const hook3 = useRef(null)",
|
||||
HookVar.create("const hook3 = useRef(null)"),
|
||||
"const hook1 = 42",
|
||||
"const hook2 = 43",
|
||||
"useEffect(() => () => {}, [])",
|
||||
HookVar.create("useEffect(() => () => {}, [])"),
|
||||
]
|
||||
imports = HookComponent()._get_all_imports()
|
||||
assert len(imports) == 1
|
||||
|
Loading…
Reference in New Issue
Block a user