add deps and position field in VarData (#4518)
* fix memoized event trigger order * allow to declare deps in event signature for memoized event triggers * clean up the code to pass tests * handle position of hooks * clean up code * revert test changes * add future annotations * remove non-necessary stuff * reuse data_callback name if already set during first call to add_hooks * remove HookVar and use Var with VarData instead * remove test change * readd removed line * fix order of stmt for cleaner code * fix typing * something broke during the merge I guess * remove hack and pass proper const for position * oops, bad syntax in jinja * use "hook_position" instead of "hook_positions" match the name of the enum --------- Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
parent
76ce112002
commit
1444421766
@ -5,11 +5,15 @@ export function {{tag_name}} () {
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
{% for hook in memo_trigger_hooks %}
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
{% for hook in component._get_all_hooks() %}
|
||||
{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
||||
|
@ -45,6 +45,7 @@ class ReflexJinjaEnvironment(Environment):
|
||||
"on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
|
||||
"update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL,
|
||||
"frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
|
||||
"hook_position": constants.Hooks.HookPosition,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1368,7 +1368,9 @@ class Component(BaseComponent, ABC):
|
||||
if user_hooks_data is not None:
|
||||
other_imports.append(user_hooks_data.imports)
|
||||
other_imports.extend(
|
||||
hook_imports for hook_imports in self._get_added_hooks().values()
|
||||
hook_vardata.imports
|
||||
for hook_vardata in self._get_added_hooks().values()
|
||||
if hook_vardata is not None
|
||||
)
|
||||
|
||||
return imports.merge_imports(_imports, *other_imports)
|
||||
@ -1516,7 +1518,7 @@ class Component(BaseComponent, ABC):
|
||||
**self._get_special_hooks(),
|
||||
}
|
||||
|
||||
def _get_added_hooks(self) -> dict[str, ImportDict]:
|
||||
def _get_added_hooks(self) -> dict[str, VarData | None]:
|
||||
"""Get the hooks added via `add_hooks` method.
|
||||
|
||||
Returns:
|
||||
@ -1525,17 +1527,15 @@ class Component(BaseComponent, ABC):
|
||||
code = {}
|
||||
|
||||
def extract_var_hooks(hook: Var):
|
||||
_imports = {}
|
||||
var_data = VarData.merge(hook._get_all_var_data())
|
||||
if var_data is not None:
|
||||
for sub_hook in var_data.hooks:
|
||||
code[sub_hook] = {}
|
||||
if var_data.imports:
|
||||
_imports = var_data.imports
|
||||
code[sub_hook] = None
|
||||
|
||||
if str(hook) in code:
|
||||
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
|
||||
code[str(hook)] = VarData.merge(var_data, code[str(hook)])
|
||||
else:
|
||||
code[str(hook)] = _imports
|
||||
code[str(hook)] = var_data
|
||||
|
||||
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
||||
# the order of the hooks in the final output)
|
||||
@ -1544,7 +1544,7 @@ class Component(BaseComponent, ABC):
|
||||
if isinstance(hook, Var):
|
||||
extract_var_hooks(hook)
|
||||
else:
|
||||
code[hook] = {}
|
||||
code[hook] = None
|
||||
|
||||
return code
|
||||
|
||||
@ -1586,8 +1586,8 @@ class Component(BaseComponent, ABC):
|
||||
if hooks is not None:
|
||||
code[hooks] = None
|
||||
|
||||
for hook in self._get_added_hooks():
|
||||
code[hook] = None
|
||||
for hook, var_data in self._get_added_hooks().items():
|
||||
code[hook] = var_data
|
||||
|
||||
# Add the hook code for the children.
|
||||
for child in self.children:
|
||||
@ -2189,6 +2189,31 @@ class StatefulComponent(BaseComponent):
|
||||
]
|
||||
return [var_name]
|
||||
|
||||
@staticmethod
|
||||
def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]:
|
||||
"""Get the dependencies accessed by event triggers.
|
||||
|
||||
Args:
|
||||
event: The event trigger to extract deps from.
|
||||
|
||||
Returns:
|
||||
The dependencies accessed by the event triggers.
|
||||
"""
|
||||
events: list = [event]
|
||||
deps = set()
|
||||
|
||||
if isinstance(event, EventChain):
|
||||
events.extend(event.events)
|
||||
|
||||
for ev in events:
|
||||
if isinstance(ev, EventSpec):
|
||||
for arg in ev.args:
|
||||
for a in arg:
|
||||
var_datas = VarData.merge(a._get_all_var_data())
|
||||
if var_datas and var_datas.deps is not None:
|
||||
deps |= {str(dep) for dep in var_datas.deps}
|
||||
return deps
|
||||
|
||||
@classmethod
|
||||
def _get_memoized_event_triggers(
|
||||
cls,
|
||||
@ -2225,6 +2250,11 @@ class StatefulComponent(BaseComponent):
|
||||
|
||||
# Calculate Var dependencies accessed by the handler for useCallback dep array.
|
||||
var_deps = ["addEvents", "Event"]
|
||||
|
||||
# Get deps from event trigger var data.
|
||||
var_deps.extend(cls._get_deps_from_event_trigger(event))
|
||||
|
||||
# Get deps from hooks.
|
||||
for arg in event_args:
|
||||
var_data = arg._get_all_var_data()
|
||||
if var_data is None:
|
||||
|
@ -6,11 +6,12 @@ from typing import Dict, List, Tuple, Union
|
||||
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.tags.tag import Tag
|
||||
from reflex.constants.compiler import Hooks
|
||||
from reflex.event import EventChain, EventHandler, passthrough_event_spec
|
||||
from reflex.utils.format import format_prop, wrap
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import get_unique_variable_name
|
||||
from reflex.vars.base import Var
|
||||
from reflex.vars.base import Var, VarData
|
||||
|
||||
|
||||
class Clipboard(Fragment):
|
||||
@ -72,7 +73,7 @@ class Clipboard(Fragment):
|
||||
),
|
||||
}
|
||||
|
||||
def add_hooks(self) -> list[str]:
|
||||
def add_hooks(self) -> list[str | Var[str]]:
|
||||
"""Add hook to register paste event listener.
|
||||
|
||||
Returns:
|
||||
@ -83,13 +84,14 @@ class Clipboard(Fragment):
|
||||
return []
|
||||
if isinstance(on_paste, EventChain):
|
||||
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
|
||||
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
|
||||
|
||||
return [
|
||||
"usePasteHandler(%s, %s, %s)"
|
||||
% (
|
||||
str(self.targets),
|
||||
str(self.on_paste_event_actions),
|
||||
on_paste,
|
||||
)
|
||||
Var(
|
||||
hook_expr,
|
||||
_var_type="str",
|
||||
_var_data=VarData(position=Hooks.HookPosition.POST_TRIGGER),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
@ -71,6 +71,6 @@ class Clipboard(Fragment):
|
||||
...
|
||||
|
||||
def add_imports(self) -> dict[str, ImportVar]: ...
|
||||
def add_hooks(self) -> list[str]: ...
|
||||
def add_hooks(self) -> list[str | Var[str]]: ...
|
||||
|
||||
clipboard = Clipboard.create
|
||||
|
@ -339,8 +339,11 @@ class DataEditor(NoSSRComponent):
|
||||
editor_id = get_unique_variable_name()
|
||||
|
||||
# Define the name of the getData callback associated with this component and assign to get_cell_content.
|
||||
data_callback = f"getData_{editor_id}"
|
||||
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore
|
||||
if self.get_cell_content is not None:
|
||||
data_callback = self.get_cell_content._js_expr
|
||||
else:
|
||||
data_callback = f"getData_{editor_id}"
|
||||
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore
|
||||
|
||||
code = [f"function {data_callback}([col, row])" "{"]
|
||||
|
||||
|
@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
|
||||
}
|
||||
})"""
|
||||
|
||||
class HookPosition(enum.Enum):
|
||||
"""The position of the hook in the component."""
|
||||
|
||||
PRE_TRIGGER = "pre_trigger"
|
||||
POST_TRIGGER = "post_trigger"
|
||||
|
||||
|
||||
class MemoizationDisposition(enum.Enum):
|
||||
"""The conditions under which a component should be memoized."""
|
||||
|
@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
|
||||
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.utils import console, imports, serializers, types
|
||||
from reflex.constants.compiler import Hooks
|
||||
from reflex.utils import console, exceptions, imports, serializers, types
|
||||
from reflex.utils.exceptions import (
|
||||
VarAttributeError,
|
||||
VarDependencyError,
|
||||
@ -115,12 +116,20 @@ class VarData:
|
||||
# Hooks that need to be present in the component to render this var
|
||||
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
|
||||
|
||||
# Dependencies of the var
|
||||
deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple)
|
||||
|
||||
# Position of the hook in the component
|
||||
position: Hooks.HookPosition | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: str = "",
|
||||
field_name: str = "",
|
||||
imports: ImportDict | ParsedImportDict | None = None,
|
||||
hooks: dict[str, None] | None = None,
|
||||
deps: list[Var] | None = None,
|
||||
position: Hooks.HookPosition | None = None,
|
||||
):
|
||||
"""Initialize the var data.
|
||||
|
||||
@ -129,6 +138,8 @@ class VarData:
|
||||
field_name: The name of the field in the state.
|
||||
imports: Imports needed to render this var.
|
||||
hooks: Hooks that need to be present in the component to render this var.
|
||||
deps: Dependencies of the var for useCallback.
|
||||
position: Position of the hook in the component.
|
||||
"""
|
||||
immutable_imports: ImmutableParsedImportDict = tuple(
|
||||
sorted(
|
||||
@ -139,6 +150,8 @@ class VarData:
|
||||
object.__setattr__(self, "field_name", field_name)
|
||||
object.__setattr__(self, "imports", immutable_imports)
|
||||
object.__setattr__(self, "hooks", tuple(hooks or {}))
|
||||
object.__setattr__(self, "deps", tuple(deps or []))
|
||||
object.__setattr__(self, "position", position or None)
|
||||
|
||||
def old_school_imports(self) -> ImportDict:
|
||||
"""Return the imports as a mutable dict.
|
||||
@ -154,6 +167,9 @@ class VarData:
|
||||
Args:
|
||||
*all: The var data objects to merge.
|
||||
|
||||
Raises:
|
||||
ReflexError: If trying to merge VarData with different positions.
|
||||
|
||||
Returns:
|
||||
The merged var data object.
|
||||
|
||||
@ -184,12 +200,32 @@ class VarData:
|
||||
*(var_data.imports for var_data in all_var_datas)
|
||||
)
|
||||
|
||||
if state or _imports or hooks or field_name:
|
||||
deps = [dep for var_data in all_var_datas for dep in var_data.deps]
|
||||
|
||||
positions = list(
|
||||
{
|
||||
var_data.position
|
||||
for var_data in all_var_datas
|
||||
if var_data.position is not None
|
||||
}
|
||||
)
|
||||
if positions:
|
||||
if len(positions) > 1:
|
||||
raise exceptions.ReflexError(
|
||||
f"Cannot merge var data with different positions: {positions}"
|
||||
)
|
||||
position = positions[0]
|
||||
else:
|
||||
position = None
|
||||
|
||||
if state or _imports or hooks or field_name or deps or position:
|
||||
return VarData(
|
||||
state=state,
|
||||
field_name=field_name,
|
||||
imports=_imports,
|
||||
hooks=hooks,
|
||||
deps=deps,
|
||||
position=position,
|
||||
)
|
||||
|
||||
return None
|
||||
@ -200,7 +236,14 @@ class VarData:
|
||||
Returns:
|
||||
True if any field is set to a non-default value.
|
||||
"""
|
||||
return bool(self.state or self.imports or self.hooks or self.field_name)
|
||||
return bool(
|
||||
self.state
|
||||
or self.imports
|
||||
or self.hooks
|
||||
or self.field_name
|
||||
or self.deps
|
||||
or self.position
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
|
||||
@ -480,7 +523,6 @@ class Var(Generic[VAR_TYPE]):
|
||||
raise TypeError(
|
||||
"The _var_full_name_needs_state_prefix argument is not supported for Var."
|
||||
)
|
||||
|
||||
value_with_replaced = dataclasses.replace(
|
||||
self,
|
||||
_var_type=_var_type or self._var_type,
|
||||
|
Loading…
Reference in New Issue
Block a user