use position in vardata to mark internal hooks (#4549)

* use position in vardata to mark internal hooks

* update all render to use position

* use macros for rendering

* reduce number of iterations over hooks during rendering

* cleanup code and add typing

* add __future__

* use new macros to render component maps in markdown

* remove calls to _get_all_hooks_internal

* fix typo

* forgot to replace this

* unnecessary expand in utils.py
This commit is contained in:
Thomas Brandého 2025-01-06 13:06:56 -08:00 committed by GitHub
parent 59b3aaca42
commit 9fafb6d526
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 140 additions and 59 deletions

View File

@ -1,4 +1,5 @@
{% extends "web/pages/base_page.js.jinja2" %} {% extends "web/pages/base_page.js.jinja2" %}
{% from "web/pages/macros.js.jinja2" import renderHooks %}
{% block early_imports %} {% block early_imports %}
import '$/styles/styles.css' import '$/styles/styles.css'
@ -18,10 +19,7 @@ import * as {{library_alias}} from "{{library_path}}";
{% block export %} {% block export %}
function AppWrap({children}) { function AppWrap({children}) {
{{ renderHooks(hooks) }}
{% for hook in hooks %}
{{ hook }}
{% endfor %}
return ( return (
{{utils.render(render, indent_width=0)}} {{utils.render(render, indent_width=0)}}

View File

@ -1,5 +1,5 @@
{% extends "web/pages/base_page.js.jinja2" %} {% extends "web/pages/base_page.js.jinja2" %}
{% from "web/pages/macros.js.jinja2" import renderHooks %}
{% block export %} {% block export %}
{% for component in components %} {% for component in components %}
@ -8,9 +8,8 @@
{% endfor %} {% endfor %}
export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => { export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => {
{% for hook in component.hooks %} {{ renderHooks(component.hooks) }}
{{ hook }}
{% endfor %}
return( return(
{{utils.render(component.render)}} {{utils.render(component.render)}}
) )

View File

@ -1,4 +1,5 @@
{% extends "web/pages/base_page.js.jinja2" %} {% extends "web/pages/base_page.js.jinja2" %}
{% from "web/pages/macros.js.jinja2" import renderHooks %}
{% block declaration %} {% block declaration %}
{% for custom_code in custom_codes %} {% for custom_code in custom_codes %}
@ -8,9 +9,7 @@
{% block export %} {% block export %}
export default function Component() { export default function Component() {
{% for hook in hooks %} {{ renderHooks(hooks)}}
{{ hook }}
{% endfor %}
return ( return (
{{utils.render(render, indent_width=0)}} {{utils.render(render, indent_width=0)}}

View File

@ -0,0 +1,38 @@
{% macro renderHooks(hooks) %}
{% set sorted_hooks = sort_hooks(hooks) %}
{# Render the grouped hooks #}
{% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %}
{{ hook }}
{% endfor %}
{% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %}
{{ hook }}
{% endfor %}
{% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %}
{{ hook }}
{% endfor %}
{% endmacro %}
{% macro renderHooksWithMemo(hooks, memo)%}
{% set sorted_hooks = sort_hooks(hooks) %}
{# Render the grouped hooks #}
{% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %}
{{ hook }}
{% endfor %}
{% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %}
{{ hook }}
{% endfor %}
{% for hook in memo %}
{{ hook }}
{% endfor %}
{% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %}
{{ hook }}
{% endfor %}
{% endmacro %}

View File

@ -1,22 +1,10 @@
{% import 'web/pages/utils.js.jinja2' as utils %} {% import 'web/pages/utils.js.jinja2' as utils %}
{% from 'web/pages/macros.js.jinja2' import renderHooksWithMemo %}
{% set all_hooks = component._get_all_hooks() %}
export function {{tag_name}} () { export function {{tag_name}} () {
{% for hook in component._get_all_hooks_internal() %} {{ renderHooksWithMemo(all_hooks, memo_trigger_hooks) }}
{{ hook }}
{% endfor %}
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
{{ hook }}
{% endfor %}
{% for hook in memo_trigger_hooks %}
{{ hook }}
{% endfor %}
{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
{{ hook }}
{% endfor %}
return ( return (
{{utils.render(component.render(), indent_width=0)}} {{utils.render(component.render(), indent_width=0)}}
) )

View File

@ -75,7 +75,7 @@ def _compile_app(app_root: Component) -> str:
return templates.APP_ROOT.render( return templates.APP_ROOT.render(
imports=utils.compile_imports(app_root._get_all_imports()), imports=utils.compile_imports(app_root._get_all_imports()),
custom_codes=app_root._get_all_custom_code(), custom_codes=app_root._get_all_custom_code(),
hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()}, hooks=app_root._get_all_hooks(),
window_libraries=window_libraries, window_libraries=window_libraries,
render=app_root.render(), render=app_root.render(),
) )
@ -149,7 +149,7 @@ def _compile_page(
imports=imports, imports=imports,
dynamic_imports=component._get_all_dynamic_imports(), dynamic_imports=component._get_all_dynamic_imports(),
custom_codes=component._get_all_custom_code(), custom_codes=component._get_all_custom_code(),
hooks={**component._get_all_hooks_internal(), **component._get_all_hooks()}, hooks=component._get_all_hooks(),
render=component.render(), render=component.render(),
**kwargs, **kwargs,
) )

View File

@ -1,9 +1,46 @@
"""Templates to use in the reflex compiler.""" """Templates to use in the reflex compiler."""
from __future__ import annotations
from jinja2 import Environment, FileSystemLoader, Template from jinja2 import Environment, FileSystemLoader, Template
from reflex import constants from reflex import constants
from reflex.constants import Hooks
from reflex.utils.format import format_state_name, json_dumps from reflex.utils.format import format_state_name, json_dumps
from reflex.vars.base import VarData
def _sort_hooks(hooks: dict[str, VarData | None]):
"""Sort the hooks by their position.
Args:
hooks: The hooks to sort.
Returns:
The sorted hooks.
"""
sorted_hooks = {
Hooks.HookPosition.INTERNAL: [],
Hooks.HookPosition.PRE_TRIGGER: [],
Hooks.HookPosition.POST_TRIGGER: [],
}
for hook, data in hooks.items():
if data and data.position and data.position == Hooks.HookPosition.INTERNAL:
sorted_hooks[Hooks.HookPosition.INTERNAL].append((hook, data))
elif not data or (
not data.position
or data.position == constants.Hooks.HookPosition.PRE_TRIGGER
):
sorted_hooks[Hooks.HookPosition.PRE_TRIGGER].append((hook, data))
elif (
data
and data.position
and data.position == constants.Hooks.HookPosition.POST_TRIGGER
):
sorted_hooks[Hooks.HookPosition.POST_TRIGGER].append((hook, data))
return sorted_hooks
class ReflexJinjaEnvironment(Environment): class ReflexJinjaEnvironment(Environment):
@ -47,6 +84,7 @@ class ReflexJinjaEnvironment(Environment):
"frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL, "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
"hook_position": constants.Hooks.HookPosition, "hook_position": constants.Hooks.HookPosition,
} }
self.globals["sort_hooks"] = _sort_hooks
def get_template(name: str) -> Template: def get_template(name: str) -> Template:
@ -103,6 +141,9 @@ STYLE = get_template("web/styles/styles.css.jinja2")
# Code that generate the package json file # Code that generate the package json file
PACKAGE_JSON = get_template("web/package.json.jinja2") PACKAGE_JSON = get_template("web/package.json.jinja2")
# Template containing some macros used in the web pages.
MACROS = get_template("web/pages/macros.js.jinja2")
# Code that generate the pyproject.toml file for custom components. # Code that generate the pyproject.toml file for custom components.
CUSTOM_COMPONENTS_PYPROJECT_TOML = get_template( CUSTOM_COMPONENTS_PYPROJECT_TOML = get_template(
"custom_components/pyproject.toml.jinja2" "custom_components/pyproject.toml.jinja2"

View File

@ -290,7 +290,7 @@ def compile_custom_component(
"name": component.tag, "name": component.tag,
"props": props, "props": props,
"render": render.render(), "render": render.render(),
"hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()}, "hooks": render._get_all_hooks(),
"custom_code": render._get_all_custom_code(), "custom_code": render._get_all_custom_code(),
}, },
imports, imports,

View File

@ -9,6 +9,7 @@ from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless from reflex.components.tags.tagless import Tagless
from reflex.utils.imports import ParsedImportDict from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData
class Bare(Component): class Bare(Component):
@ -32,7 +33,7 @@ class Bare(Component):
contents = str(contents) if contents is not None else "" contents = str(contents) if contents is not None else ""
return cls(contents=contents) # type: ignore return cls(contents=contents) # type: ignore
def _get_all_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Include the hooks for the component. """Include the hooks for the component.
Returns: Returns:
@ -43,7 +44,7 @@ class Bare(Component):
hooks |= self.contents._var_value._get_all_hooks_internal() hooks |= self.contents._var_value._get_all_hooks_internal()
return hooks return hooks
def _get_all_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Include the hooks for the component. """Include the hooks for the component.
Returns: Returns:

View File

@ -102,7 +102,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def _get_all_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -110,7 +110,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def _get_all_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component. """Get the React hooks for this component.
Returns: Returns:
@ -1272,7 +1272,7 @@ class Component(BaseComponent, ABC):
""" """
_imports = {} _imports = {}
if self._get_ref_hook(): if self._get_ref_hook() is not None:
# Handle hooks needed for attaching react refs to DOM nodes. # Handle hooks needed for attaching react refs to DOM nodes.
_imports.setdefault("react", set()).add(ImportVar(tag="useRef")) _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
_imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add( _imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
@ -1388,7 +1388,7 @@ class Component(BaseComponent, ABC):
}} }}
}}, []);""" }}, []);"""
def _get_ref_hook(self) -> str | None: def _get_ref_hook(self) -> Var | None:
"""Generate the ref hook for the component. """Generate the ref hook for the component.
Returns: Returns:
@ -1396,11 +1396,12 @@ class Component(BaseComponent, ABC):
""" """
ref = self.get_ref() ref = self.get_ref()
if ref is not None: if ref is not None:
return ( return Var(
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};" f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};",
_var_data=VarData(position=Hooks.HookPosition.INTERNAL),
) )
def _get_vars_hooks(self) -> dict[str, None]: def _get_vars_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by vars referenced in this component. """Get the hooks required by vars referenced in this component.
Returns: Returns:
@ -1413,27 +1414,38 @@ class Component(BaseComponent, ABC):
vars_hooks.update( vars_hooks.update(
var_data.hooks var_data.hooks
if isinstance(var_data.hooks, dict) if isinstance(var_data.hooks, dict)
else {k: None for k in var_data.hooks} else {
k: VarData(position=Hooks.HookPosition.INTERNAL)
for k in var_data.hooks
}
) )
return vars_hooks return vars_hooks
def _get_events_hooks(self) -> dict[str, None]: def _get_events_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by events referenced in this component. """Get the hooks required by events referenced in this component.
Returns: Returns:
The hooks for the events. The hooks for the events.
""" """
return {Hooks.EVENTS: None} if self.event_triggers else {} return (
{Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.event_triggers
else {}
)
def _get_special_hooks(self) -> dict[str, None]: def _get_special_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by special actions referenced in this component. """Get the hooks required by special actions referenced in this component.
Returns: Returns:
The hooks for special actions. The hooks for special actions.
""" """
return {Hooks.AUTOFOCUS: None} if self.autofocus else {} return (
{Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.autofocus
else {}
)
def _get_hooks_internal(self) -> dict[str, None]: def _get_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component managed by the framework. """Get the React hooks for this component managed by the framework.
Downstream components should NOT override this method to avoid breaking Downstream components should NOT override this method to avoid breaking
@ -1444,7 +1456,7 @@ class Component(BaseComponent, ABC):
""" """
return { return {
**{ **{
hook: None str(hook): VarData(position=Hooks.HookPosition.INTERNAL)
for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()] for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
if hook is not None if hook is not None
}, },
@ -1493,7 +1505,7 @@ class Component(BaseComponent, ABC):
""" """
return return
def _get_all_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -1508,7 +1520,7 @@ class Component(BaseComponent, ABC):
return code return code
def _get_all_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component and its children. """Get the React hooks for this component and its children.
Returns: Returns:
@ -1516,6 +1528,9 @@ class Component(BaseComponent, ABC):
""" """
code = {} code = {}
# Add the internal hooks for this component.
code.update(self._get_hooks_internal())
# Add the hook code for this component. # Add the hook code for this component.
hooks = self._get_hooks() hooks = self._get_hooks()
if hooks is not None: if hooks is not None:
@ -2211,7 +2226,7 @@ class StatefulComponent(BaseComponent):
) )
return trigger_memo return trigger_memo
def _get_all_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -2219,7 +2234,7 @@ class StatefulComponent(BaseComponent):
""" """
return {} return {}
def _get_all_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component. """Get the React hooks for this component.
Returns: Returns:
@ -2337,7 +2352,7 @@ class MemoizationLeaf(Component):
The memoization leaf The memoization leaf
""" """
comp = super().create(*children, **props) comp = super().create(*children, **props)
if comp._get_all_hooks() or comp._get_all_hooks_internal(): if comp._get_all_hooks():
comp._memoization_mode = cls._memoization_mode.copy( comp._memoization_mode = cls._memoization_mode.copy(
update={"disposition": MemoizationDisposition.ALWAYS} update={"disposition": MemoizationDisposition.ALWAYS}
) )

View File

@ -182,9 +182,7 @@ class Form(BaseHTML):
props["handle_submit_unique_name"] = "" props["handle_submit_unique_name"] = ""
form = super().create(*children, **props) form = super().create(*children, **props)
form.handle_submit_unique_name = md5( form.handle_submit_unique_name = md5(
str({**form._get_all_hooks_internal(), **form._get_all_hooks()}).encode( str(form._get_all_hooks()).encode("utf-8")
"utf-8"
)
).hexdigest() ).hexdigest()
return form return form

View File

@ -420,11 +420,12 @@ const {_LANGUAGE!s} = match ? match[1] : '';
def _get_custom_code(self) -> str | None: def _get_custom_code(self) -> str | None:
hooks = {} hooks = {}
from reflex.compiler.templates import MACROS
for _component in self.component_map.values(): for _component in self.component_map.values():
comp = _component(_MOCK_ARG) comp = _component(_MOCK_ARG)
hooks.update(comp._get_all_hooks_internal())
hooks.update(comp._get_all_hooks()) hooks.update(comp._get_all_hooks())
formatted_hooks = "\n".join(hooks.keys()) formatted_hooks = MACROS.module.renderHooks(hooks) # type: ignore
return f""" return f"""
function {self._get_component_map_name()} () {{ function {self._get_component_map_name()} () {{
{formatted_hooks} {formatted_hooks}

View File

@ -135,6 +135,7 @@ class Hooks(SimpleNamespace):
class HookPosition(enum.Enum): class HookPosition(enum.Enum):
"""The position of the hook in the component.""" """The position of the hook in the component."""
INTERNAL = "internal"
PRE_TRIGGER = "pre_trigger" PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger" POST_TRIGGER = "post_trigger"

View File

@ -105,7 +105,7 @@ class ClientStateVar(Var):
else: else:
default_var = default default_var = default
setter_name = f"set{var_name.capitalize()}" setter_name = f"set{var_name.capitalize()}"
hooks = { hooks: dict[str, VarData | None] = {
f"const [{var_name}, {setter_name}] = useState({default_var!s})": None, f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
} }
imports = { imports = {

View File

@ -127,7 +127,7 @@ class VarData:
state: str = "", state: str = "",
field_name: str = "", field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None, imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None, hooks: dict[str, VarData | None] | None = None,
deps: list[Var] | None = None, deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None, position: Hooks.HookPosition | None = None,
): ):
@ -194,7 +194,9 @@ class VarData:
(var_data.state for var_data in all_var_datas if var_data.state), "" (var_data.state for var_data in all_var_datas if var_data.state), ""
) )
hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks} hooks: dict[str, VarData | None] = {
hook: None for var_data in all_var_datas for hook in var_data.hooks
}
_imports = imports.merge_imports( _imports = imports.merge_imports(
*(var_data.imports for var_data in all_var_datas) *(var_data.imports for var_data in all_var_datas)