handle position of hooks
This commit is contained in:
parent
2d553c365c
commit
c54b736254
@ -5,7 +5,7 @@ export function {{tag_name}} () {
|
|||||||
{{ hook }}
|
{{ hook }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
{% for hook in component._get_all_hooks() %}
|
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == positions.PRE_TRIGGER %}
|
||||||
{{ hook }}
|
{{ hook }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
@ -13,6 +13,10 @@ export function {{tag_name}} () {
|
|||||||
{{ hook }}
|
{{ hook }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
{% for hook,data in component._get_all_hooks().items() if data.position and data.position == positions.POST_TRIGGER %}
|
||||||
|
{{ hook }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
{{utils.render(component.render(), indent_width=0)}}
|
{{utils.render(component.render(), indent_width=0)}}
|
||||||
)
|
)
|
||||||
|
@ -56,6 +56,7 @@ from reflex.components.component import (
|
|||||||
Component,
|
Component,
|
||||||
ComponentStyle,
|
ComponentStyle,
|
||||||
evaluate_style_namespaces,
|
evaluate_style_namespaces,
|
||||||
|
memo,
|
||||||
)
|
)
|
||||||
from reflex.components.core.banner import connection_pulser, connection_toaster
|
from reflex.components.core.banner import connection_pulser, connection_toaster
|
||||||
from reflex.components.core.breakpoints import set_breakpoints
|
from reflex.components.core.breakpoints import set_breakpoints
|
||||||
@ -162,6 +163,7 @@ def default_overlay_component() -> Component:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@memo
|
||||||
def default_error_boundary(*children: Component) -> Component:
|
def default_error_boundary(*children: Component) -> Component:
|
||||||
"""Default error_boundary attribute for App.
|
"""Default error_boundary attribute for App.
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import reflex.state
|
import reflex.state
|
||||||
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.compiler.templates import STATEFUL_COMPONENT
|
from reflex.compiler.templates import STATEFUL_COMPONENT
|
||||||
from reflex.components.core.breakpoints import Breakpoints
|
from reflex.components.core.breakpoints import Breakpoints
|
||||||
@ -69,6 +70,7 @@ from reflex.vars.base import (
|
|||||||
cached_property_no_lock,
|
cached_property_no_lock,
|
||||||
)
|
)
|
||||||
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
|
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
|
||||||
|
from reflex.vars.hooks import HookVar
|
||||||
from reflex.vars.number import ternary_operation
|
from reflex.vars.number import ternary_operation
|
||||||
from reflex.vars.object import ObjectVar
|
from reflex.vars.object import ObjectVar
|
||||||
from reflex.vars.sequence import LiteralArrayVar
|
from reflex.vars.sequence import LiteralArrayVar
|
||||||
@ -1369,7 +1371,9 @@ class Component(BaseComponent, ABC):
|
|||||||
if user_hooks_data is not None:
|
if user_hooks_data is not None:
|
||||||
other_imports.append(user_hooks_data.imports)
|
other_imports.append(user_hooks_data.imports)
|
||||||
other_imports.extend(
|
other_imports.extend(
|
||||||
hook_imports for hook_imports in self._get_added_hooks().values()
|
hook_vardata.imports
|
||||||
|
for hook_vardata in self._get_added_hooks().values()
|
||||||
|
if hook_vardata is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
return imports.merge_imports(_imports, *other_imports)
|
return imports.merge_imports(_imports, *other_imports)
|
||||||
@ -1523,7 +1527,7 @@ class Component(BaseComponent, ABC):
|
|||||||
**self._get_special_hooks(),
|
**self._get_special_hooks(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_added_hooks(self) -> dict[str, ImportDict]:
|
def _get_added_hooks(self) -> dict[str, VarData]:
|
||||||
"""Get the hooks added via `add_hooks` method.
|
"""Get the hooks added via `add_hooks` method.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1532,17 +1536,19 @@ class Component(BaseComponent, ABC):
|
|||||||
code = {}
|
code = {}
|
||||||
|
|
||||||
def extract_var_hooks(hook: Var):
|
def extract_var_hooks(hook: Var):
|
||||||
_imports = {}
|
|
||||||
var_data = VarData.merge(hook._get_all_var_data())
|
var_data = VarData.merge(hook._get_all_var_data())
|
||||||
if var_data is not None:
|
if var_data is not None:
|
||||||
for sub_hook in var_data.hooks:
|
for sub_hook in var_data.hooks:
|
||||||
code[sub_hook] = {}
|
code[sub_hook] = None
|
||||||
if var_data.imports:
|
|
||||||
_imports = var_data.imports
|
|
||||||
if str(hook) in code:
|
if str(hook) in code:
|
||||||
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
|
code[str(hook)] = VarData.merge(var_data, code[str(hook)])
|
||||||
|
elif isinstance(hook, HookVar):
|
||||||
|
code[str(hook)] = VarData.merge(
|
||||||
|
var_data, VarData(position=hook.position)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
code[str(hook)] = _imports
|
code[str(hook)] = var_data
|
||||||
|
|
||||||
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
||||||
# the order of the hooks in the final output)
|
# the order of the hooks in the final output)
|
||||||
@ -1551,7 +1557,9 @@ class Component(BaseComponent, ABC):
|
|||||||
if isinstance(hook, Var):
|
if isinstance(hook, Var):
|
||||||
extract_var_hooks(hook)
|
extract_var_hooks(hook)
|
||||||
else:
|
else:
|
||||||
code[hook] = {}
|
if isinstance(hook, str):
|
||||||
|
hook = HookVar.create(hook)
|
||||||
|
code[hook] = VarData()
|
||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
@ -1593,8 +1601,8 @@ class Component(BaseComponent, ABC):
|
|||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
code[hooks] = None
|
code[hooks] = None
|
||||||
|
|
||||||
for hook in self._get_added_hooks():
|
for hook, var_data in self._get_added_hooks().items():
|
||||||
code[hook] = None
|
code[hook] = var_data
|
||||||
|
|
||||||
# Add the hook code for the children.
|
# Add the hook code for the children.
|
||||||
for child in self.children:
|
for child in self.children:
|
||||||
@ -2168,6 +2176,7 @@ class StatefulComponent(BaseComponent):
|
|||||||
tag_name=tag_name,
|
tag_name=tag_name,
|
||||||
memo_trigger_hooks=memo_trigger_hooks,
|
memo_trigger_hooks=memo_trigger_hooks,
|
||||||
component=component,
|
component=component,
|
||||||
|
positions=constants.Hooks.HookPosition,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -2244,10 +2253,9 @@ class StatefulComponent(BaseComponent):
|
|||||||
imports={"react": [ImportVar(tag="useCallback")]},
|
imports={"react": [ImportVar(tag="useCallback")]},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store the memoized function name and hook code for this event trigger.
|
# Store the memoized function name and hook code for this event trigger.
|
||||||
trigger_memo[event_trigger] = (
|
trigger_memo[event_trigger] = (
|
||||||
Var(_js_expr=memo_name)._replace(
|
HookVar(_js_expr=memo_name)._replace(
|
||||||
_var_type=EventChain, merge_var_data=memo_var_data
|
_var_type=EventChain, merge_var_data=memo_var_data
|
||||||
),
|
),
|
||||||
f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])",
|
f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])",
|
||||||
|
@ -6,11 +6,13 @@ from typing import Dict, List, Tuple, Union
|
|||||||
|
|
||||||
from reflex.components.base.fragment import Fragment
|
from reflex.components.base.fragment import Fragment
|
||||||
from reflex.components.tags.tag import Tag
|
from reflex.components.tags.tag import Tag
|
||||||
|
from reflex.constants.compiler import Hooks
|
||||||
from reflex.event import EventChain, EventHandler, passthrough_event_spec
|
from reflex.event import EventChain, EventHandler, passthrough_event_spec
|
||||||
from reflex.utils.format import format_prop, wrap
|
from reflex.utils.format import format_prop, wrap
|
||||||
from reflex.utils.imports import ImportVar
|
from reflex.utils.imports import ImportVar
|
||||||
from reflex.vars import get_unique_variable_name
|
from reflex.vars import get_unique_variable_name
|
||||||
from reflex.vars.base import Var
|
from reflex.vars.base import Var
|
||||||
|
from reflex.vars.hooks import HookVar
|
||||||
|
|
||||||
|
|
||||||
class Clipboard(Fragment):
|
class Clipboard(Fragment):
|
||||||
@ -72,7 +74,7 @@ class Clipboard(Fragment):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def add_hooks(self) -> list[str]:
|
def add_hooks(self) -> list[str | Var]:
|
||||||
"""Add hook to register paste event listener.
|
"""Add hook to register paste event listener.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -83,13 +85,9 @@ class Clipboard(Fragment):
|
|||||||
return []
|
return []
|
||||||
if isinstance(on_paste, EventChain):
|
if isinstance(on_paste, EventChain):
|
||||||
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
|
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
|
||||||
|
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
|
||||||
return [
|
return [
|
||||||
"usePasteHandler(%s, %s, %s)"
|
HookVar.create(hook_expr, _position=Hooks.HookPosition.POST_TRIGGER),
|
||||||
% (
|
|
||||||
str(self.targets),
|
|
||||||
str(self.on_paste_event_actions),
|
|
||||||
on_paste,
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,6 +71,6 @@ class Clipboard(Fragment):
|
|||||||
...
|
...
|
||||||
|
|
||||||
def add_imports(self) -> dict[str, ImportVar]: ...
|
def add_imports(self) -> dict[str, ImportVar]: ...
|
||||||
def add_hooks(self) -> list[str]: ...
|
def add_hooks(self) -> list[str | Var]: ...
|
||||||
|
|
||||||
clipboard = Clipboard.create
|
clipboard = Clipboard.create
|
||||||
|
@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
|
|||||||
}
|
}
|
||||||
})"""
|
})"""
|
||||||
|
|
||||||
|
class HookPosition(enum.Enum):
|
||||||
|
"""The position of the hook in the component."""
|
||||||
|
|
||||||
|
PRE_TRIGGER = "pre_trigger"
|
||||||
|
POST_TRIGGER = "post_trigger"
|
||||||
|
|
||||||
|
|
||||||
class MemoizationDisposition(enum.Enum):
|
class MemoizationDisposition(enum.Enum):
|
||||||
"""The conditions under which a component should be memoized."""
|
"""The conditions under which a component should be memoized."""
|
||||||
|
@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
|
|||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.utils import console, imports, serializers, types
|
from reflex.constants.compiler import Hooks
|
||||||
|
from reflex.utils import console, exceptions, imports, serializers, types
|
||||||
from reflex.utils.exceptions import (
|
from reflex.utils.exceptions import (
|
||||||
VarAttributeError,
|
VarAttributeError,
|
||||||
VarDependencyError,
|
VarDependencyError,
|
||||||
@ -115,12 +116,16 @@ class VarData:
|
|||||||
# Hooks that need to be present in the component to render this var
|
# Hooks that need to be present in the component to render this var
|
||||||
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
|
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
|
||||||
|
|
||||||
|
# Position of the hook in the component
|
||||||
|
position: Hooks.HookPosition | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
state: str = "",
|
state: str = "",
|
||||||
field_name: str = "",
|
field_name: str = "",
|
||||||
imports: ImportDict | ParsedImportDict | None = None,
|
imports: ImportDict | ParsedImportDict | None = None,
|
||||||
hooks: dict[str, None] | None = None,
|
hooks: dict[str, None] | None = None,
|
||||||
|
position: Hooks.HookPosition | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize the var data.
|
"""Initialize the var data.
|
||||||
|
|
||||||
@ -129,6 +134,7 @@ class VarData:
|
|||||||
field_name: The name of the field in the state.
|
field_name: The name of the field in the state.
|
||||||
imports: Imports needed to render this var.
|
imports: Imports needed to render this var.
|
||||||
hooks: Hooks that need to be present in the component to render this var.
|
hooks: Hooks that need to be present in the component to render this var.
|
||||||
|
position: Position of the hook in the component.
|
||||||
"""
|
"""
|
||||||
immutable_imports: ImmutableParsedImportDict = tuple(
|
immutable_imports: ImmutableParsedImportDict = tuple(
|
||||||
sorted(
|
sorted(
|
||||||
@ -139,6 +145,7 @@ class VarData:
|
|||||||
object.__setattr__(self, "field_name", field_name)
|
object.__setattr__(self, "field_name", field_name)
|
||||||
object.__setattr__(self, "imports", immutable_imports)
|
object.__setattr__(self, "imports", immutable_imports)
|
||||||
object.__setattr__(self, "hooks", tuple(hooks or {}))
|
object.__setattr__(self, "hooks", tuple(hooks or {}))
|
||||||
|
object.__setattr__(self, "position", position or None)
|
||||||
|
|
||||||
def old_school_imports(self) -> ImportDict:
|
def old_school_imports(self) -> ImportDict:
|
||||||
"""Return the imports as a mutable dict.
|
"""Return the imports as a mutable dict.
|
||||||
@ -154,6 +161,9 @@ class VarData:
|
|||||||
Args:
|
Args:
|
||||||
*all: The var data objects to merge.
|
*all: The var data objects to merge.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ReflexError: If trying to merge VarData with different positions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The merged var data object.
|
The merged var data object.
|
||||||
|
|
||||||
@ -184,12 +194,29 @@ class VarData:
|
|||||||
*(var_data.imports for var_data in all_var_datas)
|
*(var_data.imports for var_data in all_var_datas)
|
||||||
)
|
)
|
||||||
|
|
||||||
if state or _imports or hooks or field_name:
|
positions = list(
|
||||||
|
{
|
||||||
|
var_data.position
|
||||||
|
for var_data in all_var_datas
|
||||||
|
if var_data.position is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if positions:
|
||||||
|
if len(positions) > 1:
|
||||||
|
raise exceptions.ReflexError(
|
||||||
|
f"Cannot merge var data with different positions: {positions}"
|
||||||
|
)
|
||||||
|
position = positions[0]
|
||||||
|
else:
|
||||||
|
position = None
|
||||||
|
|
||||||
|
if state or _imports or hooks or field_name or position:
|
||||||
return VarData(
|
return VarData(
|
||||||
state=state,
|
state=state,
|
||||||
field_name=field_name,
|
field_name=field_name,
|
||||||
imports=_imports,
|
imports=_imports,
|
||||||
hooks=hooks,
|
hooks=hooks,
|
||||||
|
position=position,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -200,7 +227,9 @@ class VarData:
|
|||||||
Returns:
|
Returns:
|
||||||
True if any field is set to a non-default value.
|
True if any field is set to a non-default value.
|
||||||
"""
|
"""
|
||||||
return bool(self.state or self.imports or self.hooks or self.field_name)
|
return bool(
|
||||||
|
self.state or self.imports or self.hooks or self.field_name or self.position
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
|
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
|
||||||
|
44
reflex/vars/hooks.py
Normal file
44
reflex/vars/hooks.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""A module for hooks-related Var."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from reflex.constants import Hooks
|
||||||
|
|
||||||
|
from .base import Var, VarData
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(
|
||||||
|
eq=False,
|
||||||
|
frozen=True,
|
||||||
|
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||||
|
)
|
||||||
|
class HookVar(Var):
|
||||||
|
"""A var class for representing a hook."""
|
||||||
|
|
||||||
|
position: Hooks.HookPosition | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
_hook_expr: str,
|
||||||
|
_var_data: VarData | None = None,
|
||||||
|
_position: Hooks.HookPosition | None = None,
|
||||||
|
):
|
||||||
|
"""Create a hook var.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_hook_expr: The hook expression.
|
||||||
|
_position: The position of the hook in the component.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The hook var.
|
||||||
|
"""
|
||||||
|
hook_var = cls(
|
||||||
|
_js_expr=_hook_expr,
|
||||||
|
_var_type="str",
|
||||||
|
_var_data=_var_data,
|
||||||
|
position=_position,
|
||||||
|
)
|
||||||
|
# print("HookVar.create", _hook_expr, hook_var.position)
|
||||||
|
return hook_var
|
@ -31,6 +31,7 @@ from reflex.utils.exceptions import EventFnArgMismatch
|
|||||||
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
|
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
|
||||||
from reflex.vars import VarData
|
from reflex.vars import VarData
|
||||||
from reflex.vars.base import LiteralVar, Var
|
from reflex.vars.base import LiteralVar, Var
|
||||||
|
from reflex.vars.hooks import HookVar
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -2078,10 +2079,10 @@ def test_component_add_hooks_var():
|
|||||||
]
|
]
|
||||||
|
|
||||||
assert list(HookComponent()._get_all_hooks()) == [
|
assert list(HookComponent()._get_all_hooks()) == [
|
||||||
"const hook3 = useRef(null)",
|
HookVar.create("const hook3 = useRef(null)"),
|
||||||
"const hook1 = 42",
|
"const hook1 = 42",
|
||||||
"const hook2 = 43",
|
"const hook2 = 43",
|
||||||
"useEffect(() => () => {}, [])",
|
HookVar.create("useEffect(() => () => {}, [])"),
|
||||||
]
|
]
|
||||||
imports = HookComponent()._get_all_imports()
|
imports = HookComponent()._get_all_imports()
|
||||||
assert len(imports) == 1
|
assert len(imports) == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user