add deps and position field in VarData (#4518)
* fix memoized event trigger order * allow to declare deps in event signature for memoized event triggers * clean up the code to pass tests * handle position of hooks * clean up code * revert test changes * add future annotations * remove non-necessary stuff * reuse data_callback name if already set during first call to add_hooks * remove HookVar and use Var with VarData instead * remove test change * readd removed line * fix order of stmt for cleaner code * fix typing * something broke during the merge I guess * remove hack and pass proper const for position * oops, bad syntax in jinja * use "hook_position" instead of "hook_positions" match the name of the enum --------- Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
parent
76ce112002
commit
1444421766
@ -5,11 +5,15 @@ export function {{tag_name}} () {
|
|||||||
{{ hook }}
|
{{ hook }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
|
||||||
|
{{ hook }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
{% for hook in memo_trigger_hooks %}
|
{% for hook in memo_trigger_hooks %}
|
||||||
{{ hook }}
|
{{ hook }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
{% for hook in component._get_all_hooks() %}
|
{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
|
||||||
{{ hook }}
|
{{ hook }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ class ReflexJinjaEnvironment(Environment):
|
|||||||
"on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
|
"on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
|
||||||
"update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL,
|
"update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL,
|
||||||
"frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
|
"frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
|
||||||
|
"hook_position": constants.Hooks.HookPosition,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1368,7 +1368,9 @@ class Component(BaseComponent, ABC):
|
|||||||
if user_hooks_data is not None:
|
if user_hooks_data is not None:
|
||||||
other_imports.append(user_hooks_data.imports)
|
other_imports.append(user_hooks_data.imports)
|
||||||
other_imports.extend(
|
other_imports.extend(
|
||||||
hook_imports for hook_imports in self._get_added_hooks().values()
|
hook_vardata.imports
|
||||||
|
for hook_vardata in self._get_added_hooks().values()
|
||||||
|
if hook_vardata is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
return imports.merge_imports(_imports, *other_imports)
|
return imports.merge_imports(_imports, *other_imports)
|
||||||
@ -1516,7 +1518,7 @@ class Component(BaseComponent, ABC):
|
|||||||
**self._get_special_hooks(),
|
**self._get_special_hooks(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_added_hooks(self) -> dict[str, ImportDict]:
|
def _get_added_hooks(self) -> dict[str, VarData | None]:
|
||||||
"""Get the hooks added via `add_hooks` method.
|
"""Get the hooks added via `add_hooks` method.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1525,17 +1527,15 @@ class Component(BaseComponent, ABC):
|
|||||||
code = {}
|
code = {}
|
||||||
|
|
||||||
def extract_var_hooks(hook: Var):
|
def extract_var_hooks(hook: Var):
|
||||||
_imports = {}
|
|
||||||
var_data = VarData.merge(hook._get_all_var_data())
|
var_data = VarData.merge(hook._get_all_var_data())
|
||||||
if var_data is not None:
|
if var_data is not None:
|
||||||
for sub_hook in var_data.hooks:
|
for sub_hook in var_data.hooks:
|
||||||
code[sub_hook] = {}
|
code[sub_hook] = None
|
||||||
if var_data.imports:
|
|
||||||
_imports = var_data.imports
|
|
||||||
if str(hook) in code:
|
if str(hook) in code:
|
||||||
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
|
code[str(hook)] = VarData.merge(var_data, code[str(hook)])
|
||||||
else:
|
else:
|
||||||
code[str(hook)] = _imports
|
code[str(hook)] = var_data
|
||||||
|
|
||||||
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
||||||
# the order of the hooks in the final output)
|
# the order of the hooks in the final output)
|
||||||
@ -1544,7 +1544,7 @@ class Component(BaseComponent, ABC):
|
|||||||
if isinstance(hook, Var):
|
if isinstance(hook, Var):
|
||||||
extract_var_hooks(hook)
|
extract_var_hooks(hook)
|
||||||
else:
|
else:
|
||||||
code[hook] = {}
|
code[hook] = None
|
||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
@ -1586,8 +1586,8 @@ class Component(BaseComponent, ABC):
|
|||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
code[hooks] = None
|
code[hooks] = None
|
||||||
|
|
||||||
for hook in self._get_added_hooks():
|
for hook, var_data in self._get_added_hooks().items():
|
||||||
code[hook] = None
|
code[hook] = var_data
|
||||||
|
|
||||||
# Add the hook code for the children.
|
# Add the hook code for the children.
|
||||||
for child in self.children:
|
for child in self.children:
|
||||||
@ -2189,6 +2189,31 @@ class StatefulComponent(BaseComponent):
|
|||||||
]
|
]
|
||||||
return [var_name]
|
return [var_name]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]:
|
||||||
|
"""Get the dependencies accessed by event triggers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The event trigger to extract deps from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The dependencies accessed by the event triggers.
|
||||||
|
"""
|
||||||
|
events: list = [event]
|
||||||
|
deps = set()
|
||||||
|
|
||||||
|
if isinstance(event, EventChain):
|
||||||
|
events.extend(event.events)
|
||||||
|
|
||||||
|
for ev in events:
|
||||||
|
if isinstance(ev, EventSpec):
|
||||||
|
for arg in ev.args:
|
||||||
|
for a in arg:
|
||||||
|
var_datas = VarData.merge(a._get_all_var_data())
|
||||||
|
if var_datas and var_datas.deps is not None:
|
||||||
|
deps |= {str(dep) for dep in var_datas.deps}
|
||||||
|
return deps
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_memoized_event_triggers(
|
def _get_memoized_event_triggers(
|
||||||
cls,
|
cls,
|
||||||
@ -2225,6 +2250,11 @@ class StatefulComponent(BaseComponent):
|
|||||||
|
|
||||||
# Calculate Var dependencies accessed by the handler for useCallback dep array.
|
# Calculate Var dependencies accessed by the handler for useCallback dep array.
|
||||||
var_deps = ["addEvents", "Event"]
|
var_deps = ["addEvents", "Event"]
|
||||||
|
|
||||||
|
# Get deps from event trigger var data.
|
||||||
|
var_deps.extend(cls._get_deps_from_event_trigger(event))
|
||||||
|
|
||||||
|
# Get deps from hooks.
|
||||||
for arg in event_args:
|
for arg in event_args:
|
||||||
var_data = arg._get_all_var_data()
|
var_data = arg._get_all_var_data()
|
||||||
if var_data is None:
|
if var_data is None:
|
||||||
|
@ -6,11 +6,12 @@ from typing import Dict, List, Tuple, Union
|
|||||||
|
|
||||||
from reflex.components.base.fragment import Fragment
|
from reflex.components.base.fragment import Fragment
|
||||||
from reflex.components.tags.tag import Tag
|
from reflex.components.tags.tag import Tag
|
||||||
|
from reflex.constants.compiler import Hooks
|
||||||
from reflex.event import EventChain, EventHandler, passthrough_event_spec
|
from reflex.event import EventChain, EventHandler, passthrough_event_spec
|
||||||
from reflex.utils.format import format_prop, wrap
|
from reflex.utils.format import format_prop, wrap
|
||||||
from reflex.utils.imports import ImportVar
|
from reflex.utils.imports import ImportVar
|
||||||
from reflex.vars import get_unique_variable_name
|
from reflex.vars import get_unique_variable_name
|
||||||
from reflex.vars.base import Var
|
from reflex.vars.base import Var, VarData
|
||||||
|
|
||||||
|
|
||||||
class Clipboard(Fragment):
|
class Clipboard(Fragment):
|
||||||
@ -72,7 +73,7 @@ class Clipboard(Fragment):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def add_hooks(self) -> list[str]:
|
def add_hooks(self) -> list[str | Var[str]]:
|
||||||
"""Add hook to register paste event listener.
|
"""Add hook to register paste event listener.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -83,13 +84,14 @@ class Clipboard(Fragment):
|
|||||||
return []
|
return []
|
||||||
if isinstance(on_paste, EventChain):
|
if isinstance(on_paste, EventChain):
|
||||||
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
|
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
|
||||||
|
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
|
||||||
|
|
||||||
return [
|
return [
|
||||||
"usePasteHandler(%s, %s, %s)"
|
Var(
|
||||||
% (
|
hook_expr,
|
||||||
str(self.targets),
|
_var_type="str",
|
||||||
str(self.on_paste_event_actions),
|
_var_data=VarData(position=Hooks.HookPosition.POST_TRIGGER),
|
||||||
on_paste,
|
),
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,6 +71,6 @@ class Clipboard(Fragment):
|
|||||||
...
|
...
|
||||||
|
|
||||||
def add_imports(self) -> dict[str, ImportVar]: ...
|
def add_imports(self) -> dict[str, ImportVar]: ...
|
||||||
def add_hooks(self) -> list[str]: ...
|
def add_hooks(self) -> list[str | Var[str]]: ...
|
||||||
|
|
||||||
clipboard = Clipboard.create
|
clipboard = Clipboard.create
|
||||||
|
@ -339,6 +339,9 @@ class DataEditor(NoSSRComponent):
|
|||||||
editor_id = get_unique_variable_name()
|
editor_id = get_unique_variable_name()
|
||||||
|
|
||||||
# Define the name of the getData callback associated with this component and assign to get_cell_content.
|
# Define the name of the getData callback associated with this component and assign to get_cell_content.
|
||||||
|
if self.get_cell_content is not None:
|
||||||
|
data_callback = self.get_cell_content._js_expr
|
||||||
|
else:
|
||||||
data_callback = f"getData_{editor_id}"
|
data_callback = f"getData_{editor_id}"
|
||||||
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore
|
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore
|
||||||
|
|
||||||
|
@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
|
|||||||
}
|
}
|
||||||
})"""
|
})"""
|
||||||
|
|
||||||
|
class HookPosition(enum.Enum):
|
||||||
|
"""The position of the hook in the component."""
|
||||||
|
|
||||||
|
PRE_TRIGGER = "pre_trigger"
|
||||||
|
POST_TRIGGER = "post_trigger"
|
||||||
|
|
||||||
|
|
||||||
class MemoizationDisposition(enum.Enum):
|
class MemoizationDisposition(enum.Enum):
|
||||||
"""The conditions under which a component should be memoized."""
|
"""The conditions under which a component should be memoized."""
|
||||||
|
@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
|
|||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.utils import console, imports, serializers, types
|
from reflex.constants.compiler import Hooks
|
||||||
|
from reflex.utils import console, exceptions, imports, serializers, types
|
||||||
from reflex.utils.exceptions import (
|
from reflex.utils.exceptions import (
|
||||||
VarAttributeError,
|
VarAttributeError,
|
||||||
VarDependencyError,
|
VarDependencyError,
|
||||||
@ -115,12 +116,20 @@ class VarData:
|
|||||||
# Hooks that need to be present in the component to render this var
|
# Hooks that need to be present in the component to render this var
|
||||||
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
|
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
|
||||||
|
|
||||||
|
# Dependencies of the var
|
||||||
|
deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple)
|
||||||
|
|
||||||
|
# Position of the hook in the component
|
||||||
|
position: Hooks.HookPosition | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
state: str = "",
|
state: str = "",
|
||||||
field_name: str = "",
|
field_name: str = "",
|
||||||
imports: ImportDict | ParsedImportDict | None = None,
|
imports: ImportDict | ParsedImportDict | None = None,
|
||||||
hooks: dict[str, None] | None = None,
|
hooks: dict[str, None] | None = None,
|
||||||
|
deps: list[Var] | None = None,
|
||||||
|
position: Hooks.HookPosition | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize the var data.
|
"""Initialize the var data.
|
||||||
|
|
||||||
@ -129,6 +138,8 @@ class VarData:
|
|||||||
field_name: The name of the field in the state.
|
field_name: The name of the field in the state.
|
||||||
imports: Imports needed to render this var.
|
imports: Imports needed to render this var.
|
||||||
hooks: Hooks that need to be present in the component to render this var.
|
hooks: Hooks that need to be present in the component to render this var.
|
||||||
|
deps: Dependencies of the var for useCallback.
|
||||||
|
position: Position of the hook in the component.
|
||||||
"""
|
"""
|
||||||
immutable_imports: ImmutableParsedImportDict = tuple(
|
immutable_imports: ImmutableParsedImportDict = tuple(
|
||||||
sorted(
|
sorted(
|
||||||
@ -139,6 +150,8 @@ class VarData:
|
|||||||
object.__setattr__(self, "field_name", field_name)
|
object.__setattr__(self, "field_name", field_name)
|
||||||
object.__setattr__(self, "imports", immutable_imports)
|
object.__setattr__(self, "imports", immutable_imports)
|
||||||
object.__setattr__(self, "hooks", tuple(hooks or {}))
|
object.__setattr__(self, "hooks", tuple(hooks or {}))
|
||||||
|
object.__setattr__(self, "deps", tuple(deps or []))
|
||||||
|
object.__setattr__(self, "position", position or None)
|
||||||
|
|
||||||
def old_school_imports(self) -> ImportDict:
|
def old_school_imports(self) -> ImportDict:
|
||||||
"""Return the imports as a mutable dict.
|
"""Return the imports as a mutable dict.
|
||||||
@ -154,6 +167,9 @@ class VarData:
|
|||||||
Args:
|
Args:
|
||||||
*all: The var data objects to merge.
|
*all: The var data objects to merge.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ReflexError: If trying to merge VarData with different positions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The merged var data object.
|
The merged var data object.
|
||||||
|
|
||||||
@ -184,12 +200,32 @@ class VarData:
|
|||||||
*(var_data.imports for var_data in all_var_datas)
|
*(var_data.imports for var_data in all_var_datas)
|
||||||
)
|
)
|
||||||
|
|
||||||
if state or _imports or hooks or field_name:
|
deps = [dep for var_data in all_var_datas for dep in var_data.deps]
|
||||||
|
|
||||||
|
positions = list(
|
||||||
|
{
|
||||||
|
var_data.position
|
||||||
|
for var_data in all_var_datas
|
||||||
|
if var_data.position is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if positions:
|
||||||
|
if len(positions) > 1:
|
||||||
|
raise exceptions.ReflexError(
|
||||||
|
f"Cannot merge var data with different positions: {positions}"
|
||||||
|
)
|
||||||
|
position = positions[0]
|
||||||
|
else:
|
||||||
|
position = None
|
||||||
|
|
||||||
|
if state or _imports or hooks or field_name or deps or position:
|
||||||
return VarData(
|
return VarData(
|
||||||
state=state,
|
state=state,
|
||||||
field_name=field_name,
|
field_name=field_name,
|
||||||
imports=_imports,
|
imports=_imports,
|
||||||
hooks=hooks,
|
hooks=hooks,
|
||||||
|
deps=deps,
|
||||||
|
position=position,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -200,7 +236,14 @@ class VarData:
|
|||||||
Returns:
|
Returns:
|
||||||
True if any field is set to a non-default value.
|
True if any field is set to a non-default value.
|
||||||
"""
|
"""
|
||||||
return bool(self.state or self.imports or self.hooks or self.field_name)
|
return bool(
|
||||||
|
self.state
|
||||||
|
or self.imports
|
||||||
|
or self.hooks
|
||||||
|
or self.field_name
|
||||||
|
or self.deps
|
||||||
|
or self.position
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
|
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
|
||||||
@ -480,7 +523,6 @@ class Var(Generic[VAR_TYPE]):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
"The _var_full_name_needs_state_prefix argument is not supported for Var."
|
"The _var_full_name_needs_state_prefix argument is not supported for Var."
|
||||||
)
|
)
|
||||||
|
|
||||||
value_with_replaced = dataclasses.replace(
|
value_with_replaced = dataclasses.replace(
|
||||||
self,
|
self,
|
||||||
_var_type=_var_type or self._var_type,
|
_var_type=_var_type or self._var_type,
|
||||||
|
Loading…
Reference in New Issue
Block a user