diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2
index e29deee34..efb086ef5 100644
--- a/reflex/.templates/jinja/web/pages/index.js.jinja2
+++ b/reflex/.templates/jinja/web/pages/index.js.jinja2
@@ -8,10 +8,6 @@
{% block export %}
export default function Component() {
- {% for ref_hook in ref_hooks %}
- {{ ref_hook }}
- {% endfor %}
-
{% for hook in hooks %}
{{ hook }}
{% endfor %}
diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2
index ccb375893..ad970c2f5 100644
--- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2
+++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2
@@ -1,10 +1,6 @@
{% import 'web/pages/utils.js.jinja2' as utils %}
export function {{tag_name}} () {
- {% for ref_hook in component.get_ref_hooks() %}
- {{ ref_hook }}
- {% endfor %}
-
{% for hook in component.get_hooks_internal() %}
{{ hook }}
{% endfor %}
diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py
index ef0d3e4ca..02b2a2a16 100644
--- a/reflex/compiler/compiler.py
+++ b/reflex/compiler/compiler.py
@@ -50,7 +50,7 @@ def _compile_app(app_root: Component) -> str:
return templates.APP_ROOT.render(
imports=utils.compile_imports(app_root.get_imports()),
custom_codes=app_root.get_custom_code(),
- hooks=app_root.get_hooks_internal() | app_root.get_hooks(),
+ hooks={**app_root.get_hooks_internal(), **app_root.get_hooks()},
render=app_root.render(),
)
@@ -119,8 +119,7 @@ def _compile_page(
imports=imports,
dynamic_imports=component.get_dynamic_imports(),
custom_codes=component.get_custom_code(),
- ref_hooks=component.get_ref_hooks(),
- hooks=component.get_hooks_internal() | component.get_hooks(),
+ hooks={**component.get_hooks_internal(), **component.get_hooks()},
render=component.render(),
**kwargs,
)
diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py
index 0789a4548..3ebba40e0 100644
--- a/reflex/compiler/utils.py
+++ b/reflex/compiler/utils.py
@@ -265,7 +265,7 @@ def compile_custom_component(
"name": component.tag,
"props": props,
"render": render.render(),
- "hooks": render.get_hooks_internal() | render.get_hooks(),
+ "hooks": {**render.get_hooks_internal(), **render.get_hooks()},
"custom_code": render.get_custom_code(),
},
imports,
diff --git a/reflex/components/component.py b/reflex/components/component.py
index c8c3f5b9d..c873e687b 100644
--- a/reflex/components/component.py
+++ b/reflex/components/component.py
@@ -76,15 +76,7 @@ class BaseComponent(Base, ABC):
"""
@abstractmethod
- def get_ref_hooks(self) -> set[str]:
- """Get the hooks required by refs in this component.
-
- Returns:
- The hooks for the refs.
- """
-
- @abstractmethod
- def get_hooks_internal(self) -> set[str]:
+ def get_hooks_internal(self) -> dict[str, None]:
"""Get the reflex internal hooks for the component and its children.
Returns:
@@ -92,7 +84,7 @@ class BaseComponent(Base, ABC):
"""
@abstractmethod
- def get_hooks(self) -> set[str]:
+ def get_hooks(self) -> dict[str, None]:
"""Get the React hooks for this component.
Returns:
@@ -929,7 +921,7 @@ class Component(BaseComponent, ABC):
"""
return None
- def get_custom_code(self) -> Set[str]:
+ def get_custom_code(self) -> set[str]:
"""Get custom code for the component and its children.
Returns:
@@ -1108,62 +1100,53 @@ class Component(BaseComponent, ABC):
if ref is not None:
return f"const {ref} = useRef(null); {str(Var.create_safe(ref).as_ref())} = {ref};"
- def _get_vars_hooks(self) -> set[str]:
+ def _get_vars_hooks(self) -> dict[str, None]:
"""Get the hooks required by vars referenced in this component.
Returns:
The hooks for the vars.
"""
- vars_hooks = set()
+ vars_hooks = {}
for var in self._get_vars():
if var._var_data:
vars_hooks.update(var._var_data.hooks)
return vars_hooks
- def _get_events_hooks(self) -> set[str]:
+ def _get_events_hooks(self) -> dict[str, None]:
"""Get the hooks required by events referenced in this component.
Returns:
The hooks for the events.
"""
- if self.event_triggers:
- return {Hooks.EVENTS}
- return set()
+ return {Hooks.EVENTS: None} if self.event_triggers else {}
- def _get_special_hooks(self) -> set[str]:
+ def _get_special_hooks(self) -> dict[str, None]:
"""Get the hooks required by special actions referenced in this component.
Returns:
The hooks for special actions.
"""
- if self.autofocus:
- return {
- """
- // Set focus to the specified element.
- const focusRef = useRef(null)
- useEffect(() => {
- if (focusRef.current) {
- focusRef.current.focus();
- }
- })""",
- }
- return set()
+ return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
- def _get_hooks_internal(self) -> Set[str]:
+ def _get_hooks_internal(self) -> dict[str, None]:
"""Get the React hooks for this component managed by the framework.
Downstream components should NOT override this method to avoid breaking
framework functionality.
Returns:
- Set of internally managed hooks.
+ The internally managed hooks.
"""
- return (
- self._get_vars_hooks()
- | self._get_events_hooks()
- | self._get_special_hooks()
- | set(hook for hook in [self._get_mount_lifecycle_hook()] if hook)
- )
+ return {
+ **{
+ hook: None
+ for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
+ if hook is not None
+ },
+ **self._get_vars_hooks(),
+ **self._get_events_hooks(),
+ **self._get_special_hooks(),
+ }
def _get_hooks(self) -> str | None:
"""Get the React hooks for this component.
@@ -1175,20 +1158,7 @@ class Component(BaseComponent, ABC):
"""
return
- def get_ref_hooks(self) -> Set[str]:
- """Get the ref hooks for the component and its children.
-
- Returns:
- The ref hooks.
- """
- ref_hook = self._get_ref_hook()
- hooks = set() if ref_hook is None else {ref_hook}
-
- for child in self.children:
- hooks |= child.get_ref_hooks()
- return hooks
-
- def get_hooks_internal(self) -> set[str]:
+ def get_hooks_internal(self) -> dict[str, None]:
"""Get the reflex internal hooks for the component and its children.
Returns:
@@ -1199,26 +1169,26 @@ class Component(BaseComponent, ABC):
# Add the hook code for the children.
for child in self.children:
- code |= child.get_hooks_internal()
+ code = {**code, **child.get_hooks_internal()}
return code
- def get_hooks(self) -> Set[str]:
+ def get_hooks(self) -> dict[str, None]:
"""Get the React hooks for this component and its children.
Returns:
The code that should appear just before returning the rendered component.
"""
- code = set()
+ code = {}
# Add the hook code for this component.
hooks = self._get_hooks()
if hooks is not None:
- code.add(hooks)
+ code[hooks] = None
# Add the hook code for the children.
for child in self.children:
- code |= child.get_hooks()
+ code = {**code, **child.get_hooks()}
return code
@@ -1233,7 +1203,7 @@ class Component(BaseComponent, ABC):
return None
return format.format_ref(self.id)
- def get_refs(self) -> Set[str]:
+ def get_refs(self) -> set[str]:
"""Get the refs for the children of the component.
Returns:
@@ -1854,29 +1824,21 @@ class StatefulComponent(BaseComponent):
)
return trigger_memo
- def get_ref_hooks(self) -> set[str]:
- """Get the ref hooks for the component and its children.
-
- Returns:
- The ref hooks.
- """
- return set()
-
- def get_hooks_internal(self) -> set[str]:
+ def get_hooks_internal(self) -> dict[str, None]:
"""Get the reflex internal hooks for the component and its children.
Returns:
The code that should appear just before user-defined hooks.
"""
- return set()
+ return {}
- def get_hooks(self) -> set[str]:
+ def get_hooks(self) -> dict[str, None]:
"""Get the React hooks for this component.
Returns:
The code that should appear just before returning the rendered component.
"""
- return set()
+ return {}
def get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component.
diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py
index ea20179e1..0c781fba8 100644
--- a/reflex/components/core/banner.py
+++ b/reflex/components/core/banner.py
@@ -22,7 +22,7 @@ from reflex.vars import Var, VarData
connect_error_var_data: VarData = VarData( # type: ignore
imports=Imports.EVENTS,
- hooks={Hooks.EVENTS},
+ hooks={Hooks.EVENTS: None},
)
connection_error: Var = Var.create_safe(
diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py
index a1f9df9d0..dfc62bc0f 100644
--- a/reflex/components/core/upload.py
+++ b/reflex/components/core/upload.py
@@ -1,4 +1,5 @@
"""A file upload component."""
+
from __future__ import annotations
import os
@@ -31,7 +32,7 @@ upload_files_context_var_data: VarData = VarData( # type: ignore
},
},
hooks={
- "const [filesById, setFilesById] = useContext(UploadFilesContext);",
+ "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
},
)
diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py
index 9e33f1a02..dc7aa5355 100644
--- a/reflex/components/el/elements/forms.py
+++ b/reflex/components/el/elements/forms.py
@@ -1,4 +1,5 @@
"""Element classes. This is an auto-generated file. Do not edit. See ../generate.py."""
+
from __future__ import annotations
from hashlib import md5
@@ -162,7 +163,7 @@ class Form(BaseHTML):
props["handle_submit_unique_name"] = ""
form = super().create(*children, **props)
form.handle_submit_unique_name = md5(
- str(form.get_hooks_internal().union(form.get_hooks())).encode("utf-8")
+ str({**form.get_hooks_internal(), **form.get_hooks()}).encode("utf-8")
).hexdigest()
return form
diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py
index 1a125cbe3..5328cb6ff 100644
--- a/reflex/components/markdown/markdown.py
+++ b/reflex/components/markdown/markdown.py
@@ -293,8 +293,8 @@ class Markdown(Component):
hooks = set()
for _component in self.component_map.values():
comp = _component(_MOCK_ARG)
- hooks |= comp.get_hooks_internal()
- hooks |= comp.get_hooks()
+ hooks.update(comp.get_hooks_internal())
+ hooks.update(comp.get_hooks())
formatted_hooks = "\n".join(hooks)
return f"""
function {self._get_component_map_name()} () {{
diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py
index f172dfcec..b37c0d1ce 100644
--- a/reflex/constants/compiler.py
+++ b/reflex/constants/compiler.py
@@ -111,6 +111,14 @@ class Hooks(SimpleNamespace):
"""Common sets of hook declarations."""
EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"
+ AUTOFOCUS = """
+ // Set focus to the specified element.
+ const focusRef = useRef(null)
+ useEffect(() => {
+ if (focusRef.current) {
+ focusRef.current.focus();
+ }
+ })"""
class MemoizationDisposition(enum.Enum):
diff --git a/reflex/style.py b/reflex/style.py
index 3b916da43..159f05454 100644
--- a/reflex/style.py
+++ b/reflex/style.py
@@ -22,7 +22,7 @@ color_mode_var_data = VarData( # type: ignore
"react": {ImportVar(tag="useContext")},
},
hooks={
- f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)",
+ f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,
},
)
# Var resolves to the current color mode for the app ("light" or "dark")
@@ -240,9 +240,9 @@ def format_as_emotion(style_dict: dict[str, Any]) -> Style | None:
if isinstance(value, list):
# Apply media queries from responsive value list.
mbps = {
- media_query(bp): bp_value
- if isinstance(bp_value, dict)
- else {key: bp_value}
+ media_query(bp): (
+ bp_value if isinstance(bp_value, dict) else {key: bp_value}
+ )
for bp, bp_value in enumerate(value)
}
if key.startswith("&:"):
diff --git a/reflex/vars.py b/reflex/vars.py
index 68964ed5c..ccabf9185 100644
--- a/reflex/vars.py
+++ b/reflex/vars.py
@@ -1,4 +1,5 @@
"""Define a state var."""
+
from __future__ import annotations
import contextlib
@@ -21,7 +22,6 @@ from typing import (
List,
Literal,
Optional,
- Set,
Tuple,
Type,
Union,
@@ -119,7 +119,7 @@ class VarData(Base):
imports: ImportDict = {}
# Hooks that need to be present in the component to render this var
- hooks: Set[str] = set()
+ hooks: Dict[str, None] = {}
# Positions of interpolated strings. This is used by the decoder to figure
# out where the interpolations are and only escape the non-interpolated
@@ -138,7 +138,7 @@ class VarData(Base):
"""
state = ""
_imports = {}
- hooks = set()
+ hooks = {}
interpolations = []
for var_data in others:
if var_data is None:
@@ -182,7 +182,7 @@ class VarData(Base):
# not part of the vardata itself.
return (
self.state == other.state
- and self.hooks == other.hooks
+ and self.hooks.keys() == other.hooks.keys()
and imports.collapse_imports(self.imports)
== imports.collapse_imports(other.imports)
)
@@ -200,7 +200,7 @@ class VarData(Base):
lib: [import_var.dict() for import_var in import_vars]
for lib, import_vars in self.imports.items()
},
- "hooks": list(self.hooks),
+ "hooks": self.hooks,
}
@@ -1659,7 +1659,7 @@ class Var:
hooks={
"const {0} = useContext(StateContexts.{0})".format(
format.format_state_name(state_name)
- )
+ ): None
},
imports={
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
diff --git a/reflex/vars.pyi b/reflex/vars.pyi
index 959a54f74..fb2ed4657 100644
--- a/reflex/vars.pyi
+++ b/reflex/vars.pyi
@@ -1,4 +1,5 @@
""" Generated with stubgen from mypy, then manually edited, do not regen."""
+
from __future__ import annotations
from dataclasses import dataclass
@@ -35,7 +36,7 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
class VarData(Base):
state: str
imports: dict[str, set[ImportVar]]
- hooks: set[str]
+ hooks: Dict[str, None]
interpolations: List[Tuple[int, int]]
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None: ...
diff --git a/tests/components/test_component.py b/tests/components/test_component.py
index 9cdf14f70..f63d81fbd 100644
--- a/tests/components/test_component.py
+++ b/tests/components/test_component.py
@@ -596,7 +596,7 @@ def test_get_hooks_nested2(component3, component4):
component3: component with hooks defined.
component4: component with different hooks defined.
"""
- exp_hooks = component3().get_hooks().union(component4().get_hooks())
+ exp_hooks = {**component3().get_hooks(), **component4().get_hooks()}
assert component3.create(component4.create()).get_hooks() == exp_hooks
assert component4.create(component3.create()).get_hooks() == exp_hooks
assert (
@@ -725,7 +725,7 @@ def test_stateful_banner():
TEST_VAR = Var.create_safe("test")._replace(
merge_var_data=VarData(
- hooks={"useTest"},
+ hooks={"useTest": None},
imports={"test": {ImportVar(tag="test")}},
state="Test",
interpolations=[],
diff --git a/tests/test_var.py b/tests/test_var.py
index 797d48ed7..61286dc83 100644
--- a/tests/test_var.py
+++ b/tests/test_var.py
@@ -836,7 +836,7 @@ def test_state_with_initial_computed_var(
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
(
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
- 'testing f-string with ${"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}{state.myvar}',
+ 'testing f-string with ${"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}{state.myvar}',
),
(
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",