[REF-2296] Rename recursive functions (#2999)

This commit is contained in:
Martin Xu 2024-04-04 09:26:31 -07:00 committed by GitHub
parent 619c0b0d06
commit 5274f04b66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 112 additions and 107 deletions

View File

@ -1,7 +1,7 @@
{% import 'web/pages/utils.js.jinja2' as utils %} {% import 'web/pages/utils.js.jinja2' as utils %}
export function {{tag_name}} () { export function {{tag_name}} () {
{% for hook in component.get_hooks_internal() %} {% for hook in component._get_all_hooks_internal() %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
@ -9,7 +9,7 @@ export function {{tag_name}} () {
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}
{% for hook in component.get_hooks() %} {% for hook in component._get_all_hooks() %}
{{ hook }} {{ hook }}
{% endfor %} {% endfor %}

View File

@ -653,7 +653,7 @@ class App(Base):
def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component: def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
for component in tuple(app_wrappers.values()): for component in tuple(app_wrappers.values()):
app_wrappers.update(component.get_app_wrap_components()) app_wrappers.update(component._get_all_app_wrap_components())
order = sorted(app_wrappers, key=lambda k: k[0], reverse=True) order = sorted(app_wrappers, key=lambda k: k[0], reverse=True)
root = parent = copy.deepcopy(app_wrappers[order[0]]) root = parent = copy.deepcopy(app_wrappers[order[0]])
for key in order[1:]: for key in order[1:]:
@ -791,18 +791,18 @@ class App(Base):
for _route, component in self.pages.items(): for _route, component in self.pages.items():
# Merge the component style with the app style. # Merge the component style with the app style.
component.add_style(self.style) component._add_style_recursive(self.style)
component.apply_theme(self.theme) component.apply_theme(self.theme)
# Add component.get_imports() to all_imports. # Add component._get_all_imports() to all_imports.
all_imports.update(component.get_imports()) all_imports.update(component._get_all_imports())
# Add the app wrappers from this component. # Add the app wrappers from this component.
app_wrappers.update(component.get_app_wrap_components()) app_wrappers.update(component._get_all_app_wrap_components())
# Add the custom components from the page to the set. # Add the custom components from the page to the set.
custom_components |= component.get_custom_components() custom_components |= component._get_all_custom_components()
progress.advance(task) progress.advance(task)
@ -924,7 +924,7 @@ class App(Base):
all_imports.update(custom_components_imports) all_imports.update(custom_components_imports)
# Get imports from AppWrap components. # Get imports from AppWrap components.
all_imports.update(app_root.get_imports()) all_imports.update(app_root._get_all_imports())
progress.advance(task) progress.advance(task)

View File

@ -33,7 +33,7 @@ def _compile_document_root(root: Component) -> str:
The compiled document root. The compiled document root.
""" """
return templates.DOCUMENT_ROOT.render( return templates.DOCUMENT_ROOT.render(
imports=utils.compile_imports(root.get_imports()), imports=utils.compile_imports(root._get_all_imports()),
document=root.render(), document=root.render(),
) )
@ -48,9 +48,9 @@ def _compile_app(app_root: Component) -> str:
The compiled app. The compiled app.
""" """
return templates.APP_ROOT.render( return templates.APP_ROOT.render(
imports=utils.compile_imports(app_root.get_imports()), imports=utils.compile_imports(app_root._get_all_imports()),
custom_codes=app_root.get_custom_code(), custom_codes=app_root._get_all_custom_code(),
hooks={**app_root.get_hooks_internal(), **app_root.get_hooks()}, hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()},
render=app_root.render(), render=app_root.render(),
) )
@ -109,7 +109,7 @@ def _compile_page(
Returns: Returns:
The compiled component. The compiled component.
""" """
imports = component.get_imports() imports = component._get_all_imports()
imports = utils.compile_imports(imports) imports = utils.compile_imports(imports)
# Compile the code to render the component. # Compile the code to render the component.
@ -117,9 +117,9 @@ def _compile_page(
return templates.PAGE.render( return templates.PAGE.render(
imports=imports, imports=imports,
dynamic_imports=component.get_dynamic_imports(), dynamic_imports=component._get_all_dynamic_imports(),
custom_codes=component.get_custom_code(), custom_codes=component._get_all_custom_code(),
hooks={**component.get_hooks_internal(), **component.get_hooks()}, hooks={**component._get_all_hooks_internal(), **component._get_all_hooks()},
render=component.render(), render=component.render(),
**kwargs, **kwargs,
) )
@ -264,9 +264,9 @@ def _compile_stateful_components(
component.rendered_as_shared = False component.rendered_as_shared = False
rendered_components.update( rendered_components.update(
{code: None for code in component.get_custom_code()}, {code: None for code in component._get_all_custom_code()},
) )
all_import_dicts.append(component.get_imports()) all_import_dicts.append(component._get_all_imports())
# Indicate that this component now imports from the shared file. # Indicate that this component now imports from the shared file.
component.rendered_as_shared = True component.rendered_as_shared = True

View File

@ -252,7 +252,7 @@ def compile_custom_component(
# Get the imports. # Get the imports.
imports = { imports = {
lib: fields lib: fields
for lib, fields in render.get_imports().items() for lib, fields in render._get_all_imports().items()
if lib != component.library if lib != component.library
} }
@ -265,8 +265,8 @@ def compile_custom_component(
"name": component.tag, "name": component.tag,
"props": props, "props": props,
"render": render.render(), "render": render.render(),
"hooks": {**render.get_hooks_internal(), **render.get_hooks()}, "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
"custom_code": render.get_custom_code(), "custom_code": render._get_all_custom_code(),
}, },
imports, imports,
) )

View File

@ -1,4 +1,5 @@
"""A pin input component.""" """A pin input component."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -71,7 +72,7 @@ class PinInput(ChakraComponent):
range_var = Var.range(0) range_var = Var.range(0)
return merge_imports( return merge_imports(
super()._get_imports(), super()._get_imports(),
PinInputField().get_imports(), # type: ignore PinInputField()._get_all_imports(), # type: ignore
range_var._var_data.imports if range_var._var_data is not None else {}, range_var._var_data.imports if range_var._var_data is not None else {},
) )

View File

@ -76,7 +76,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def get_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, None]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -84,7 +84,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def get_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, None]:
"""Get the React hooks for this component. """Get the React hooks for this component.
Returns: Returns:
@ -92,7 +92,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def get_imports(self) -> imports.ImportDict: def _get_all_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.
Returns: Returns:
@ -100,7 +100,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def get_dynamic_imports(self) -> set[str]: def _get_all_dynamic_imports(self) -> set[str]:
"""Get dynamic imports for the component. """Get dynamic imports for the component.
Returns: Returns:
@ -108,7 +108,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def get_custom_code(self) -> set[str]: def _get_all_custom_code(self) -> set[str]:
"""Get custom code for the component. """Get custom code for the component.
Returns: Returns:
@ -116,7 +116,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def get_refs(self) -> set[str]: def _get_all_refs(self) -> set[str]:
"""Get the refs for the children of the component. """Get the refs for the children of the component.
Returns: Returns:
@ -669,7 +669,7 @@ class Component(BaseComponent, ABC):
""" """
self.style.update(style) self.style.update(style)
def add_style(self, style: ComponentStyle) -> Component: def _add_style_recursive(self, style: ComponentStyle) -> Component:
"""Add additional style to the component and its children. """Add additional style to the component and its children.
Args: Args:
@ -698,7 +698,7 @@ class Component(BaseComponent, ABC):
# Skip BaseComponent and StatefulComponent children. # Skip BaseComponent and StatefulComponent children.
if not isinstance(child, Component): if not isinstance(child, Component):
continue continue
child.add_style(style) child._add_style_recursive(style)
return self return self
def _get_style(self) -> dict: def _get_style(self) -> dict:
@ -921,7 +921,7 @@ class Component(BaseComponent, ABC):
""" """
return None return None
def get_custom_code(self) -> set[str]: def _get_all_custom_code(self) -> set[str]:
"""Get custom code for the component and its children. """Get custom code for the component and its children.
Returns: Returns:
@ -937,7 +937,7 @@ class Component(BaseComponent, ABC):
# Add the custom code for the children. # Add the custom code for the children.
for child in self.children: for child in self.children:
code |= child.get_custom_code() code |= child._get_all_custom_code()
# Return the code. # Return the code.
return code return code
@ -950,7 +950,7 @@ class Component(BaseComponent, ABC):
""" """
return None return None
def get_dynamic_imports(self) -> Set[str]: def _get_all_dynamic_imports(self) -> Set[str]:
"""Get dynamic imports for the component and its children. """Get dynamic imports for the component and its children.
Returns: Returns:
@ -966,11 +966,11 @@ class Component(BaseComponent, ABC):
# Get the dynamic imports from children # Get the dynamic imports from children
for child in self.children: for child in self.children:
dynamic_imports |= child.get_dynamic_imports() dynamic_imports |= child._get_all_dynamic_imports()
for prop in self.get_component_props(): for prop in self.get_component_props():
if getattr(self, prop) is not None: if getattr(self, prop) is not None:
dynamic_imports |= getattr(self, prop).get_dynamic_imports() dynamic_imports |= getattr(self, prop)._get_all_dynamic_imports()
# Return the dynamic imports # Return the dynamic imports
return dynamic_imports return dynamic_imports
@ -982,7 +982,7 @@ class Component(BaseComponent, ABC):
The imports for the components props of the component. The imports for the components props of the component.
""" """
return [ return [
getattr(self, prop).get_imports() getattr(self, prop)._get_all_imports()
for prop in self.get_component_props() for prop in self.get_component_props()
if getattr(self, prop) is not None if getattr(self, prop) is not None
] ]
@ -1053,7 +1053,7 @@ class Component(BaseComponent, ABC):
*var_imports, *var_imports,
) )
def get_imports(self, collapse: bool = False) -> imports.ImportDict: def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component and its children. """Get all the libraries and fields that are used by the component and its children.
Args: Args:
@ -1063,7 +1063,7 @@ class Component(BaseComponent, ABC):
The import dict with the required imports. The import dict with the required imports.
""" """
_imports = imports.merge_imports( _imports = imports.merge_imports(
self._get_imports(), *[child.get_imports() for child in self.children] self._get_imports(), *[child._get_all_imports() for child in self.children]
) )
return imports.collapse_imports(_imports) if collapse else _imports return imports.collapse_imports(_imports) if collapse else _imports
@ -1158,7 +1158,7 @@ class Component(BaseComponent, ABC):
""" """
return return
def get_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, None]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -1169,11 +1169,11 @@ class Component(BaseComponent, ABC):
# Add the hook code for the children. # Add the hook code for the children.
for child in self.children: for child in self.children:
code = {**code, **child.get_hooks_internal()} code = {**code, **child._get_all_hooks_internal()}
return code return code
def get_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, None]:
"""Get the React hooks for this component and its children. """Get the React hooks for this component and its children.
Returns: Returns:
@ -1188,7 +1188,7 @@ class Component(BaseComponent, ABC):
# Add the hook code for the children. # Add the hook code for the children.
for child in self.children: for child in self.children:
code = {**code, **child.get_hooks()} code = {**code, **child._get_all_hooks()}
return code return code
@ -1203,7 +1203,7 @@ class Component(BaseComponent, ABC):
return None return None
return format.format_ref(self.id) return format.format_ref(self.id)
def get_refs(self) -> set[str]: def _get_all_refs(self) -> set[str]:
"""Get the refs for the children of the component. """Get the refs for the children of the component.
Returns: Returns:
@ -1214,10 +1214,10 @@ class Component(BaseComponent, ABC):
if ref is not None: if ref is not None:
refs.add(ref) refs.add(ref)
for child in self.children: for child in self.children:
refs |= child.get_refs() refs |= child._get_all_refs()
return refs return refs
def get_custom_components( def _get_all_custom_components(
self, seen: set[str] | None = None self, seen: set[str] | None = None
) -> Set[CustomComponent]: ) -> Set[CustomComponent]:
"""Get all the custom components used by the component. """Get all the custom components used by the component.
@ -1237,7 +1237,7 @@ class Component(BaseComponent, ABC):
# Skip BaseComponent and StatefulComponent children. # Skip BaseComponent and StatefulComponent children.
if not isinstance(child, Component): if not isinstance(child, Component):
continue continue
custom_components |= child.get_custom_components(seen=seen) custom_components |= child._get_all_custom_components(seen=seen)
return custom_components return custom_components
@property @property
@ -1261,7 +1261,7 @@ class Component(BaseComponent, ABC):
""" """
return {} return {}
def get_app_wrap_components(self) -> dict[tuple[int, str], Component]: def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]:
"""Get the app wrap components for the component and its children. """Get the app wrap components for the component and its children.
Returns: Returns:
@ -1271,14 +1271,14 @@ class Component(BaseComponent, ABC):
components = self._get_app_wrap_components() components = self._get_app_wrap_components()
for component in tuple(components.values()): for component in tuple(components.values()):
components.update(component.get_app_wrap_components()) components.update(component._get_all_app_wrap_components())
# Add the app wrap components for the children. # Add the app wrap components for the children.
for child in self.children: for child in self.children:
# Skip BaseComponent and StatefulComponent children. # Skip BaseComponent and StatefulComponent children.
if not isinstance(child, Component): if not isinstance(child, Component):
continue continue
components.update(child.get_app_wrap_components()) components.update(child._get_all_app_wrap_components())
# Return the components. # Return the components.
return components return components
@ -1347,8 +1347,8 @@ class CustomComponent(Component):
self.component_props[key] = value self.component_props[key] = value
value = base_value._replace( value = base_value._replace(
merge_var_data=VarData( # type: ignore merge_var_data=VarData( # type: ignore
imports=value.get_imports(), imports=value._get_all_imports(),
hooks=value.get_hooks(), hooks=value._get_all_hooks(),
) )
) )
else: else:
@ -1387,7 +1387,7 @@ class CustomComponent(Component):
""" """
return set() return set()
def get_custom_components( def _get_all_custom_components(
self, seen: set[str] | None = None self, seen: set[str] | None = None
) -> Set[CustomComponent]: ) -> Set[CustomComponent]:
"""Get all the custom components used by the component. """Get all the custom components used by the component.
@ -1403,12 +1403,12 @@ class CustomComponent(Component):
# Store the seen components in a set to avoid infinite recursion. # Store the seen components in a set to avoid infinite recursion.
if seen is None: if seen is None:
seen = set() seen = set()
custom_components = {self} | super().get_custom_components(seen=seen) custom_components = {self} | super()._get_all_custom_components(seen=seen)
# Avoid adding the same component twice. # Avoid adding the same component twice.
if self.tag not in seen: if self.tag not in seen:
seen.add(self.tag) seen.add(self.tag)
custom_components |= self.get_component(self).get_custom_components( custom_components |= self.get_component(self)._get_all_custom_components(
seen=seen seen=seen
) )
@ -1420,7 +1420,9 @@ class CustomComponent(Component):
seen.add(child_component.tag) seen.add(child_component.tag)
if isinstance(child_component, CustomComponent): if isinstance(child_component, CustomComponent):
custom_components |= {child_component} custom_components |= {child_component}
custom_components |= child_component.get_custom_components(seen=seen) custom_components |= child_component._get_all_custom_components(
seen=seen
)
return custom_components return custom_components
def _render(self) -> Tag: def _render(self) -> Tag:
@ -1824,7 +1826,7 @@ class StatefulComponent(BaseComponent):
) )
return trigger_memo return trigger_memo
def get_hooks_internal(self) -> dict[str, None]: def _get_all_hooks_internal(self) -> dict[str, None]:
"""Get the reflex internal hooks for the component and its children. """Get the reflex internal hooks for the component and its children.
Returns: Returns:
@ -1832,7 +1834,7 @@ class StatefulComponent(BaseComponent):
""" """
return {} return {}
def get_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, None]:
"""Get the React hooks for this component. """Get the React hooks for this component.
Returns: Returns:
@ -1840,7 +1842,7 @@ class StatefulComponent(BaseComponent):
""" """
return {} return {}
def get_imports(self) -> imports.ImportDict: def _get_all_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.
Returns: Returns:
@ -1852,9 +1854,9 @@ class StatefulComponent(BaseComponent):
ImportVar(tag=self.tag) ImportVar(tag=self.tag)
] ]
} }
return self.component.get_imports() return self.component._get_all_imports()
def get_dynamic_imports(self) -> set[str]: def _get_all_dynamic_imports(self) -> set[str]:
"""Get dynamic imports for the component. """Get dynamic imports for the component.
Returns: Returns:
@ -1862,9 +1864,9 @@ class StatefulComponent(BaseComponent):
""" """
if self.rendered_as_shared: if self.rendered_as_shared:
return set() return set()
return self.component.get_dynamic_imports() return self.component._get_all_dynamic_imports()
def get_custom_code(self) -> set[str]: def _get_all_custom_code(self) -> set[str]:
"""Get custom code for the component. """Get custom code for the component.
Returns: Returns:
@ -1872,9 +1874,9 @@ class StatefulComponent(BaseComponent):
""" """
if self.rendered_as_shared: if self.rendered_as_shared:
return set() return set()
return self.component.get_custom_code().union({self.code}) return self.component._get_all_custom_code().union({self.code})
def get_refs(self) -> set[str]: def _get_all_refs(self) -> set[str]:
"""Get the refs for the children of the component. """Get the refs for the children of the component.
Returns: Returns:
@ -1882,7 +1884,7 @@ class StatefulComponent(BaseComponent):
""" """
if self.rendered_as_shared: if self.rendered_as_shared:
return set() return set()
return self.component.get_refs() return self.component._get_all_refs()
def render(self) -> dict: def render(self) -> dict:
"""Define how to render the component in React. """Define how to render the component in React.
@ -1940,7 +1942,7 @@ class MemoizationLeaf(Component):
The memoization leaf The memoization leaf
""" """
comp = super().create(*children, **props) comp = super().create(*children, **props)
if comp.get_hooks() or comp.get_hooks_internal(): if comp._get_all_hooks() or comp._get_all_hooks_internal():
comp._memoization_mode = cls._memoization_mode.copy( comp._memoization_mode = cls._memoization_mode.copy(
update={"disposition": MemoizationDisposition.ALWAYS} update={"disposition": MemoizationDisposition.ALWAYS}
) )

View File

@ -163,7 +163,9 @@ class Form(BaseHTML):
props["handle_submit_unique_name"] = "" props["handle_submit_unique_name"] = ""
form = super().create(*children, **props) form = super().create(*children, **props)
form.handle_submit_unique_name = md5( form.handle_submit_unique_name = md5(
str({**form.get_hooks_internal(), **form.get_hooks()}).encode("utf-8") str({**form._get_all_hooks_internal(), **form._get_all_hooks()}).encode(
"utf-8"
)
).hexdigest() ).hexdigest()
return form return form
@ -208,7 +210,7 @@ class Form(BaseHTML):
def _get_form_refs(self) -> Dict[str, Any]: def _get_form_refs(self) -> Dict[str, Any]:
# Send all the input refs to the handler. # Send all the input refs to the handler.
form_refs = {} form_refs = {}
for ref in self.get_refs(): for ref in self._get_all_refs():
# when ref start with refs_ it's an array of refs, so we need different method # when ref start with refs_ it's an array of refs, so we need different method
# to collect data # to collect data
if ref.startswith("refs_"): if ref.startswith("refs_"):
@ -593,13 +595,13 @@ class Textarea(BaseHTML):
"enter_key_submit", "enter_key_submit",
] ]
def get_custom_code(self) -> Set[str]: def _get_all_custom_code(self) -> Set[str]:
"""Include the custom code for auto_height and enter_key_submit functionality. """Include the custom code for auto_height and enter_key_submit functionality.
Returns: Returns:
The custom code for the component. The custom code for the component.
""" """
custom_code = super().get_custom_code() custom_code = super()._get_all_custom_code()
if self.auto_height is not None: if self.auto_height is not None:
custom_code.add(AUTO_HEIGHT_JS) custom_code.add(AUTO_HEIGHT_JS)
if self.enter_key_submit is not None: if self.enter_key_submit is not None:

View File

@ -2022,7 +2022,6 @@ AUTO_HEIGHT_JS = '\nconst autoHeightOnInput = (e, is_enabled) => {\n if (is_e
ENTER_KEY_SUBMIT_JS = "\nconst enterKeySubmitOnKeyDown = (e, is_enabled) => {\n if (is_enabled && e.which === 13 && !e.shiftKey) {\n e.preventDefault();\n if (!e.repeat) {\n if (e.target.form) {\n e.target.form.requestSubmit();\n }\n }\n }\n}\n" ENTER_KEY_SUBMIT_JS = "\nconst enterKeySubmitOnKeyDown = (e, is_enabled) => {\n if (is_enabled && e.which === 13 && !e.shiftKey) {\n e.preventDefault();\n if (!e.repeat) {\n if (e.target.form) {\n e.target.form.requestSubmit();\n }\n }\n }\n}\n"
class Textarea(BaseHTML): class Textarea(BaseHTML):
def get_custom_code(self) -> Set[str]: ...
def get_event_triggers(self) -> Dict[str, Any]: ... def get_event_triggers(self) -> Dict[str, Any]: ...
@overload @overload
@classmethod @classmethod

View File

@ -133,7 +133,7 @@ class Markdown(Component):
**props, **props,
) )
def get_custom_components( def _get_all_custom_components(
self, seen: set[str] | None = None self, seen: set[str] | None = None
) -> set[CustomComponent]: ) -> set[CustomComponent]:
"""Get all the custom components used by the component. """Get all the custom components used by the component.
@ -144,11 +144,13 @@ class Markdown(Component):
Returns: Returns:
The set of custom components. The set of custom components.
""" """
custom_components = super().get_custom_components(seen=seen) custom_components = super()._get_all_custom_components(seen=seen)
# Get the custom components for each tag. # Get the custom components for each tag.
for component in self.component_map.values(): for component in self.component_map.values():
custom_components |= component(_MOCK_ARG).get_custom_components(seen=seen) custom_components |= component(_MOCK_ARG)._get_all_custom_components(
seen=seen
)
return custom_components return custom_components
@ -183,7 +185,9 @@ class Markdown(Component):
# Get the imports for each component. # Get the imports for each component.
for component in self.component_map.values(): for component in self.component_map.values():
imports = utils.merge_imports(imports, component(_MOCK_ARG).get_imports()) imports = utils.merge_imports(
imports, component(_MOCK_ARG)._get_all_imports()
)
# Get the imports for the code components. # Get the imports for the code components.
imports = utils.merge_imports( imports = utils.merge_imports(
@ -293,8 +297,8 @@ class Markdown(Component):
hooks = set() hooks = set()
for _component in self.component_map.values(): for _component in self.component_map.values():
comp = _component(_MOCK_ARG) comp = _component(_MOCK_ARG)
hooks.update(comp.get_hooks_internal()) hooks.update(comp._get_all_hooks_internal())
hooks.update(comp.get_hooks()) hooks.update(comp._get_all_hooks())
formatted_hooks = "\n".join(hooks) formatted_hooks = "\n".join(hooks)
return f""" return f"""
function {self._get_component_map_name()} () {{ function {self._get_component_map_name()} () {{

View File

@ -123,9 +123,6 @@ class Markdown(Component):
The markdown component. The markdown component.
""" """
... ...
def get_custom_components(
self, seen: set[str] | None = None
) -> set[CustomComponent]: ...
def get_component(self, tag: str, **props) -> Component: ... def get_component(self, tag: str, **props) -> Component: ...
def format_component(self, tag: str, **props) -> str: ... def format_component(self, tag: str, **props) -> str: ...
def format_component_map(self) -> dict[str, str]: ... def format_component_map(self) -> dict[str, str]: ...

View File

@ -9,13 +9,13 @@ from reflex.components.radix.themes.typography.text import Text
def test_websocket_target_url(): def test_websocket_target_url():
url = WebsocketTargetURL.create() url = WebsocketTargetURL.create()
_imports = url.get_imports(collapse=True) _imports = url._get_all_imports(collapse=True)
assert list(_imports.keys()) == ["/utils/state", "/env.json"] assert list(_imports.keys()) == ["/utils/state", "/env.json"]
def test_connection_banner(): def test_connection_banner():
banner = ConnectionBanner.create() banner = ConnectionBanner.create()
_imports = banner.get_imports(collapse=True) _imports = banner._get_all_imports(collapse=True)
assert list(_imports.keys()) == [ assert list(_imports.keys()) == [
"react", "react",
"/utils/context", "/utils/context",
@ -31,7 +31,7 @@ def test_connection_banner():
def test_connection_modal(): def test_connection_modal():
modal = ConnectionModal.create() modal = ConnectionModal.create()
_imports = modal.get_imports(collapse=True) _imports = modal._get_all_imports(collapse=True)
assert list(_imports.keys()) == [ assert list(_imports.keys()) == [
"react", "react",
"/utils/context", "/utils/context",
@ -47,5 +47,5 @@ def test_connection_modal():
def test_connection_pulser(): def test_connection_pulser():
pulser = ConnectionPulser.create() pulser = ConnectionPulser.create()
_custom_code = pulser.get_custom_code() _custom_code = pulser._get_all_custom_code()
_imports = pulser.get_imports(collapse=True) _imports = pulser._get_all_imports(collapse=True)

View File

@ -263,8 +263,8 @@ def test_add_style(component1, component2):
component1: Style({"color": "white"}), component1: Style({"color": "white"}),
component2: Style({"color": "black"}), component2: Style({"color": "black"}),
} }
c1 = component1().add_style(style) # type: ignore c1 = component1()._add_style_recursive(style) # type: ignore
c2 = component2().add_style(style) # type: ignore c2 = component2()._add_style_recursive(style) # type: ignore
assert c1.style["color"] == "white" assert c1.style["color"] == "white"
assert c2.style["color"] == "black" assert c2.style["color"] == "black"
@ -280,8 +280,8 @@ def test_add_style_create(component1, component2):
component1.create: Style({"color": "white"}), component1.create: Style({"color": "white"}),
component2.create: Style({"color": "black"}), component2.create: Style({"color": "black"}),
} }
c1 = component1().add_style(style) # type: ignore c1 = component1()._add_style_recursive(style) # type: ignore
c2 = component2().add_style(style) # type: ignore c2 = component2()._add_style_recursive(style) # type: ignore
assert c1.style["color"] == "white" assert c1.style["color"] == "white"
assert c2.style["color"] == "black" assert c2.style["color"] == "black"
@ -295,8 +295,8 @@ def test_get_imports(component1, component2):
""" """
c1 = component1.create() c1 = component1.create()
c2 = component2.create(c1) c2 = component2.create(c1)
assert c1.get_imports() == {"react": [ImportVar(tag="Component")]} assert c1._get_all_imports() == {"react": [ImportVar(tag="Component")]}
assert c2.get_imports() == { assert c2._get_all_imports() == {
"react-redux": [ImportVar(tag="connect")], "react-redux": [ImportVar(tag="connect")],
"react": [ImportVar(tag="Component")], "react": [ImportVar(tag="Component")],
} }
@ -312,19 +312,19 @@ def test_get_custom_code(component1, component2):
# Check that the code gets compiled correctly. # Check that the code gets compiled correctly.
c1 = component1.create() c1 = component1.create()
c2 = component2.create() c2 = component2.create()
assert c1.get_custom_code() == {"console.log('component1')"} assert c1._get_all_custom_code() == {"console.log('component1')"}
assert c2.get_custom_code() == {"console.log('component2')"} assert c2._get_all_custom_code() == {"console.log('component2')"}
# Check that nesting components compiles both codes. # Check that nesting components compiles both codes.
c1 = component1.create(c2) c1 = component1.create(c2)
assert c1.get_custom_code() == { assert c1._get_all_custom_code() == {
"console.log('component1')", "console.log('component1')",
"console.log('component2')", "console.log('component2')",
} }
# Check that code is not duplicated. # Check that code is not duplicated.
c1 = component1.create(c2, c2, c1, c1) c1 = component1.create(c2, c2, c1, c1)
assert c1.get_custom_code() == { assert c1._get_all_custom_code() == {
"console.log('component1')", "console.log('component1')",
"console.log('component2')", "console.log('component2')",
} }
@ -502,7 +502,7 @@ def test_create_custom_component(my_component):
component = CustomComponent(component_fn=my_component, prop1="test", prop2=1) component = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
assert component.tag == "MyComponent" assert component.tag == "MyComponent"
assert component.get_props() == set() assert component.get_props() == set()
assert component.get_custom_components() == {component} assert component._get_all_custom_components() == {component}
def test_custom_component_hash(my_component): def test_custom_component_hash(my_component):
@ -586,7 +586,7 @@ def test_get_hooks_nested(component1, component2, component3):
text="a", text="a",
number=1, number=1,
) )
assert c.get_hooks() == component3().get_hooks() assert c._get_all_hooks() == component3()._get_all_hooks()
def test_get_hooks_nested2(component3, component4): def test_get_hooks_nested2(component3, component4):
@ -596,15 +596,15 @@ def test_get_hooks_nested2(component3, component4):
component3: component with hooks defined. component3: component with hooks defined.
component4: component with different hooks defined. component4: component with different hooks defined.
""" """
exp_hooks = {**component3().get_hooks(), **component4().get_hooks()} exp_hooks = {**component3()._get_all_hooks(), **component4()._get_all_hooks()}
assert component3.create(component4.create()).get_hooks() == exp_hooks assert component3.create(component4.create())._get_all_hooks() == exp_hooks
assert component4.create(component3.create()).get_hooks() == exp_hooks assert component4.create(component3.create())._get_all_hooks() == exp_hooks
assert ( assert (
component4.create( component4.create(
component3.create(), component3.create(),
component4.create(), component4.create(),
component3.create(), component3.create(),
).get_hooks() )._get_all_hooks()
== exp_hooks == exp_hooks
) )
@ -1329,20 +1329,20 @@ def test_custom_component_get_imports():
custom_comp = wrapper() custom_comp = wrapper()
# Inner is not imported directly, but it is imported by the custom component. # Inner is not imported directly, but it is imported by the custom component.
assert "inner" not in custom_comp.get_imports() assert "inner" not in custom_comp._get_all_imports()
# The imports are only resolved during compilation. # The imports are only resolved during compilation.
_, _, imports_inner = compile_components(custom_comp.get_custom_components()) _, _, imports_inner = compile_components(custom_comp._get_all_custom_components())
assert "inner" in imports_inner assert "inner" in imports_inner
outer_comp = outer(c=wrapper()) outer_comp = outer(c=wrapper())
# Libraries are not imported directly, but are imported by the custom component. # Libraries are not imported directly, but are imported by the custom component.
assert "inner" not in outer_comp.get_imports() assert "inner" not in outer_comp._get_all_imports()
assert "other" not in outer_comp.get_imports() assert "other" not in outer_comp._get_all_imports()
# The imports are only resolved during compilation. # The imports are only resolved during compilation.
_, _, imports_outer = compile_components(outer_comp.get_custom_components()) _, _, imports_outer = compile_components(outer_comp._get_all_custom_components())
assert "inner" in imports_outer assert "inner" in imports_outer
assert "other" in imports_outer assert "other" in imports_outer