From 8edd1dfdc99e21ed3d300c2ad1ed2afe4c420e36 Mon Sep 17 00:00:00 2001 From: Martin Xu <15661672+martinxu9@users.noreply.github.com> Date: Thu, 28 Mar 2024 11:03:26 -0700 Subject: [PATCH] [REF-2269] Add `add_imports` API for component class (#2937) --- reflex/components/component.py | 27 ++++++++++++++++++ tests/components/test_component.py | 44 +++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 506031af9..58842062f 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1017,6 +1017,16 @@ class Component(BaseComponent, ABC): ) return _imports + def add_imports( + self, + ) -> Dict[str, Union[str, ImportVar, List[str | ImportVar]]]: + """User defined imports for the component. Need to be overriden in subclass. + + Returns: + The user defined imports as a dict. + """ + return {} + def _get_imports(self) -> imports.ImportDict: """Get all the libraries and fields that are used by the component. @@ -1037,12 +1047,29 @@ class Component(BaseComponent, ABC): var._var_data.imports for var in self._get_vars() if var._var_data ] + # If the subclass implements add_imports, merge the imports. + def _make_list( + value: str | ImportVar | list[str | ImportVar], + ) -> list[str | ImportVar]: + if isinstance(value, (str, ImportVar)): + return [value] + return value + + added_imports = { + package: [ + ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag + for tag in _make_list(maybe_tags) + ] + for package, maybe_tags in self.add_imports().items() + } + return imports.merge_imports( *self._get_props_imports(), self._get_dependencies_imports(), self._get_hooks_imports(), _imports, event_imports, + added_imports, *var_imports, ) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 8dfe6dedc..9c090baa9 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Type, Union import pytest @@ -1308,3 +1308,45 @@ def test_custom_component_get_imports(): _, _, imports_outer = compile_components(outer_comp.get_custom_components()) assert "inner" in imports_outer assert "other" in imports_outer + + +@pytest.mark.parametrize( + "tags", + ( + ["Component"], + ["Component", "useState"], + [ImportVar(tag="Component")], + [ImportVar(tag="Component"), ImportVar(tag="useState")], + ["Component", ImportVar(tag="useState")], + ), +) +def test_custom_component_add_imports(tags): + def _list_to_import_vars(tags: List[str]) -> List[ImportVar]: + return [ + ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag + for tag in tags + ] + + class BaseComponent(Component): + def _get_imports(self) -> imports.ImportDict: + return {} + + class Reference(Component): + def _get_imports(self) -> imports.ImportDict: + return imports.merge_imports( + super()._get_imports(), + {"react": _list_to_import_vars(tags)}, + ) + + class Test(Component): + def add_imports( + self, + ) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]: + + return {"react": (tags[0] if len(tags) == 1 else tags)} + + baseline = Reference.create() + test = Test.create() + + assert baseline.get_imports() == {"react": _list_to_import_vars(tags)} + assert test.get_imports() == baseline.get_imports()