From 9fafb6d5269121fbbb01d645d4f26d4f9711fa51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Mon, 6 Jan 2025 13:06:56 -0800 Subject: [PATCH] use position in vardata to mark internal hooks (#4549) * use position in vardata to mark internal hooks * update all render to use position * use macros for rendering * reduce number of iterations over hooks during rendering * cleanup code and add typing * add __future__ * use new macros to render component maps in markdown * remove calls to _get_all_hooks_internal * fix typo * forgot to replace this * unnecessary expand in utils.py --- .../.templates/jinja/web/pages/_app.js.jinja2 | 6 +-- .../web/pages/custom_component.js.jinja2 | 7 ++- .../jinja/web/pages/index.js.jinja2 | 5 +- .../jinja/web/pages/macros.js.jinja2 | 38 +++++++++++++ .../web/pages/stateful_component.js.jinja2 | 20 ++----- reflex/compiler/compiler.py | 4 +- reflex/compiler/templates.py | 41 ++++++++++++++ reflex/compiler/utils.py | 2 +- reflex/components/base/bare.py | 5 +- reflex/components/component.py | 53 ++++++++++++------- reflex/components/el/elements/forms.py | 4 +- reflex/components/markdown/markdown.py | 5 +- reflex/constants/compiler.py | 1 + reflex/experimental/client_state.py | 2 +- reflex/vars/base.py | 6 ++- 15 files changed, 140 insertions(+), 59 deletions(-) create mode 100644 reflex/.templates/jinja/web/pages/macros.js.jinja2 diff --git a/reflex/.templates/jinja/web/pages/_app.js.jinja2 b/reflex/.templates/jinja/web/pages/_app.js.jinja2 index 21cfd921a..40e31dee6 100644 --- a/reflex/.templates/jinja/web/pages/_app.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/_app.js.jinja2 @@ -1,4 +1,5 @@ {% extends "web/pages/base_page.js.jinja2" %} +{% from "web/pages/macros.js.jinja2" import renderHooks %} {% block early_imports %} import '$/styles/styles.css' @@ -18,10 +19,7 @@ import * as {{library_alias}} from "{{library_path}}"; {% block export %} function AppWrap({children}) { - - {% for hook in hooks %} - {{ hook }} - {% endfor %} + {{ renderHooks(hooks) }} return ( {{utils.render(render, indent_width=0)}} diff --git a/reflex/.templates/jinja/web/pages/custom_component.js.jinja2 b/reflex/.templates/jinja/web/pages/custom_component.js.jinja2 index 222524d2d..e729d7273 100644 --- a/reflex/.templates/jinja/web/pages/custom_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/custom_component.js.jinja2 @@ -1,5 +1,5 @@ {% extends "web/pages/base_page.js.jinja2" %} - +{% from "web/pages/macros.js.jinja2" import renderHooks %} {% block export %} {% for component in components %} @@ -8,9 +8,8 @@ {% endfor %} export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => { - {% for hook in component.hooks %} - {{ hook }} - {% endfor %} + {{ renderHooks(component.hooks) }} + return( {{utils.render(component.render)}} ) diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2 index efb086ef5..5551ad5fc 100644 --- a/reflex/.templates/jinja/web/pages/index.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/index.js.jinja2 @@ -1,4 +1,5 @@ {% extends "web/pages/base_page.js.jinja2" %} +{% from "web/pages/macros.js.jinja2" import renderHooks %} {% block declaration %} {% for custom_code in custom_codes %} @@ -8,9 +9,7 @@ {% block export %} export default function Component() { - {% for hook in hooks %} - {{ hook }} - {% endfor %} + {{ renderHooks(hooks)}} return ( {{utils.render(render, indent_width=0)}} diff --git a/reflex/.templates/jinja/web/pages/macros.js.jinja2 b/reflex/.templates/jinja/web/pages/macros.js.jinja2 new file mode 100644 index 000000000..68810d896 --- /dev/null +++ b/reflex/.templates/jinja/web/pages/macros.js.jinja2 @@ -0,0 +1,38 @@ +{% macro renderHooks(hooks) %} + {% set sorted_hooks = sort_hooks(hooks) %} + + {# Render the grouped hooks #} + {% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %} + {{ hook }} + {% endfor %} + + {% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %} + {{ hook }} + {% endfor %} + + {% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %} + {{ hook }} + {% endfor %} +{% endmacro %} + +{% macro renderHooksWithMemo(hooks, memo)%} + {% set sorted_hooks = sort_hooks(hooks) %} + + {# Render the grouped hooks #} + {% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %} + {{ hook }} + {% endfor %} + + {% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %} + {{ hook }} + {% endfor %} + + {% for hook in memo %} + {{ hook }} + {% endfor %} + + {% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %} + {{ hook }} + {% endfor %} + +{% endmacro %} \ No newline at end of file diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index b04a78781..208a5755f 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -1,22 +1,10 @@ {% import 'web/pages/utils.js.jinja2' as utils %} +{% from 'web/pages/macros.js.jinja2' import renderHooksWithMemo %} +{% set all_hooks = component._get_all_hooks() %} export function {{tag_name}} () { - {% for hook in component._get_all_hooks_internal() %} - {{ hook }} - {% endfor %} - - {% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %} - {{ hook }} - {% endfor %} - - {% for hook in memo_trigger_hooks %} - {{ hook }} - {% endfor %} - - {% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %} - {{ hook }} - {% endfor %} - + {{ renderHooksWithMemo(all_hooks, memo_trigger_hooks) }} + return ( {{utils.render(component.render(), indent_width=0)}} ) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 9f81f319d..218dd0c55 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -75,7 +75,7 @@ def _compile_app(app_root: Component) -> str: return templates.APP_ROOT.render( imports=utils.compile_imports(app_root._get_all_imports()), custom_codes=app_root._get_all_custom_code(), - hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()}, + hooks=app_root._get_all_hooks(), window_libraries=window_libraries, render=app_root.render(), ) @@ -149,7 +149,7 @@ def _compile_page( imports=imports, dynamic_imports=component._get_all_dynamic_imports(), custom_codes=component._get_all_custom_code(), - hooks={**component._get_all_hooks_internal(), **component._get_all_hooks()}, + hooks=component._get_all_hooks(), render=component.render(), **kwargs, ) diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index 631aa4ee2..117b655a9 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -1,9 +1,46 @@ """Templates to use in the reflex compiler.""" +from __future__ import annotations + from jinja2 import Environment, FileSystemLoader, Template from reflex import constants +from reflex.constants import Hooks from reflex.utils.format import format_state_name, json_dumps +from reflex.vars.base import VarData + + +def _sort_hooks(hooks: dict[str, VarData | None]): + """Sort the hooks by their position. + + Args: + hooks: The hooks to sort. + + Returns: + The sorted hooks. + """ + sorted_hooks = { + Hooks.HookPosition.INTERNAL: [], + Hooks.HookPosition.PRE_TRIGGER: [], + Hooks.HookPosition.POST_TRIGGER: [], + } + + for hook, data in hooks.items(): + if data and data.position and data.position == Hooks.HookPosition.INTERNAL: + sorted_hooks[Hooks.HookPosition.INTERNAL].append((hook, data)) + elif not data or ( + not data.position + or data.position == constants.Hooks.HookPosition.PRE_TRIGGER + ): + sorted_hooks[Hooks.HookPosition.PRE_TRIGGER].append((hook, data)) + elif ( + data + and data.position + and data.position == constants.Hooks.HookPosition.POST_TRIGGER + ): + sorted_hooks[Hooks.HookPosition.POST_TRIGGER].append((hook, data)) + + return sorted_hooks class ReflexJinjaEnvironment(Environment): @@ -47,6 +84,7 @@ class ReflexJinjaEnvironment(Environment): "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL, "hook_position": constants.Hooks.HookPosition, } + self.globals["sort_hooks"] = _sort_hooks def get_template(name: str) -> Template: @@ -103,6 +141,9 @@ STYLE = get_template("web/styles/styles.css.jinja2") # Code that generate the package json file PACKAGE_JSON = get_template("web/package.json.jinja2") +# Template containing some macros used in the web pages. +MACROS = get_template("web/pages/macros.js.jinja2") + # Code that generate the pyproject.toml file for custom components. CUSTOM_COMPONENTS_PYPROJECT_TOML = get_template( "custom_components/pyproject.toml.jinja2" diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 1d698431c..c0ba28f4b 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -290,7 +290,7 @@ def compile_custom_component( "name": component.tag, "props": props, "render": render.render(), - "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()}, + "hooks": render._get_all_hooks(), "custom_code": render._get_all_custom_code(), }, imports, diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index e1b5d9237..7cd225deb 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -9,6 +9,7 @@ from reflex.components.tags import Tag from reflex.components.tags.tagless import Tagless from reflex.utils.imports import ParsedImportDict from reflex.vars import BooleanVar, ObjectVar, Var +from reflex.vars.base import VarData class Bare(Component): @@ -32,7 +33,7 @@ class Bare(Component): contents = str(contents) if contents is not None else "" return cls(contents=contents) # type: ignore - def _get_all_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, VarData | None]: """Include the hooks for the component. Returns: @@ -43,7 +44,7 @@ class Bare(Component): hooks |= self.contents._var_value._get_all_hooks_internal() return hooks - def _get_all_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, VarData | None]: """Include the hooks for the component. Returns: diff --git a/reflex/components/component.py b/reflex/components/component.py index c1d4cbc80..46708e572 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -102,7 +102,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def _get_all_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, VarData | None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -110,7 +110,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def _get_all_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, VarData | None]: """Get the React hooks for this component. Returns: @@ -1272,7 +1272,7 @@ class Component(BaseComponent, ABC): """ _imports = {} - if self._get_ref_hook(): + if self._get_ref_hook() is not None: # Handle hooks needed for attaching react refs to DOM nodes. _imports.setdefault("react", set()).add(ImportVar(tag="useRef")) _imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add( @@ -1388,7 +1388,7 @@ class Component(BaseComponent, ABC): }} }}, []);""" - def _get_ref_hook(self) -> str | None: + def _get_ref_hook(self) -> Var | None: """Generate the ref hook for the component. Returns: @@ -1396,11 +1396,12 @@ class Component(BaseComponent, ABC): """ ref = self.get_ref() if ref is not None: - return ( - f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};" + return Var( + f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};", + _var_data=VarData(position=Hooks.HookPosition.INTERNAL), ) - def _get_vars_hooks(self) -> dict[str, None]: + def _get_vars_hooks(self) -> dict[str, VarData | None]: """Get the hooks required by vars referenced in this component. Returns: @@ -1413,27 +1414,38 @@ class Component(BaseComponent, ABC): vars_hooks.update( var_data.hooks if isinstance(var_data.hooks, dict) - else {k: None for k in var_data.hooks} + else { + k: VarData(position=Hooks.HookPosition.INTERNAL) + for k in var_data.hooks + } ) return vars_hooks - def _get_events_hooks(self) -> dict[str, None]: + def _get_events_hooks(self) -> dict[str, VarData | None]: """Get the hooks required by events referenced in this component. Returns: The hooks for the events. """ - return {Hooks.EVENTS: None} if self.event_triggers else {} + return ( + {Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)} + if self.event_triggers + else {} + ) - def _get_special_hooks(self) -> dict[str, None]: + def _get_special_hooks(self) -> dict[str, VarData | None]: """Get the hooks required by special actions referenced in this component. Returns: The hooks for special actions. """ - return {Hooks.AUTOFOCUS: None} if self.autofocus else {} + return ( + {Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)} + if self.autofocus + else {} + ) - def _get_hooks_internal(self) -> dict[str, None]: + def _get_hooks_internal(self) -> dict[str, VarData | None]: """Get the React hooks for this component managed by the framework. Downstream components should NOT override this method to avoid breaking @@ -1444,7 +1456,7 @@ class Component(BaseComponent, ABC): """ return { **{ - hook: None + str(hook): VarData(position=Hooks.HookPosition.INTERNAL) for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()] if hook is not None }, @@ -1493,7 +1505,7 @@ class Component(BaseComponent, ABC): """ return - def _get_all_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, VarData | None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -1508,7 +1520,7 @@ class Component(BaseComponent, ABC): return code - def _get_all_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, VarData | None]: """Get the React hooks for this component and its children. Returns: @@ -1516,6 +1528,9 @@ class Component(BaseComponent, ABC): """ code = {} + # Add the internal hooks for this component. + code.update(self._get_hooks_internal()) + # Add the hook code for this component. hooks = self._get_hooks() if hooks is not None: @@ -2211,7 +2226,7 @@ class StatefulComponent(BaseComponent): ) return trigger_memo - def _get_all_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, VarData | None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -2219,7 +2234,7 @@ class StatefulComponent(BaseComponent): """ return {} - def _get_all_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, VarData | None]: """Get the React hooks for this component. Returns: @@ -2337,7 +2352,7 @@ class MemoizationLeaf(Component): The memoization leaf """ comp = super().create(*children, **props) - if comp._get_all_hooks() or comp._get_all_hooks_internal(): + if comp._get_all_hooks(): comp._memoization_mode = cls._memoization_mode.copy( update={"disposition": MemoizationDisposition.ALWAYS} ) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 61ded4fd3..529a5e884 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -182,9 +182,7 @@ class Form(BaseHTML): props["handle_submit_unique_name"] = "" form = super().create(*children, **props) form.handle_submit_unique_name = md5( - str({**form._get_all_hooks_internal(), **form._get_all_hooks()}).encode( - "utf-8" - ) + str(form._get_all_hooks()).encode("utf-8") ).hexdigest() return form diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index cd82d5903..7c65c0d43 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -420,11 +420,12 @@ const {_LANGUAGE!s} = match ? match[1] : ''; def _get_custom_code(self) -> str | None: hooks = {} + from reflex.compiler.templates import MACROS + for _component in self.component_map.values(): comp = _component(_MOCK_ARG) - hooks.update(comp._get_all_hooks_internal()) hooks.update(comp._get_all_hooks()) - formatted_hooks = "\n".join(hooks.keys()) + formatted_hooks = MACROS.module.renderHooks(hooks) # type: ignore return f""" function {self._get_component_map_name()} () {{ {formatted_hooks} diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 7ca55f4dd..d98c04d76 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -135,6 +135,7 @@ class Hooks(SimpleNamespace): class HookPosition(enum.Enum): """The position of the hook in the component.""" + INTERNAL = "internal" PRE_TRIGGER = "pre_trigger" POST_TRIGGER = "post_trigger" diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py index 1982b3dfe..e37ceb14c 100644 --- a/reflex/experimental/client_state.py +++ b/reflex/experimental/client_state.py @@ -105,7 +105,7 @@ class ClientStateVar(Var): else: default_var = default setter_name = f"set{var_name.capitalize()}" - hooks = { + hooks: dict[str, VarData | None] = { f"const [{var_name}, {setter_name}] = useState({default_var!s})": None, } imports = { diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 31c5f2e94..a4496d77d 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -127,7 +127,7 @@ class VarData: state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, - hooks: dict[str, None] | None = None, + hooks: dict[str, VarData | None] | None = None, deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, ): @@ -194,7 +194,9 @@ class VarData: (var_data.state for var_data in all_var_datas if var_data.state), "" ) - hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks} + hooks: dict[str, VarData | None] = { + hook: None for var_data in all_var_datas for hook in var_data.hooks + } _imports = imports.merge_imports( *(var_data.imports for var_data in all_var_datas)