use dict instead of set to store hooks ()

This commit is contained in:
Thomas Brandého 2024-04-04 02:13:42 +02:00 committed by GitHub
parent b2c51b82a5
commit 34ee07ecd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 65 additions and 101 deletions

View File

@ -8,10 +8,6 @@
{% block export %} {% block export %}
export default function Component() { export default function Component() {
{% for ref_hook in ref_hooks %}
{{ ref_hook }}
{% endfor %}
{% for hook in hooks %} {% for hook in hooks %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}

View File

@ -1,10 +1,6 @@
{% import 'web/pages/utils.js.jinja2' as utils %} {% import 'web/pages/utils.js.jinja2' as utils %}
export function {{tag_name}} () { export function {{tag_name}} () {
{% for ref_hook in component.get_ref_hooks() %}
{{ ref_hook }}
{% endfor %}
{% for hook in component.get_hooks_internal() %} {% for hook in component.get_hooks_internal() %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}

View File

@ -50,7 +50,7 @@ def _compile_app(app_root: Component) -> str:
return templates.APP_ROOT.render( return templates.APP_ROOT.render(
imports=utils.compile_imports(app_root.get_imports()), imports=utils.compile_imports(app_root.get_imports()),
custom_codes=app_root.get_custom_code(), custom_codes=app_root.get_custom_code(),
hooks=app_root.get_hooks_internal() | app_root.get_hooks(), hooks={**app_root.get_hooks_internal(), **app_root.get_hooks()},
render=app_root.render(), render=app_root.render(),
) )
@ -119,8 +119,7 @@ def _compile_page(
imports=imports, imports=imports,
dynamic_imports=component.get_dynamic_imports(), dynamic_imports=component.get_dynamic_imports(),
custom_codes=component.get_custom_code(), custom_codes=component.get_custom_code(),
ref_hooks=component.get_ref_hooks(), hooks={**component.get_hooks_internal(), **component.get_hooks()},
hooks=component.get_hooks_internal() | component.get_hooks(),
render=component.render(), render=component.render(),
**kwargs, **kwargs,
) )

View File

@ -265,7 +265,7 @@ def compile_custom_component(
"name": component.tag, "name": component.tag,
"props": props, "props": props,
"render": render.render(), "render": render.render(),
"hooks": render.get_hooks_internal() | render.get_hooks(), "hooks": {**render.get_hooks_internal(), **render.get_hooks()},
"custom_code": render.get_custom_code(), "custom_code": render.get_custom_code(),
}, },
imports, imports,

View File

@ -76,15 +76,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def get_ref_hooks(self) -> set[str]: def get_hooks_internal(self) -> dict[str, None]:
"""Get the hooks required by refs in this component.
Returns:
The hooks for the refs.
"""
@abstractmethod
def get_hooks_internal(self) -> set[str]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -92,7 +84,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def get_hooks(self) -> set[str]: def get_hooks(self) -> dict[str, None]:
"""Get the React hooks for this component. """Get the React hooks for this component.
Returns: Returns:
@ -929,7 +921,7 @@ class Component(BaseComponent, ABC):
""" """
return None return None
def get_custom_code(self) -> Set[str]: def get_custom_code(self) -> set[str]:
"""Get custom code for the component and its children. """Get custom code for the component and its children.
Returns: Returns:
@ -1108,62 +1100,53 @@ class Component(BaseComponent, ABC):
if ref is not None: if ref is not None:
return f"const {ref} = useRef(null); {str(Var.create_safe(ref).as_ref())} = {ref};" return f"const {ref} = useRef(null); {str(Var.create_safe(ref).as_ref())} = {ref};"
def _get_vars_hooks(self) -> set[str]: def _get_vars_hooks(self) -> dict[str, None]:
"""Get the hooks required by vars referenced in this component. """Get the hooks required by vars referenced in this component.
Returns: Returns:
The hooks for the vars. The hooks for the vars.
""" """
vars_hooks = set() vars_hooks = {}
for var in self._get_vars(): for var in self._get_vars():
if var._var_data: if var._var_data:
vars_hooks.update(var._var_data.hooks) vars_hooks.update(var._var_data.hooks)
return vars_hooks return vars_hooks
def _get_events_hooks(self) -> set[str]: def _get_events_hooks(self) -> dict[str, None]:
"""Get the hooks required by events referenced in this component. """Get the hooks required by events referenced in this component.
Returns: Returns:
The hooks for the events. The hooks for the events.
""" """
if self.event_triggers: return {Hooks.EVENTS: None} if self.event_triggers else {}
return {Hooks.EVENTS}
return set()
def _get_special_hooks(self) -> set[str]: def _get_special_hooks(self) -> dict[str, None]:
"""Get the hooks required by special actions referenced in this component. """Get the hooks required by special actions referenced in this component.
Returns: Returns:
The hooks for special actions. The hooks for special actions.
""" """
if self.autofocus: return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
return {
"""
// Set focus to the specified element.
const focusRef = useRef(null)
useEffect(() => {
if (focusRef.current) {
focusRef.current.focus();
}
})""",
}
return set()
def _get_hooks_internal(self) -> Set[str]: def _get_hooks_internal(self) -> dict[str, None]:
"""Get the React hooks for this component managed by the framework. """Get the React hooks for this component managed by the framework.
Downstream components should NOT override this method to avoid breaking Downstream components should NOT override this method to avoid breaking
framework functionality. framework functionality.
Returns: Returns:
Set of internally managed hooks. The internally managed hooks.
""" """
return ( return {
self._get_vars_hooks() **{
| self._get_events_hooks() hook: None
| self._get_special_hooks() for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
| set(hook for hook in [self._get_mount_lifecycle_hook()] if hook) if hook is not None
) },
**self._get_vars_hooks(),
**self._get_events_hooks(),
**self._get_special_hooks(),
}
def _get_hooks(self) -> str | None: def _get_hooks(self) -> str | None:
"""Get the React hooks for this component. """Get the React hooks for this component.
@ -1175,20 +1158,7 @@ class Component(BaseComponent, ABC):
""" """
return return
def get_ref_hooks(self) -> Set[str]: def get_hooks_internal(self) -> dict[str, None]:
"""Get the ref hooks for the component and its children.
Returns:
The ref hooks.
"""
ref_hook = self._get_ref_hook()
hooks = set() if ref_hook is None else {ref_hook}
for child in self.children:
hooks |= child.get_ref_hooks()
return hooks
def get_hooks_internal(self) -> set[str]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -1199,26 +1169,26 @@ class Component(BaseComponent, ABC):
# Add the hook code for the children. # Add the hook code for the children.
for child in self.children: for child in self.children:
code |= child.get_hooks_internal() code = {**code, **child.get_hooks_internal()}
return code return code
def get_hooks(self) -> Set[str]: def get_hooks(self) -> dict[str, None]:
"""Get the React hooks for this component and its children. """Get the React hooks for this component and its children.
Returns: Returns:
The code that should appear just before returning the rendered component. The code that should appear just before returning the rendered component.
""" """
code = set() code = {}
# Add the hook code for this component. # Add the hook code for this component.
hooks = self._get_hooks() hooks = self._get_hooks()
if hooks is not None: if hooks is not None:
code.add(hooks) code[hooks] = None
# Add the hook code for the children. # Add the hook code for the children.
for child in self.children: for child in self.children:
code |= child.get_hooks() code = {**code, **child.get_hooks()}
return code return code
@ -1233,7 +1203,7 @@ class Component(BaseComponent, ABC):
return None return None
return format.format_ref(self.id) return format.format_ref(self.id)
def get_refs(self) -> Set[str]: def get_refs(self) -> set[str]:
"""Get the refs for the children of the component. """Get the refs for the children of the component.
Returns: Returns:
@ -1854,29 +1824,21 @@ class StatefulComponent(BaseComponent):
) )
return trigger_memo return trigger_memo
def get_ref_hooks(self) -> set[str]: def get_hooks_internal(self) -> dict[str, None]:
"""Get the ref hooks for the component and its children.
Returns:
The ref hooks.
"""
return set()
def get_hooks_internal(self) -> set[str]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
The code that should appear just before user-defined hooks. The code that should appear just before user-defined hooks.
""" """
return set() return {}
def get_hooks(self) -> set[str]: def get_hooks(self) -> dict[str, None]:
"""Get the React hooks for this component. """Get the React hooks for this component.
Returns: Returns:
The code that should appear just before returning the rendered component. The code that should appear just before returning the rendered component.
""" """
return set() return {}
def get_imports(self) -> imports.ImportDict: def get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.

View File

@ -22,7 +22,7 @@ from reflex.vars import Var, VarData
connect_error_var_data: VarData = VarData( # type: ignore connect_error_var_data: VarData = VarData( # type: ignore
imports=Imports.EVENTS, imports=Imports.EVENTS,
hooks={Hooks.EVENTS}, hooks={Hooks.EVENTS: None},
) )
connection_error: Var = Var.create_safe( connection_error: Var = Var.create_safe(

View File

@ -1,4 +1,5 @@
"""A file upload component.""" """A file upload component."""
from __future__ import annotations from __future__ import annotations
import os import os
@ -31,7 +32,7 @@ upload_files_context_var_data: VarData = VarData( # type: ignore
}, },
}, },
hooks={ hooks={
"const [filesById, setFilesById] = useContext(UploadFilesContext);", "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
}, },
) )

View File

@ -1,4 +1,5 @@
"""Element classes. This is an auto-generated file. Do not edit. See ../generate.py.""" """Element classes. This is an auto-generated file. Do not edit. See ../generate.py."""
from __future__ import annotations from __future__ import annotations
from hashlib import md5 from hashlib import md5
@ -162,7 +163,7 @@ class Form(BaseHTML):
props["handle_submit_unique_name"] = "" props["handle_submit_unique_name"] = ""
form = super().create(*children, **props) form = super().create(*children, **props)
form.handle_submit_unique_name = md5( form.handle_submit_unique_name = md5(
str(form.get_hooks_internal().union(form.get_hooks())).encode("utf-8") str({**form.get_hooks_internal(), **form.get_hooks()}).encode("utf-8")
).hexdigest() ).hexdigest()
return form return form

View File

@ -293,8 +293,8 @@ class Markdown(Component):
hooks = set() hooks = set()
for _component in self.component_map.values(): for _component in self.component_map.values():
comp = _component(_MOCK_ARG) comp = _component(_MOCK_ARG)
hooks |= comp.get_hooks_internal() hooks.update(comp.get_hooks_internal())
hooks |= comp.get_hooks() hooks.update(comp.get_hooks())
formatted_hooks = "\n".join(hooks) formatted_hooks = "\n".join(hooks)
return f""" return f"""
function {self._get_component_map_name()} () {{ function {self._get_component_map_name()} () {{

View File

@ -111,6 +111,14 @@ class Hooks(SimpleNamespace):
"""Common sets of hook declarations.""" """Common sets of hook declarations."""
EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);" EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"
AUTOFOCUS = """
// Set focus to the specified element.
const focusRef = useRef(null)
useEffect(() => {
if (focusRef.current) {
focusRef.current.focus();
}
})"""
class MemoizationDisposition(enum.Enum): class MemoizationDisposition(enum.Enum):

View File

@ -22,7 +22,7 @@ color_mode_var_data = VarData( # type: ignore
"react": {ImportVar(tag="useContext")}, "react": {ImportVar(tag="useContext")},
}, },
hooks={ hooks={
f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)", f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,
}, },
) )
# Var resolves to the current color mode for the app ("light" or "dark") # Var resolves to the current color mode for the app ("light" or "dark")
@ -240,9 +240,9 @@ def format_as_emotion(style_dict: dict[str, Any]) -> Style | None:
if isinstance(value, list): if isinstance(value, list):
# Apply media queries from responsive value list. # Apply media queries from responsive value list.
mbps = { mbps = {
media_query(bp): bp_value media_query(bp): (
if isinstance(bp_value, dict) bp_value if isinstance(bp_value, dict) else {key: bp_value}
else {key: bp_value} )
for bp, bp_value in enumerate(value) for bp, bp_value in enumerate(value)
} }
if key.startswith("&:"): if key.startswith("&:"):

View File

@ -1,4 +1,5 @@
"""Define a state var.""" """Define a state var."""
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
@ -21,7 +22,6 @@ from typing import (
List, List,
Literal, Literal,
Optional, Optional,
Set,
Tuple, Tuple,
Type, Type,
Union, Union,
@ -119,7 +119,7 @@ class VarData(Base):
imports: ImportDict = {} imports: ImportDict = {}
# Hooks that need to be present in the component to render this var # Hooks that need to be present in the component to render this var
hooks: Set[str] = set() hooks: Dict[str, None] = {}
# Positions of interpolated strings. This is used by the decoder to figure # Positions of interpolated strings. This is used by the decoder to figure
# out where the interpolations are and only escape the non-interpolated # out where the interpolations are and only escape the non-interpolated
@ -138,7 +138,7 @@ class VarData(Base):
""" """
state = "" state = ""
_imports = {} _imports = {}
hooks = set() hooks = {}
interpolations = [] interpolations = []
for var_data in others: for var_data in others:
if var_data is None: if var_data is None:
@ -182,7 +182,7 @@ class VarData(Base):
# not part of the vardata itself. # not part of the vardata itself.
return ( return (
self.state == other.state self.state == other.state
and self.hooks == other.hooks and self.hooks.keys() == other.hooks.keys()
and imports.collapse_imports(self.imports) and imports.collapse_imports(self.imports)
== imports.collapse_imports(other.imports) == imports.collapse_imports(other.imports)
) )
@ -200,7 +200,7 @@ class VarData(Base):
lib: [import_var.dict() for import_var in import_vars] lib: [import_var.dict() for import_var in import_vars]
for lib, import_vars in self.imports.items() for lib, import_vars in self.imports.items()
}, },
"hooks": list(self.hooks), "hooks": self.hooks,
} }
@ -1659,7 +1659,7 @@ class Var:
hooks={ hooks={
"const {0} = useContext(StateContexts.{0})".format( "const {0} = useContext(StateContexts.{0})".format(
format.format_state_name(state_name) format.format_state_name(state_name)
) ): None
}, },
imports={ imports={
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")], f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],

View File

@ -1,4 +1,5 @@
""" Generated with stubgen from mypy, then manually edited, do not regen.""" """ Generated with stubgen from mypy, then manually edited, do not regen."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
@ -35,7 +36,7 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
class VarData(Base): class VarData(Base):
state: str state: str
imports: dict[str, set[ImportVar]] imports: dict[str, set[ImportVar]]
hooks: set[str] hooks: Dict[str, None]
interpolations: List[Tuple[int, int]] interpolations: List[Tuple[int, int]]
@classmethod @classmethod
def merge(cls, *others: VarData | None) -> VarData | None: ... def merge(cls, *others: VarData | None) -> VarData | None: ...

View File

@ -596,7 +596,7 @@ def test_get_hooks_nested2(component3, component4):
component3: component with hooks defined. component3: component with hooks defined.
component4: component with different hooks defined. component4: component with different hooks defined.
""" """
exp_hooks = component3().get_hooks().union(component4().get_hooks()) exp_hooks = {**component3().get_hooks(), **component4().get_hooks()}
assert component3.create(component4.create()).get_hooks() == exp_hooks assert component3.create(component4.create()).get_hooks() == exp_hooks
assert component4.create(component3.create()).get_hooks() == exp_hooks assert component4.create(component3.create()).get_hooks() == exp_hooks
assert ( assert (
@ -725,7 +725,7 @@ def test_stateful_banner():
TEST_VAR = Var.create_safe("test")._replace( TEST_VAR = Var.create_safe("test")._replace(
merge_var_data=VarData( merge_var_data=VarData(
hooks={"useTest"}, hooks={"useTest": None},
imports={"test": {ImportVar(tag="test")}}, imports={"test": {ImportVar(tag="test")}},
state="Test", state="Test",
interpolations=[], interpolations=[],

View File

@ -836,7 +836,7 @@ def test_state_with_initial_computed_var(
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"), (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
( (
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}", f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}</reflex.Var>{state.myvar}', 'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}</reflex.Var>{state.myvar}',
), ),
( (
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}", f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",