diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index 26faa4820..d878fe438 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -49,7 +49,10 @@ class ImportVar(Base): Returns: The name(tag name with alias) of tag. """ - return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore + if self.alias: + return self.alias if self.is_default else " as ".join([self.tag, self.alias]) # type: ignore + else: + return self.tag or "" def __hash__(self) -> int: """Define a hash function for the import var. diff --git a/tests/test_var.py b/tests/test_var.py index e3cb28a1c..c57c7eb23 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -7,7 +7,6 @@ from pandas import DataFrame from reflex.base import Base from reflex.state import State -from reflex.utils.imports import ImportVar from reflex.vars import ( BaseVar, ComputedVar, @@ -24,8 +23,6 @@ test_vars = [ BaseVar(_var_name="local2", _var_type=str, _var_is_local=True), ] -test_import_vars = [ImportVar(tag="DataGrid"), ImportVar(tag="DataGrid", alias="Grid")] - class BaseState(State): """A Test State.""" @@ -610,26 +607,6 @@ def test_computed_var_with_annotation_error(request, fixture, full_name): ) -@pytest.mark.parametrize( - "import_var,expected", - zip( - test_import_vars, - [ - "DataGrid", - "DataGrid as Grid", - ], - ), -) -def test_import_var(import_var, expected): - """Test that the import var name is computed correctly. - - Args: - import_var: The import var. - expected: expected name - """ - assert import_var.name == expected - - @pytest.mark.parametrize( "out, expected", [ diff --git a/tests/utils/test_imports.py b/tests/utils/test_imports.py new file mode 100644 index 000000000..c7253ff6b --- /dev/null +++ b/tests/utils/test_imports.py @@ -0,0 +1,78 @@ +import pytest + +from reflex.utils.imports import ImportVar, merge_imports + + +@pytest.mark.parametrize( + "import_var, expected_name", + [ + ( + ImportVar(tag="BaseTag"), + "BaseTag", + ), + ( + ImportVar(tag="BaseTag", alias="AliasTag"), + "BaseTag as AliasTag", + ), + ( + ImportVar(tag="BaseTag", is_default=True), + "BaseTag", + ), + ( + ImportVar(tag="BaseTag", is_default=True, alias="AliasTag"), + "AliasTag", + ), + ( + ImportVar(tag="BaseTag", is_default=False), + "BaseTag", + ), + ( + ImportVar(tag="BaseTag", is_default=False, alias="AliasTag"), + "BaseTag as AliasTag", + ), + ], +) +def test_import_var(import_var, expected_name): + """Test that the import var name is computed correctly. + + Args: + import_var: The import var. + expected_name: The expected name. + """ + assert import_var.name == expected_name + + +@pytest.mark.parametrize( + "input_1, input_2, output", + [ + ( + {"react": {"Component"}}, + {"react": {"Component"}, "react-dom": {"render"}}, + {"react": {"Component"}, "react-dom": {"render"}}, + ), + ( + {"react": {"Component"}, "next/image": {"Image"}}, + {"react": {"Component"}, "react-dom": {"render"}}, + {"react": {"Component"}, "react-dom": {"render"}, "next/image": {"Image"}}, + ), + ( + {"react": {"Component"}}, + {"": {"some/custom.css"}}, + {"react": {"Component"}, "": {"some/custom.css"}}, + ), + ], +) +def test_merge_imports(input_1, input_2, output): + """Test that imports are merged correctly. + + Args: + input_1: The first dict to merge. + input_2: The second dict to merge. + output: The expected output dict after merging. + + """ + res = merge_imports(input_1, input_2) + assert set(res.keys()) == set(output.keys()) + + for key in output: + assert set(res[key]) == set(output[key]) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 72c1d0f5e..b64f5f188 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -13,7 +13,6 @@ from reflex.event import EventHandler from reflex.state import State from reflex.utils import ( build, - imports, prerequisites, types, ) @@ -56,16 +55,6 @@ def test_func(): pass -def test_merge_imports(): - """Test that imports are merged correctly.""" - d1 = {"react": {"Component"}} - d2 = {"react": {"Component"}, "react-dom": {"render"}} - d = imports.merge_imports(d1, d2) - assert set(d.keys()) == {"react", "react-dom"} - assert set(d["react"]) == {"Component"} - assert set(d["react-dom"]) == {"render"} - - @pytest.mark.parametrize( "cls,expected", [