diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2 index e29deee34..efb086ef5 100644 --- a/reflex/.templates/jinja/web/pages/index.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/index.js.jinja2 @@ -8,10 +8,6 @@ {% block export %} export default function Component() { - {% for ref_hook in ref_hooks %} - {{ ref_hook }} - {% endfor %} - {% for hook in hooks %} {{ hook }} {% endfor %} diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index ccb375893..ad970c2f5 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -1,10 +1,6 @@ {% import 'web/pages/utils.js.jinja2' as utils %} export function {{tag_name}} () { - {% for ref_hook in component.get_ref_hooks() %} - {{ ref_hook }} - {% endfor %} - {% for hook in component.get_hooks_internal() %} {{ hook }} {% endfor %} diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index ef0d3e4ca..02b2a2a16 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -50,7 +50,7 @@ def _compile_app(app_root: Component) -> str: return templates.APP_ROOT.render( imports=utils.compile_imports(app_root.get_imports()), custom_codes=app_root.get_custom_code(), - hooks=app_root.get_hooks_internal() | app_root.get_hooks(), + hooks={**app_root.get_hooks_internal(), **app_root.get_hooks()}, render=app_root.render(), ) @@ -119,8 +119,7 @@ def _compile_page( imports=imports, dynamic_imports=component.get_dynamic_imports(), custom_codes=component.get_custom_code(), - ref_hooks=component.get_ref_hooks(), - hooks=component.get_hooks_internal() | component.get_hooks(), + hooks={**component.get_hooks_internal(), **component.get_hooks()}, render=component.render(), **kwargs, ) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 0789a4548..3ebba40e0 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -265,7 +265,7 @@ def compile_custom_component( "name": component.tag, "props": props, "render": render.render(), - "hooks": render.get_hooks_internal() | render.get_hooks(), + "hooks": {**render.get_hooks_internal(), **render.get_hooks()}, "custom_code": render.get_custom_code(), }, imports, diff --git a/reflex/components/component.py b/reflex/components/component.py index c8c3f5b9d..c873e687b 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -76,15 +76,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def get_ref_hooks(self) -> set[str]: - """Get the hooks required by refs in this component. - - Returns: - The hooks for the refs. - """ - - @abstractmethod - def get_hooks_internal(self) -> set[str]: + def get_hooks_internal(self) -> dict[str, None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -92,7 +84,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def get_hooks(self) -> set[str]: + def get_hooks(self) -> dict[str, None]: """Get the React hooks for this component. Returns: @@ -929,7 +921,7 @@ class Component(BaseComponent, ABC): """ return None - def get_custom_code(self) -> Set[str]: + def get_custom_code(self) -> set[str]: """Get custom code for the component and its children. Returns: @@ -1108,62 +1100,53 @@ class Component(BaseComponent, ABC): if ref is not None: return f"const {ref} = useRef(null); {str(Var.create_safe(ref).as_ref())} = {ref};" - def _get_vars_hooks(self) -> set[str]: + def _get_vars_hooks(self) -> dict[str, None]: """Get the hooks required by vars referenced in this component. Returns: The hooks for the vars. """ - vars_hooks = set() + vars_hooks = {} for var in self._get_vars(): if var._var_data: vars_hooks.update(var._var_data.hooks) return vars_hooks - def _get_events_hooks(self) -> set[str]: + def _get_events_hooks(self) -> dict[str, None]: """Get the hooks required by events referenced in this component. Returns: The hooks for the events. """ - if self.event_triggers: - return {Hooks.EVENTS} - return set() + return {Hooks.EVENTS: None} if self.event_triggers else {} - def _get_special_hooks(self) -> set[str]: + def _get_special_hooks(self) -> dict[str, None]: """Get the hooks required by special actions referenced in this component. Returns: The hooks for special actions. """ - if self.autofocus: - return { - """ - // Set focus to the specified element. - const focusRef = useRef(null) - useEffect(() => { - if (focusRef.current) { - focusRef.current.focus(); - } - })""", - } - return set() + return {Hooks.AUTOFOCUS: None} if self.autofocus else {} - def _get_hooks_internal(self) -> Set[str]: + def _get_hooks_internal(self) -> dict[str, None]: """Get the React hooks for this component managed by the framework. Downstream components should NOT override this method to avoid breaking framework functionality. Returns: - Set of internally managed hooks. + The internally managed hooks. """ - return ( - self._get_vars_hooks() - | self._get_events_hooks() - | self._get_special_hooks() - | set(hook for hook in [self._get_mount_lifecycle_hook()] if hook) - ) + return { + **{ + hook: None + for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()] + if hook is not None + }, + **self._get_vars_hooks(), + **self._get_events_hooks(), + **self._get_special_hooks(), + } def _get_hooks(self) -> str | None: """Get the React hooks for this component. @@ -1175,20 +1158,7 @@ class Component(BaseComponent, ABC): """ return - def get_ref_hooks(self) -> Set[str]: - """Get the ref hooks for the component and its children. - - Returns: - The ref hooks. - """ - ref_hook = self._get_ref_hook() - hooks = set() if ref_hook is None else {ref_hook} - - for child in self.children: - hooks |= child.get_ref_hooks() - return hooks - - def get_hooks_internal(self) -> set[str]: + def get_hooks_internal(self) -> dict[str, None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -1199,26 +1169,26 @@ class Component(BaseComponent, ABC): # Add the hook code for the children. for child in self.children: - code |= child.get_hooks_internal() + code = {**code, **child.get_hooks_internal()} return code - def get_hooks(self) -> Set[str]: + def get_hooks(self) -> dict[str, None]: """Get the React hooks for this component and its children. Returns: The code that should appear just before returning the rendered component. """ - code = set() + code = {} # Add the hook code for this component. hooks = self._get_hooks() if hooks is not None: - code.add(hooks) + code[hooks] = None # Add the hook code for the children. for child in self.children: - code |= child.get_hooks() + code = {**code, **child.get_hooks()} return code @@ -1233,7 +1203,7 @@ class Component(BaseComponent, ABC): return None return format.format_ref(self.id) - def get_refs(self) -> Set[str]: + def get_refs(self) -> set[str]: """Get the refs for the children of the component. Returns: @@ -1854,29 +1824,21 @@ class StatefulComponent(BaseComponent): ) return trigger_memo - def get_ref_hooks(self) -> set[str]: - """Get the ref hooks for the component and its children. - - Returns: - The ref hooks. - """ - return set() - - def get_hooks_internal(self) -> set[str]: + def get_hooks_internal(self) -> dict[str, None]: """Get the reflex internal hooks for the component and its children. Returns: The code that should appear just before user-defined hooks. """ - return set() + return {} - def get_hooks(self) -> set[str]: + def get_hooks(self) -> dict[str, None]: """Get the React hooks for this component. Returns: The code that should appear just before returning the rendered component. """ - return set() + return {} def get_imports(self) -> imports.ImportDict: """Get all the libraries and fields that are used by the component. diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py index ea20179e1..0c781fba8 100644 --- a/reflex/components/core/banner.py +++ b/reflex/components/core/banner.py @@ -22,7 +22,7 @@ from reflex.vars import Var, VarData connect_error_var_data: VarData = VarData( # type: ignore imports=Imports.EVENTS, - hooks={Hooks.EVENTS}, + hooks={Hooks.EVENTS: None}, ) connection_error: Var = Var.create_safe( diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index a1f9df9d0..dfc62bc0f 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -1,4 +1,5 @@ """A file upload component.""" + from __future__ import annotations import os @@ -31,7 +32,7 @@ upload_files_context_var_data: VarData = VarData( # type: ignore }, }, hooks={ - "const [filesById, setFilesById] = useContext(UploadFilesContext);", + "const [filesById, setFilesById] = useContext(UploadFilesContext);": None, }, ) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 9e33f1a02..dc7aa5355 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -1,4 +1,5 @@ """Element classes. This is an auto-generated file. Do not edit. See ../generate.py.""" + from __future__ import annotations from hashlib import md5 @@ -162,7 +163,7 @@ class Form(BaseHTML): props["handle_submit_unique_name"] = "" form = super().create(*children, **props) form.handle_submit_unique_name = md5( - str(form.get_hooks_internal().union(form.get_hooks())).encode("utf-8") + str({**form.get_hooks_internal(), **form.get_hooks()}).encode("utf-8") ).hexdigest() return form diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index 1a125cbe3..5328cb6ff 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -293,8 +293,8 @@ class Markdown(Component): hooks = set() for _component in self.component_map.values(): comp = _component(_MOCK_ARG) - hooks |= comp.get_hooks_internal() - hooks |= comp.get_hooks() + hooks.update(comp.get_hooks_internal()) + hooks.update(comp.get_hooks()) formatted_hooks = "\n".join(hooks) return f""" function {self._get_component_map_name()} () {{ diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index f172dfcec..b37c0d1ce 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -111,6 +111,14 @@ class Hooks(SimpleNamespace): """Common sets of hook declarations.""" EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);" + AUTOFOCUS = """ + // Set focus to the specified element. + const focusRef = useRef(null) + useEffect(() => { + if (focusRef.current) { + focusRef.current.focus(); + } + })""" class MemoizationDisposition(enum.Enum): diff --git a/reflex/style.py b/reflex/style.py index 3b916da43..159f05454 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -22,7 +22,7 @@ color_mode_var_data = VarData( # type: ignore "react": {ImportVar(tag="useContext")}, }, hooks={ - f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)", + f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None, }, ) # Var resolves to the current color mode for the app ("light" or "dark") @@ -240,9 +240,9 @@ def format_as_emotion(style_dict: dict[str, Any]) -> Style | None: if isinstance(value, list): # Apply media queries from responsive value list. mbps = { - media_query(bp): bp_value - if isinstance(bp_value, dict) - else {key: bp_value} + media_query(bp): ( + bp_value if isinstance(bp_value, dict) else {key: bp_value} + ) for bp, bp_value in enumerate(value) } if key.startswith("&:"): diff --git a/reflex/vars.py b/reflex/vars.py index 68964ed5c..ccabf9185 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1,4 +1,5 @@ """Define a state var.""" + from __future__ import annotations import contextlib @@ -21,7 +22,6 @@ from typing import ( List, Literal, Optional, - Set, Tuple, Type, Union, @@ -119,7 +119,7 @@ class VarData(Base): imports: ImportDict = {} # Hooks that need to be present in the component to render this var - hooks: Set[str] = set() + hooks: Dict[str, None] = {} # Positions of interpolated strings. This is used by the decoder to figure # out where the interpolations are and only escape the non-interpolated @@ -138,7 +138,7 @@ class VarData(Base): """ state = "" _imports = {} - hooks = set() + hooks = {} interpolations = [] for var_data in others: if var_data is None: @@ -182,7 +182,7 @@ class VarData(Base): # not part of the vardata itself. return ( self.state == other.state - and self.hooks == other.hooks + and self.hooks.keys() == other.hooks.keys() and imports.collapse_imports(self.imports) == imports.collapse_imports(other.imports) ) @@ -200,7 +200,7 @@ class VarData(Base): lib: [import_var.dict() for import_var in import_vars] for lib, import_vars in self.imports.items() }, - "hooks": list(self.hooks), + "hooks": self.hooks, } @@ -1659,7 +1659,7 @@ class Var: hooks={ "const {0} = useContext(StateContexts.{0})".format( format.format_state_name(state_name) - ) + ): None }, imports={ f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")], diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 959a54f74..fb2ed4657 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -1,4 +1,5 @@ """ Generated with stubgen from mypy, then manually edited, do not regen.""" + from __future__ import annotations from dataclasses import dataclass @@ -35,7 +36,7 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: ... class VarData(Base): state: str imports: dict[str, set[ImportVar]] - hooks: set[str] + hooks: Dict[str, None] interpolations: List[Tuple[int, int]] @classmethod def merge(cls, *others: VarData | None) -> VarData | None: ... diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 9cdf14f70..f63d81fbd 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -596,7 +596,7 @@ def test_get_hooks_nested2(component3, component4): component3: component with hooks defined. component4: component with different hooks defined. """ - exp_hooks = component3().get_hooks().union(component4().get_hooks()) + exp_hooks = {**component3().get_hooks(), **component4().get_hooks()} assert component3.create(component4.create()).get_hooks() == exp_hooks assert component4.create(component3.create()).get_hooks() == exp_hooks assert ( @@ -725,7 +725,7 @@ def test_stateful_banner(): TEST_VAR = Var.create_safe("test")._replace( merge_var_data=VarData( - hooks={"useTest"}, + hooks={"useTest": None}, imports={"test": {ImportVar(tag="test")}}, state="Test", interpolations=[], diff --git a/tests/test_var.py b/tests/test_var.py index 797d48ed7..61286dc83 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -836,7 +836,7 @@ def test_state_with_initial_computed_var( (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"), ( f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}", - 'testing f-string with ${"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}{state.myvar}', + 'testing f-string with ${"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}{state.myvar}', ), ( f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",