From 5274f04b66612fa30b3f53fe67e65a26f05a78f3 Mon Sep 17 00:00:00 2001 From: Martin Xu <15661672+martinxu9@users.noreply.github.com> Date: Thu, 4 Apr 2024 09:26:31 -0700 Subject: [PATCH] [REF-2296] Rename recursive functions (#2999) --- .../web/pages/stateful_component.js.jinja2 | 4 +- reflex/app.py | 14 +-- reflex/compiler/compiler.py | 20 ++--- reflex/compiler/utils.py | 6 +- reflex/components/chakra/forms/pininput.py | 3 +- reflex/components/component.py | 90 ++++++++++--------- reflex/components/el/elements/forms.py | 10 ++- reflex/components/el/elements/forms.pyi | 1 - reflex/components/markdown/markdown.py | 16 ++-- reflex/components/markdown/markdown.pyi | 3 - tests/components/core/test_banner.py | 10 +-- tests/components/test_component.py | 42 ++++----- 12 files changed, 112 insertions(+), 107 deletions(-) diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index ad970c2f5..4a40ef545 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -1,7 +1,7 @@ {% import 'web/pages/utils.js.jinja2' as utils %} export function {{tag_name}} () { - {% for hook in component.get_hooks_internal() %} + {% for hook in component._get_all_hooks_internal() %} {{ hook }} {% endfor %} @@ -9,7 +9,7 @@ export function {{tag_name}} () { {{ hook }} {% endfor %} - {% for hook in component.get_hooks() %} + {% for hook in component._get_all_hooks() %} {{ hook }} {% endfor %} diff --git a/reflex/app.py b/reflex/app.py index ef8780b49..50000e507 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -653,7 +653,7 @@ class App(Base): def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component: for component in tuple(app_wrappers.values()): - app_wrappers.update(component.get_app_wrap_components()) + app_wrappers.update(component._get_all_app_wrap_components()) order = sorted(app_wrappers, key=lambda k: k[0], reverse=True) root = parent = copy.deepcopy(app_wrappers[order[0]]) for key in order[1:]: @@ -791,18 +791,18 @@ class App(Base): for _route, component in self.pages.items(): # Merge the component style with the app style. - component.add_style(self.style) + component._add_style_recursive(self.style) component.apply_theme(self.theme) - # Add component.get_imports() to all_imports. - all_imports.update(component.get_imports()) + # Add component._get_all_imports() to all_imports. + all_imports.update(component._get_all_imports()) # Add the app wrappers from this component. - app_wrappers.update(component.get_app_wrap_components()) + app_wrappers.update(component._get_all_app_wrap_components()) # Add the custom components from the page to the set. - custom_components |= component.get_custom_components() + custom_components |= component._get_all_custom_components() progress.advance(task) @@ -924,7 +924,7 @@ class App(Base): all_imports.update(custom_components_imports) # Get imports from AppWrap components. - all_imports.update(app_root.get_imports()) + all_imports.update(app_root._get_all_imports()) progress.advance(task) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 02b2a2a16..89ac867f7 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -33,7 +33,7 @@ def _compile_document_root(root: Component) -> str: The compiled document root. """ return templates.DOCUMENT_ROOT.render( - imports=utils.compile_imports(root.get_imports()), + imports=utils.compile_imports(root._get_all_imports()), document=root.render(), ) @@ -48,9 +48,9 @@ def _compile_app(app_root: Component) -> str: The compiled app. """ return templates.APP_ROOT.render( - imports=utils.compile_imports(app_root.get_imports()), - custom_codes=app_root.get_custom_code(), - hooks={**app_root.get_hooks_internal(), **app_root.get_hooks()}, + imports=utils.compile_imports(app_root._get_all_imports()), + custom_codes=app_root._get_all_custom_code(), + hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()}, render=app_root.render(), ) @@ -109,7 +109,7 @@ def _compile_page( Returns: The compiled component. """ - imports = component.get_imports() + imports = component._get_all_imports() imports = utils.compile_imports(imports) # Compile the code to render the component. @@ -117,9 +117,9 @@ def _compile_page( return templates.PAGE.render( imports=imports, - dynamic_imports=component.get_dynamic_imports(), - custom_codes=component.get_custom_code(), - hooks={**component.get_hooks_internal(), **component.get_hooks()}, + dynamic_imports=component._get_all_dynamic_imports(), + custom_codes=component._get_all_custom_code(), + hooks={**component._get_all_hooks_internal(), **component._get_all_hooks()}, render=component.render(), **kwargs, ) @@ -264,9 +264,9 @@ def _compile_stateful_components( component.rendered_as_shared = False rendered_components.update( - {code: None for code in component.get_custom_code()}, + {code: None for code in component._get_all_custom_code()}, ) - all_import_dicts.append(component.get_imports()) + all_import_dicts.append(component._get_all_imports()) # Indicate that this component now imports from the shared file. component.rendered_as_shared = True diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 3ebba40e0..ab5bf9650 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -252,7 +252,7 @@ def compile_custom_component( # Get the imports. imports = { lib: fields - for lib, fields in render.get_imports().items() + for lib, fields in render._get_all_imports().items() if lib != component.library } @@ -265,8 +265,8 @@ def compile_custom_component( "name": component.tag, "props": props, "render": render.render(), - "hooks": {**render.get_hooks_internal(), **render.get_hooks()}, - "custom_code": render.get_custom_code(), + "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()}, + "custom_code": render._get_all_custom_code(), }, imports, ) diff --git a/reflex/components/chakra/forms/pininput.py b/reflex/components/chakra/forms/pininput.py index d57c9acaf..cea95ed7b 100644 --- a/reflex/components/chakra/forms/pininput.py +++ b/reflex/components/chakra/forms/pininput.py @@ -1,4 +1,5 @@ """A pin input component.""" + from __future__ import annotations from typing import Any, Optional, Union @@ -71,7 +72,7 @@ class PinInput(ChakraComponent): range_var = Var.range(0) return merge_imports( super()._get_imports(), - PinInputField().get_imports(), # type: ignore + PinInputField()._get_all_imports(), # type: ignore range_var._var_data.imports if range_var._var_data is not None else {}, ) diff --git a/reflex/components/component.py b/reflex/components/component.py index c873e687b..057a55361 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -76,7 +76,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def get_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -84,7 +84,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def get_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, None]: """Get the React hooks for this component. Returns: @@ -92,7 +92,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def get_imports(self) -> imports.ImportDict: + def _get_all_imports(self) -> imports.ImportDict: """Get all the libraries and fields that are used by the component. Returns: @@ -100,7 +100,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def get_dynamic_imports(self) -> set[str]: + def _get_all_dynamic_imports(self) -> set[str]: """Get dynamic imports for the component. Returns: @@ -108,7 +108,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def get_custom_code(self) -> set[str]: + def _get_all_custom_code(self) -> set[str]: """Get custom code for the component. Returns: @@ -116,7 +116,7 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def get_refs(self) -> set[str]: + def _get_all_refs(self) -> set[str]: """Get the refs for the children of the component. Returns: @@ -669,7 +669,7 @@ class Component(BaseComponent, ABC): """ self.style.update(style) - def add_style(self, style: ComponentStyle) -> Component: + def _add_style_recursive(self, style: ComponentStyle) -> Component: """Add additional style to the component and its children. Args: @@ -698,7 +698,7 @@ class Component(BaseComponent, ABC): # Skip BaseComponent and StatefulComponent children. if not isinstance(child, Component): continue - child.add_style(style) + child._add_style_recursive(style) return self def _get_style(self) -> dict: @@ -921,7 +921,7 @@ class Component(BaseComponent, ABC): """ return None - def get_custom_code(self) -> set[str]: + def _get_all_custom_code(self) -> set[str]: """Get custom code for the component and its children. Returns: @@ -937,7 +937,7 @@ class Component(BaseComponent, ABC): # Add the custom code for the children. for child in self.children: - code |= child.get_custom_code() + code |= child._get_all_custom_code() # Return the code. return code @@ -950,7 +950,7 @@ class Component(BaseComponent, ABC): """ return None - def get_dynamic_imports(self) -> Set[str]: + def _get_all_dynamic_imports(self) -> Set[str]: """Get dynamic imports for the component and its children. Returns: @@ -966,11 +966,11 @@ class Component(BaseComponent, ABC): # Get the dynamic imports from children for child in self.children: - dynamic_imports |= child.get_dynamic_imports() + dynamic_imports |= child._get_all_dynamic_imports() for prop in self.get_component_props(): if getattr(self, prop) is not None: - dynamic_imports |= getattr(self, prop).get_dynamic_imports() + dynamic_imports |= getattr(self, prop)._get_all_dynamic_imports() # Return the dynamic imports return dynamic_imports @@ -982,7 +982,7 @@ class Component(BaseComponent, ABC): The imports for the components props of the component. """ return [ - getattr(self, prop).get_imports() + getattr(self, prop)._get_all_imports() for prop in self.get_component_props() if getattr(self, prop) is not None ] @@ -1053,7 +1053,7 @@ class Component(BaseComponent, ABC): *var_imports, ) - def get_imports(self, collapse: bool = False) -> imports.ImportDict: + def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict: """Get all the libraries and fields that are used by the component and its children. Args: @@ -1063,7 +1063,7 @@ class Component(BaseComponent, ABC): The import dict with the required imports. """ _imports = imports.merge_imports( - self._get_imports(), *[child.get_imports() for child in self.children] + self._get_imports(), *[child._get_all_imports() for child in self.children] ) return imports.collapse_imports(_imports) if collapse else _imports @@ -1158,7 +1158,7 @@ class Component(BaseComponent, ABC): """ return - def get_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -1169,11 +1169,11 @@ class Component(BaseComponent, ABC): # Add the hook code for the children. for child in self.children: - code = {**code, **child.get_hooks_internal()} + code = {**code, **child._get_all_hooks_internal()} return code - def get_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, None]: """Get the React hooks for this component and its children. Returns: @@ -1188,7 +1188,7 @@ class Component(BaseComponent, ABC): # Add the hook code for the children. for child in self.children: - code = {**code, **child.get_hooks()} + code = {**code, **child._get_all_hooks()} return code @@ -1203,7 +1203,7 @@ class Component(BaseComponent, ABC): return None return format.format_ref(self.id) - def get_refs(self) -> set[str]: + def _get_all_refs(self) -> set[str]: """Get the refs for the children of the component. Returns: @@ -1214,10 +1214,10 @@ class Component(BaseComponent, ABC): if ref is not None: refs.add(ref) for child in self.children: - refs |= child.get_refs() + refs |= child._get_all_refs() return refs - def get_custom_components( + def _get_all_custom_components( self, seen: set[str] | None = None ) -> Set[CustomComponent]: """Get all the custom components used by the component. @@ -1237,7 +1237,7 @@ class Component(BaseComponent, ABC): # Skip BaseComponent and StatefulComponent children. if not isinstance(child, Component): continue - custom_components |= child.get_custom_components(seen=seen) + custom_components |= child._get_all_custom_components(seen=seen) return custom_components @property @@ -1261,7 +1261,7 @@ class Component(BaseComponent, ABC): """ return {} - def get_app_wrap_components(self) -> dict[tuple[int, str], Component]: + def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]: """Get the app wrap components for the component and its children. Returns: @@ -1271,14 +1271,14 @@ class Component(BaseComponent, ABC): components = self._get_app_wrap_components() for component in tuple(components.values()): - components.update(component.get_app_wrap_components()) + components.update(component._get_all_app_wrap_components()) # Add the app wrap components for the children. for child in self.children: # Skip BaseComponent and StatefulComponent children. if not isinstance(child, Component): continue - components.update(child.get_app_wrap_components()) + components.update(child._get_all_app_wrap_components()) # Return the components. return components @@ -1347,8 +1347,8 @@ class CustomComponent(Component): self.component_props[key] = value value = base_value._replace( merge_var_data=VarData( # type: ignore - imports=value.get_imports(), - hooks=value.get_hooks(), + imports=value._get_all_imports(), + hooks=value._get_all_hooks(), ) ) else: @@ -1387,7 +1387,7 @@ class CustomComponent(Component): """ return set() - def get_custom_components( + def _get_all_custom_components( self, seen: set[str] | None = None ) -> Set[CustomComponent]: """Get all the custom components used by the component. @@ -1403,12 +1403,12 @@ class CustomComponent(Component): # Store the seen components in a set to avoid infinite recursion. if seen is None: seen = set() - custom_components = {self} | super().get_custom_components(seen=seen) + custom_components = {self} | super()._get_all_custom_components(seen=seen) # Avoid adding the same component twice. if self.tag not in seen: seen.add(self.tag) - custom_components |= self.get_component(self).get_custom_components( + custom_components |= self.get_component(self)._get_all_custom_components( seen=seen ) @@ -1420,7 +1420,9 @@ class CustomComponent(Component): seen.add(child_component.tag) if isinstance(child_component, CustomComponent): custom_components |= {child_component} - custom_components |= child_component.get_custom_components(seen=seen) + custom_components |= child_component._get_all_custom_components( + seen=seen + ) return custom_components def _render(self) -> Tag: @@ -1824,7 +1826,7 @@ class StatefulComponent(BaseComponent): ) return trigger_memo - def get_hooks_internal(self) -> dict[str, None]: + def _get_all_hooks_internal(self) -> dict[str, None]: """Get the reflex internal hooks for the component and its children. Returns: @@ -1832,7 +1834,7 @@ class StatefulComponent(BaseComponent): """ return {} - def get_hooks(self) -> dict[str, None]: + def _get_all_hooks(self) -> dict[str, None]: """Get the React hooks for this component. Returns: @@ -1840,7 +1842,7 @@ class StatefulComponent(BaseComponent): """ return {} - def get_imports(self) -> imports.ImportDict: + def _get_all_imports(self) -> imports.ImportDict: """Get all the libraries and fields that are used by the component. Returns: @@ -1852,9 +1854,9 @@ class StatefulComponent(BaseComponent): ImportVar(tag=self.tag) ] } - return self.component.get_imports() + return self.component._get_all_imports() - def get_dynamic_imports(self) -> set[str]: + def _get_all_dynamic_imports(self) -> set[str]: """Get dynamic imports for the component. Returns: @@ -1862,9 +1864,9 @@ class StatefulComponent(BaseComponent): """ if self.rendered_as_shared: return set() - return self.component.get_dynamic_imports() + return self.component._get_all_dynamic_imports() - def get_custom_code(self) -> set[str]: + def _get_all_custom_code(self) -> set[str]: """Get custom code for the component. Returns: @@ -1872,9 +1874,9 @@ class StatefulComponent(BaseComponent): """ if self.rendered_as_shared: return set() - return self.component.get_custom_code().union({self.code}) + return self.component._get_all_custom_code().union({self.code}) - def get_refs(self) -> set[str]: + def _get_all_refs(self) -> set[str]: """Get the refs for the children of the component. Returns: @@ -1882,7 +1884,7 @@ class StatefulComponent(BaseComponent): """ if self.rendered_as_shared: return set() - return self.component.get_refs() + return self.component._get_all_refs() def render(self) -> dict: """Define how to render the component in React. @@ -1940,7 +1942,7 @@ class MemoizationLeaf(Component): The memoization leaf """ comp = super().create(*children, **props) - if comp.get_hooks() or comp.get_hooks_internal(): + if comp._get_all_hooks() or comp._get_all_hooks_internal(): comp._memoization_mode = cls._memoization_mode.copy( update={"disposition": MemoizationDisposition.ALWAYS} ) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index dc7aa5355..ab66fd5fc 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -163,7 +163,9 @@ class Form(BaseHTML): props["handle_submit_unique_name"] = "" form = super().create(*children, **props) form.handle_submit_unique_name = md5( - str({**form.get_hooks_internal(), **form.get_hooks()}).encode("utf-8") + str({**form._get_all_hooks_internal(), **form._get_all_hooks()}).encode( + "utf-8" + ) ).hexdigest() return form @@ -208,7 +210,7 @@ class Form(BaseHTML): def _get_form_refs(self) -> Dict[str, Any]: # Send all the input refs to the handler. form_refs = {} - for ref in self.get_refs(): + for ref in self._get_all_refs(): # when ref start with refs_ it's an array of refs, so we need different method # to collect data if ref.startswith("refs_"): @@ -593,13 +595,13 @@ class Textarea(BaseHTML): "enter_key_submit", ] - def get_custom_code(self) -> Set[str]: + def _get_all_custom_code(self) -> Set[str]: """Include the custom code for auto_height and enter_key_submit functionality. Returns: The custom code for the component. """ - custom_code = super().get_custom_code() + custom_code = super()._get_all_custom_code() if self.auto_height is not None: custom_code.add(AUTO_HEIGHT_JS) if self.enter_key_submit is not None: diff --git a/reflex/components/el/elements/forms.pyi b/reflex/components/el/elements/forms.pyi index 78f55c9a9..1a3015804 100644 --- a/reflex/components/el/elements/forms.pyi +++ b/reflex/components/el/elements/forms.pyi @@ -2022,7 +2022,6 @@ AUTO_HEIGHT_JS = '\nconst autoHeightOnInput = (e, is_enabled) => {\n if (is_e ENTER_KEY_SUBMIT_JS = "\nconst enterKeySubmitOnKeyDown = (e, is_enabled) => {\n if (is_enabled && e.which === 13 && !e.shiftKey) {\n e.preventDefault();\n if (!e.repeat) {\n if (e.target.form) {\n e.target.form.requestSubmit();\n }\n }\n }\n}\n" class Textarea(BaseHTML): - def get_custom_code(self) -> Set[str]: ... def get_event_triggers(self) -> Dict[str, Any]: ... @overload @classmethod diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index 5328cb6ff..933b73756 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -133,7 +133,7 @@ class Markdown(Component): **props, ) - def get_custom_components( + def _get_all_custom_components( self, seen: set[str] | None = None ) -> set[CustomComponent]: """Get all the custom components used by the component. @@ -144,11 +144,13 @@ class Markdown(Component): Returns: The set of custom components. """ - custom_components = super().get_custom_components(seen=seen) + custom_components = super()._get_all_custom_components(seen=seen) # Get the custom components for each tag. for component in self.component_map.values(): - custom_components |= component(_MOCK_ARG).get_custom_components(seen=seen) + custom_components |= component(_MOCK_ARG)._get_all_custom_components( + seen=seen + ) return custom_components @@ -183,7 +185,9 @@ class Markdown(Component): # Get the imports for each component. for component in self.component_map.values(): - imports = utils.merge_imports(imports, component(_MOCK_ARG).get_imports()) + imports = utils.merge_imports( + imports, component(_MOCK_ARG)._get_all_imports() + ) # Get the imports for the code components. imports = utils.merge_imports( @@ -293,8 +297,8 @@ class Markdown(Component): hooks = set() for _component in self.component_map.values(): comp = _component(_MOCK_ARG) - hooks.update(comp.get_hooks_internal()) - hooks.update(comp.get_hooks()) + hooks.update(comp._get_all_hooks_internal()) + hooks.update(comp._get_all_hooks()) formatted_hooks = "\n".join(hooks) return f""" function {self._get_component_map_name()} () {{ diff --git a/reflex/components/markdown/markdown.pyi b/reflex/components/markdown/markdown.pyi index 7f7b7c4e5..85ffeeeed 100644 --- a/reflex/components/markdown/markdown.pyi +++ b/reflex/components/markdown/markdown.pyi @@ -123,9 +123,6 @@ class Markdown(Component): The markdown component. """ ... - def get_custom_components( - self, seen: set[str] | None = None - ) -> set[CustomComponent]: ... def get_component(self, tag: str, **props) -> Component: ... def format_component(self, tag: str, **props) -> str: ... def format_component_map(self) -> dict[str, str]: ... diff --git a/tests/components/core/test_banner.py b/tests/components/core/test_banner.py index 66fa598c7..5131a4b85 100644 --- a/tests/components/core/test_banner.py +++ b/tests/components/core/test_banner.py @@ -9,13 +9,13 @@ from reflex.components.radix.themes.typography.text import Text def test_websocket_target_url(): url = WebsocketTargetURL.create() - _imports = url.get_imports(collapse=True) + _imports = url._get_all_imports(collapse=True) assert list(_imports.keys()) == ["/utils/state", "/env.json"] def test_connection_banner(): banner = ConnectionBanner.create() - _imports = banner.get_imports(collapse=True) + _imports = banner._get_all_imports(collapse=True) assert list(_imports.keys()) == [ "react", "/utils/context", @@ -31,7 +31,7 @@ def test_connection_banner(): def test_connection_modal(): modal = ConnectionModal.create() - _imports = modal.get_imports(collapse=True) + _imports = modal._get_all_imports(collapse=True) assert list(_imports.keys()) == [ "react", "/utils/context", @@ -47,5 +47,5 @@ def test_connection_modal(): def test_connection_pulser(): pulser = ConnectionPulser.create() - _custom_code = pulser.get_custom_code() - _imports = pulser.get_imports(collapse=True) + _custom_code = pulser._get_all_custom_code() + _imports = pulser._get_all_imports(collapse=True) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index f63d81fbd..924f2b169 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -263,8 +263,8 @@ def test_add_style(component1, component2): component1: Style({"color": "white"}), component2: Style({"color": "black"}), } - c1 = component1().add_style(style) # type: ignore - c2 = component2().add_style(style) # type: ignore + c1 = component1()._add_style_recursive(style) # type: ignore + c2 = component2()._add_style_recursive(style) # type: ignore assert c1.style["color"] == "white" assert c2.style["color"] == "black" @@ -280,8 +280,8 @@ def test_add_style_create(component1, component2): component1.create: Style({"color": "white"}), component2.create: Style({"color": "black"}), } - c1 = component1().add_style(style) # type: ignore - c2 = component2().add_style(style) # type: ignore + c1 = component1()._add_style_recursive(style) # type: ignore + c2 = component2()._add_style_recursive(style) # type: ignore assert c1.style["color"] == "white" assert c2.style["color"] == "black" @@ -295,8 +295,8 @@ def test_get_imports(component1, component2): """ c1 = component1.create() c2 = component2.create(c1) - assert c1.get_imports() == {"react": [ImportVar(tag="Component")]} - assert c2.get_imports() == { + assert c1._get_all_imports() == {"react": [ImportVar(tag="Component")]} + assert c2._get_all_imports() == { "react-redux": [ImportVar(tag="connect")], "react": [ImportVar(tag="Component")], } @@ -312,19 +312,19 @@ def test_get_custom_code(component1, component2): # Check that the code gets compiled correctly. c1 = component1.create() c2 = component2.create() - assert c1.get_custom_code() == {"console.log('component1')"} - assert c2.get_custom_code() == {"console.log('component2')"} + assert c1._get_all_custom_code() == {"console.log('component1')"} + assert c2._get_all_custom_code() == {"console.log('component2')"} # Check that nesting components compiles both codes. c1 = component1.create(c2) - assert c1.get_custom_code() == { + assert c1._get_all_custom_code() == { "console.log('component1')", "console.log('component2')", } # Check that code is not duplicated. c1 = component1.create(c2, c2, c1, c1) - assert c1.get_custom_code() == { + assert c1._get_all_custom_code() == { "console.log('component1')", "console.log('component2')", } @@ -502,7 +502,7 @@ def test_create_custom_component(my_component): component = CustomComponent(component_fn=my_component, prop1="test", prop2=1) assert component.tag == "MyComponent" assert component.get_props() == set() - assert component.get_custom_components() == {component} + assert component._get_all_custom_components() == {component} def test_custom_component_hash(my_component): @@ -586,7 +586,7 @@ def test_get_hooks_nested(component1, component2, component3): text="a", number=1, ) - assert c.get_hooks() == component3().get_hooks() + assert c._get_all_hooks() == component3()._get_all_hooks() def test_get_hooks_nested2(component3, component4): @@ -596,15 +596,15 @@ def test_get_hooks_nested2(component3, component4): component3: component with hooks defined. component4: component with different hooks defined. """ - exp_hooks = {**component3().get_hooks(), **component4().get_hooks()} - assert component3.create(component4.create()).get_hooks() == exp_hooks - assert component4.create(component3.create()).get_hooks() == exp_hooks + exp_hooks = {**component3()._get_all_hooks(), **component4()._get_all_hooks()} + assert component3.create(component4.create())._get_all_hooks() == exp_hooks + assert component4.create(component3.create())._get_all_hooks() == exp_hooks assert ( component4.create( component3.create(), component4.create(), component3.create(), - ).get_hooks() + )._get_all_hooks() == exp_hooks ) @@ -1329,20 +1329,20 @@ def test_custom_component_get_imports(): custom_comp = wrapper() # Inner is not imported directly, but it is imported by the custom component. - assert "inner" not in custom_comp.get_imports() + assert "inner" not in custom_comp._get_all_imports() # The imports are only resolved during compilation. - _, _, imports_inner = compile_components(custom_comp.get_custom_components()) + _, _, imports_inner = compile_components(custom_comp._get_all_custom_components()) assert "inner" in imports_inner outer_comp = outer(c=wrapper()) # Libraries are not imported directly, but are imported by the custom component. - assert "inner" not in outer_comp.get_imports() - assert "other" not in outer_comp.get_imports() + assert "inner" not in outer_comp._get_all_imports() + assert "other" not in outer_comp._get_all_imports() # The imports are only resolved during compilation. - _, _, imports_outer = compile_components(outer_comp.get_custom_components()) + _, _, imports_outer = compile_components(outer_comp._get_all_custom_components()) assert "inner" in imports_outer assert "other" in imports_outer