add deps and position field in VarData (#4518)

* fix memoized event trigger order

* allow to declare deps in event signature for memoized event triggers

* clean up the code to pass tests

* handle position of hooks

* clean up code

* revert test changes

* add future annotations

* remove non-necessary stuff

* reuse data_callback name if already set during first call to add_hooks

* remove HookVar and use Var with VarData instead

* remove test change

* readd removed line

* fix order of stmt for cleaner code

* fix typing

* something broke during the merge I guess

* remove hack and pass proper const for position

* oops, bad syntax in jinja

* use "hook_position" instead of "hook_positions"

match the name of the enum

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
Thomas Brandého 2024-12-13 13:28:55 -08:00 committed by GitHub
parent 76ce112002
commit 1444421766
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 115 additions and 27 deletions

View File

@ -5,11 +5,15 @@ export function {{tag_name}} () {
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
{{ hook }}
{% endfor %}
{% for hook in memo_trigger_hooks %} {% for hook in memo_trigger_hooks %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
{% for hook in component._get_all_hooks() %} {% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}

View File

@ -45,6 +45,7 @@ class ReflexJinjaEnvironment(Environment):
"on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL, "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
"update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL, "update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL,
"frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL, "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
"hook_position": constants.Hooks.HookPosition,
} }

View File

@ -1368,7 +1368,9 @@ class Component(BaseComponent, ABC):
if user_hooks_data is not None: if user_hooks_data is not None:
other_imports.append(user_hooks_data.imports) other_imports.append(user_hooks_data.imports)
other_imports.extend( other_imports.extend(
hook_imports for hook_imports in self._get_added_hooks().values() hook_vardata.imports
for hook_vardata in self._get_added_hooks().values()
if hook_vardata is not None
) )
return imports.merge_imports(_imports, *other_imports) return imports.merge_imports(_imports, *other_imports)
@ -1516,7 +1518,7 @@ class Component(BaseComponent, ABC):
**self._get_special_hooks(), **self._get_special_hooks(),
} }
def _get_added_hooks(self) -> dict[str, ImportDict]: def _get_added_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks added via `add_hooks` method. """Get the hooks added via `add_hooks` method.
Returns: Returns:
@ -1525,17 +1527,15 @@ class Component(BaseComponent, ABC):
code = {} code = {}
def extract_var_hooks(hook: Var): def extract_var_hooks(hook: Var):
_imports = {}
var_data = VarData.merge(hook._get_all_var_data()) var_data = VarData.merge(hook._get_all_var_data())
if var_data is not None: if var_data is not None:
for sub_hook in var_data.hooks: for sub_hook in var_data.hooks:
code[sub_hook] = {} code[sub_hook] = None
if var_data.imports:
_imports = var_data.imports
if str(hook) in code: if str(hook) in code:
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports) code[str(hook)] = VarData.merge(var_data, code[str(hook)])
else: else:
code[str(hook)] = _imports code[str(hook)] = var_data
# Add the hook code from add_hooks for each parent class (this is reversed to preserve # Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output) # the order of the hooks in the final output)
@ -1544,7 +1544,7 @@ class Component(BaseComponent, ABC):
if isinstance(hook, Var): if isinstance(hook, Var):
extract_var_hooks(hook) extract_var_hooks(hook)
else: else:
code[hook] = {} code[hook] = None
return code return code
@ -1586,8 +1586,8 @@ class Component(BaseComponent, ABC):
if hooks is not None: if hooks is not None:
code[hooks] = None code[hooks] = None
for hook in self._get_added_hooks(): for hook, var_data in self._get_added_hooks().items():
code[hook] = None code[hook] = var_data
# Add the hook code for the children. # Add the hook code for the children.
for child in self.children: for child in self.children:
@ -2189,6 +2189,31 @@ class StatefulComponent(BaseComponent):
] ]
return [var_name] return [var_name]
@staticmethod
def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]:
"""Get the dependencies accessed by event triggers.
Args:
event: The event trigger to extract deps from.
Returns:
The dependencies accessed by the event triggers.
"""
events: list = [event]
deps = set()
if isinstance(event, EventChain):
events.extend(event.events)
for ev in events:
if isinstance(ev, EventSpec):
for arg in ev.args:
for a in arg:
var_datas = VarData.merge(a._get_all_var_data())
if var_datas and var_datas.deps is not None:
deps |= {str(dep) for dep in var_datas.deps}
return deps
@classmethod @classmethod
def _get_memoized_event_triggers( def _get_memoized_event_triggers(
cls, cls,
@ -2225,6 +2250,11 @@ class StatefulComponent(BaseComponent):
# Calculate Var dependencies accessed by the handler for useCallback dep array. # Calculate Var dependencies accessed by the handler for useCallback dep array.
var_deps = ["addEvents", "Event"] var_deps = ["addEvents", "Event"]
# Get deps from event trigger var data.
var_deps.extend(cls._get_deps_from_event_trigger(event))
# Get deps from hooks.
for arg in event_args: for arg in event_args:
var_data = arg._get_all_var_data() var_data = arg._get_all_var_data()
if var_data is None: if var_data is None:

View File

@ -6,11 +6,12 @@ from typing import Dict, List, Tuple, Union
from reflex.components.base.fragment import Fragment from reflex.components.base.fragment import Fragment
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.constants.compiler import Hooks
from reflex.event import EventChain, EventHandler, passthrough_event_spec from reflex.event import EventChain, EventHandler, passthrough_event_spec
from reflex.utils.format import format_prop, wrap from reflex.utils.format import format_prop, wrap
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
from reflex.vars import get_unique_variable_name from reflex.vars import get_unique_variable_name
from reflex.vars.base import Var from reflex.vars.base import Var, VarData
class Clipboard(Fragment): class Clipboard(Fragment):
@ -72,7 +73,7 @@ class Clipboard(Fragment):
), ),
} }
def add_hooks(self) -> list[str]: def add_hooks(self) -> list[str | Var[str]]:
"""Add hook to register paste event listener. """Add hook to register paste event listener.
Returns: Returns:
@ -83,13 +84,14 @@ class Clipboard(Fragment):
return [] return []
if isinstance(on_paste, EventChain): if isinstance(on_paste, EventChain):
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(") on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
return [ return [
"usePasteHandler(%s, %s, %s)" Var(
% ( hook_expr,
str(self.targets), _var_type="str",
str(self.on_paste_event_actions), _var_data=VarData(position=Hooks.HookPosition.POST_TRIGGER),
on_paste, ),
)
] ]

View File

@ -71,6 +71,6 @@ class Clipboard(Fragment):
... ...
def add_imports(self) -> dict[str, ImportVar]: ... def add_imports(self) -> dict[str, ImportVar]: ...
def add_hooks(self) -> list[str]: ... def add_hooks(self) -> list[str | Var[str]]: ...
clipboard = Clipboard.create clipboard = Clipboard.create

View File

@ -339,6 +339,9 @@ class DataEditor(NoSSRComponent):
editor_id = get_unique_variable_name() editor_id = get_unique_variable_name()
# Define the name of the getData callback associated with this component and assign to get_cell_content. # Define the name of the getData callback associated with this component and assign to get_cell_content.
if self.get_cell_content is not None:
data_callback = self.get_cell_content._js_expr
else:
data_callback = f"getData_{editor_id}" data_callback = f"getData_{editor_id}"
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore self.get_cell_content = Var(_js_expr=data_callback) # type: ignore

View File

@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
} }
})""" })"""
class HookPosition(enum.Enum):
"""The position of the hook in the component."""
PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger"
class MemoizationDisposition(enum.Enum): class MemoizationDisposition(enum.Enum):
"""The conditions under which a component should be memoized.""" """The conditions under which a component should be memoized."""

View File

@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.utils import console, imports, serializers, types from reflex.constants.compiler import Hooks
from reflex.utils import console, exceptions, imports, serializers, types
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
VarAttributeError, VarAttributeError,
VarDependencyError, VarDependencyError,
@ -115,12 +116,20 @@ class VarData:
# Hooks that need to be present in the component to render this var # Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
# Dependencies of the var
deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple)
# Position of the hook in the component
position: Hooks.HookPosition | None = None
def __init__( def __init__(
self, self,
state: str = "", state: str = "",
field_name: str = "", field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None, imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None, hooks: dict[str, None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
): ):
"""Initialize the var data. """Initialize the var data.
@ -129,6 +138,8 @@ class VarData:
field_name: The name of the field in the state. field_name: The name of the field in the state.
imports: Imports needed to render this var. imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var. hooks: Hooks that need to be present in the component to render this var.
deps: Dependencies of the var for useCallback.
position: Position of the hook in the component.
""" """
immutable_imports: ImmutableParsedImportDict = tuple( immutable_imports: ImmutableParsedImportDict = tuple(
sorted( sorted(
@ -139,6 +150,8 @@ class VarData:
object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {})) object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None)
def old_school_imports(self) -> ImportDict: def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict. """Return the imports as a mutable dict.
@ -154,6 +167,9 @@ class VarData:
Args: Args:
*all: The var data objects to merge. *all: The var data objects to merge.
Raises:
ReflexError: If trying to merge VarData with different positions.
Returns: Returns:
The merged var data object. The merged var data object.
@ -184,12 +200,32 @@ class VarData:
*(var_data.imports for var_data in all_var_datas) *(var_data.imports for var_data in all_var_datas)
) )
if state or _imports or hooks or field_name: deps = [dep for var_data in all_var_datas for dep in var_data.deps]
positions = list(
{
var_data.position
for var_data in all_var_datas
if var_data.position is not None
}
)
if positions:
if len(positions) > 1:
raise exceptions.ReflexError(
f"Cannot merge var data with different positions: {positions}"
)
position = positions[0]
else:
position = None
if state or _imports or hooks or field_name or deps or position:
return VarData( return VarData(
state=state, state=state,
field_name=field_name, field_name=field_name,
imports=_imports, imports=_imports,
hooks=hooks, hooks=hooks,
deps=deps,
position=position,
) )
return None return None
@ -200,7 +236,14 @@ class VarData:
Returns: Returns:
True if any field is set to a non-default value. True if any field is set to a non-default value.
""" """
return bool(self.state or self.imports or self.hooks or self.field_name) return bool(
self.state
or self.imports
or self.hooks
or self.field_name
or self.deps
or self.position
)
@classmethod @classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
@ -480,7 +523,6 @@ class Var(Generic[VAR_TYPE]):
raise TypeError( raise TypeError(
"The _var_full_name_needs_state_prefix argument is not supported for Var." "The _var_full_name_needs_state_prefix argument is not supported for Var."
) )
value_with_replaced = dataclasses.replace( value_with_replaced = dataclasses.replace(
self, self,
_var_type=_var_type or self._var_type, _var_type=_var_type or self._var_type,