diff --git a/reflex/app.py b/reflex/app.py
index c47077d0e..b3d9149e6 100644
--- a/reflex/app.py
+++ b/reflex/app.py
@@ -79,7 +79,7 @@ from reflex.state import (
)
from reflex.utils import console, exceptions, format, prerequisites, types
from reflex.utils.exec import is_testing_env, should_skip_compile
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList
# Define custom types.
ComponentCallable = Callable[[], Component]
@@ -618,27 +618,16 @@ class App(Base):
admin.mount_to(self.api)
- def get_frontend_packages(self, imports: Dict[str, set[ImportVar]]):
+ def get_frontend_packages(self, imports: ImportList):
"""Gets the frontend packages to be installed and filters out the unnecessary ones.
Args:
- imports: A dictionary containing the imports used in the current page.
+ imports: A list containing the imports used in the current page.
Example:
>>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"})
"""
- page_imports = {
- i
- for i, tags in imports.items()
- if i
- not in [
- *constants.PackageJson.DEPENDENCIES.keys(),
- *constants.PackageJson.DEV_DEPENDENCIES.keys(),
- ]
- and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
- and i != ""
- and any(tag.install for tag in tags)
- }
+ page_imports = [i.package for i in imports.collapse().values() if i.install]
frontend_packages = get_config().frontend_packages
_frontend_packages = []
for package in frontend_packages:
@@ -653,7 +642,7 @@ class App(Base):
)
continue
_frontend_packages.append(package)
- page_imports.update(_frontend_packages)
+ page_imports.extend(_frontend_packages)
prerequisites.install_frontend_packages(page_imports, get_config())
def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
@@ -794,7 +783,7 @@ class App(Base):
self.style = evaluate_style_namespaces(self.style)
# Track imports and custom components found.
- all_imports = {}
+ all_imports = ImportList()
custom_components = set()
for _route, component in self.pages.items():
@@ -804,7 +793,7 @@ class App(Base):
component.apply_theme(self.theme)
# Add component._get_all_imports() to all_imports.
- all_imports.update(component._get_all_imports())
+ all_imports.extend(component._get_all_imports())
# Add the app wrappers from this component.
app_wrappers.update(component._get_all_app_wrap_components())
@@ -932,10 +921,10 @@ class App(Base):
custom_components_imports,
) = custom_components_future.result()
compile_results.append(custom_components_result)
- all_imports.update(custom_components_imports)
+ all_imports.extend(custom_components_imports)
# Get imports from AppWrap components.
- all_imports.update(app_root._get_all_imports())
+ all_imports.extend(app_root._get_all_imports())
progress.advance(task)
@@ -951,7 +940,7 @@ class App(Base):
# Setup the next.config.js
transpile_packages = [
package
- for package, import_vars in all_imports.items()
+ for package, import_vars in all_imports.collapse().items()
if any(import_var.transpile for import_var in import_vars)
]
prerequisites.update_next_config(
diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py
index 89ac867f7..dae97c153 100644
--- a/reflex/compiler/compiler.py
+++ b/reflex/compiler/compiler.py
@@ -19,7 +19,7 @@ from reflex.config import get_config
from reflex.state import BaseState
from reflex.style import LIGHT_COLOR_MODE
from reflex.utils.exec import is_prod_mode
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList, ImportVar
from reflex.vars import Var
@@ -197,25 +197,34 @@ def _compile_components(
Returns:
The compiled components.
"""
- imports = {
- "react": [ImportVar(tag="memo")],
- f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")],
- }
+ _imports = ImportList(
+ [
+ ImportVar(package="react", tag="memo"),
+ ImportVar(
+ package=f"/{constants.Dirs.STATE_PATH}",
+ tag="E",
+ ),
+ ImportVar(
+ package=f"/{constants.Dirs.STATE_PATH}",
+ tag="isTrue",
+ ),
+ ]
+ )
component_renders = []
# Compile each component.
for component in components:
component_render, component_imports = utils.compile_custom_component(component)
component_renders.append(component_render)
- imports = utils.merge_imports(imports, component_imports)
+ _imports.extend(component_imports)
# Compile the components page.
return (
templates.COMPONENTS.render(
- imports=utils.compile_imports(imports),
+ imports=utils.compile_imports(_imports),
components=component_renders,
),
- imports,
+ _imports,
)
@@ -235,7 +244,7 @@ def _compile_stateful_components(
Returns:
The rendered stateful components code.
"""
- all_import_dicts = []
+ all_imports = []
rendered_components = {}
def get_shared_components_recursive(component: BaseComponent):
@@ -266,7 +275,7 @@ def _compile_stateful_components(
rendered_components.update(
{code: None for code in component._get_all_custom_code()},
)
- all_import_dicts.append(component._get_all_imports())
+ all_imports.extend(component._get_all_imports())
# Indicate that this component now imports from the shared file.
component.rendered_as_shared = True
@@ -275,9 +284,11 @@ def _compile_stateful_components(
get_shared_components_recursive(page_component)
# Don't import from the file that we're about to create.
- all_imports = utils.merge_imports(*all_import_dicts)
- all_imports.pop(
- f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None
+ all_imports = ImportList(
+ imp
+ for imp in all_imports
+ if imp.library
+ != f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}"
)
return templates.STATEFUL_COMPONENTS.render(
@@ -408,7 +419,7 @@ def compile_page(
def compile_components(
components: set[CustomComponent],
-) -> tuple[str, str, Dict[str, list[ImportVar]]]:
+) -> tuple[str, str, ImportList]:
"""Compile the custom components.
Args:
diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py
index 14d7d4d36..ee65d08a2 100644
--- a/reflex/compiler/utils.py
+++ b/reflex/compiler/utils.py
@@ -88,16 +88,16 @@ def validate_imports(import_dict: imports.ImportDict):
used_tags[import_name] = lib
-def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
- """Compile an import dict.
+def compile_imports(import_list: imports.ImportList) -> list[dict]:
+ """Compile an import list.
Args:
- import_dict: The import dict to compile.
+ import_list: The import list to compile.
Returns:
- The list of import dict.
+ The list of template import dict.
"""
- collapsed_import_dict = imports.collapse_imports(import_dict)
+ collapsed_import_dict = import_list.collapse()
validate_imports(collapsed_import_dict)
import_dicts = []
for lib, fields in collapsed_import_dict.items():
@@ -114,9 +114,6 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
import_dicts.append(get_import_dict(module))
continue
- # remove the version before rendering the package imports
- lib = format.format_library_name(lib)
-
import_dicts.append(get_import_dict(lib, default, rest))
return import_dicts
@@ -237,7 +234,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
def compile_custom_component(
component: CustomComponent,
-) -> tuple[dict, imports.ImportDict]:
+) -> tuple[dict, imports.ImportList]:
"""Compile a custom component.
Args:
@@ -250,11 +247,12 @@ def compile_custom_component(
render = component.get_component(component)
# Get the imports.
- imports = {
- lib: fields
- for lib, fields in render._get_all_imports().items()
- if lib != component.library
- }
+ component_library_name = format.format_library_name(component.library)
+ _imports = imports.ImportList(
+ imp
+ for imp in render._get_all_imports()
+ if imp.library != component_library_name
+ )
# Concatenate the props.
props = [prop._var_name for prop in component.get_prop_vars()]
@@ -268,7 +266,7 @@ def compile_custom_component(
"hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
"custom_code": render._get_all_custom_code(),
},
- imports,
+ _imports,
)
diff --git a/reflex/components/chakra/base.py b/reflex/components/chakra/base.py
index 9e1f3f698..95e457a05 100644
--- a/reflex/components/chakra/base.py
+++ b/reflex/components/chakra/base.py
@@ -35,19 +35,18 @@ class ChakraComponent(Component):
@classmethod
@lru_cache(maxsize=None)
- def _get_dependencies_imports(cls) -> imports.ImportDict:
+ def _get_dependencies_imports(cls) -> imports.ImportList:
"""Get the imports from lib_dependencies for installing.
Returns:
The dependencies imports of the component.
"""
- return {
- dep: [imports.ImportVar(tag=None, render=False)]
- for dep in [
- "@chakra-ui/system@2.5.7",
- "framer-motion@10.16.4",
- ]
- }
+ return [
+ imports.ImportVar(
+ package="@chakra-ui/system@2.5.7", tag=None, render=False
+ ),
+ imports.ImportVar(package="framer-motion@10.16.4", tag=None, render=False),
+ ]
class ChakraProvider(ChakraComponent):
diff --git a/reflex/components/component.py b/reflex/components/component.py
index 9dd11254c..96834be08 100644
--- a/reflex/components/component.py
+++ b/reflex/components/component.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import copy
+import itertools
import typing
from abc import ABC, abstractmethod
from functools import lru_cache, wraps
@@ -95,11 +96,11 @@ class BaseComponent(Base, ABC):
"""
@abstractmethod
- def _get_all_imports(self) -> imports.ImportDict:
+ def _get_all_imports(self) -> imports.ImportList:
"""Get all the libraries and fields that are used by the component.
Returns:
- The import dict with the required imports.
+ The list of all required ImportVar.
"""
@abstractmethod
@@ -994,17 +995,22 @@ class Component(BaseComponent, ABC):
# Return the dynamic imports
return dynamic_imports
- def _get_props_imports(self) -> List[str]:
+ def _get_props_imports(self) -> imports.ImportList:
"""Get the imports needed for components props.
Returns:
- The imports for the components props of the component.
+ The imports for the components props of the component.
"""
- return [
- getattr(self, prop)._get_all_imports()
- for prop in self.get_component_props()
- if getattr(self, prop) is not None
- ]
+ return imports.ImportList(
+ sum(
+ (
+ getattr(self, prop)._get_all_imports()
+ for prop in self.get_component_props()
+ if getattr(self, prop) is not None
+ ),
+ [],
+ )
+ )
def _should_transpile(self, dep: str | None) -> bool:
"""Check if a dependency should be transpiled.
@@ -1020,97 +1026,129 @@ class Component(BaseComponent, ABC):
or format.format_library_name(dep or "") in self.transpile_packages
)
- def _get_dependencies_imports(self) -> imports.ImportDict:
+ def _get_dependencies_imports(self) -> imports.ImportList:
"""Get the imports from lib_dependencies for installing.
Returns:
The dependencies imports of the component.
"""
- return {
- dep: [
- ImportVar(
- tag=None,
- render=False,
- transpile=self._should_transpile(dep),
- )
- ]
+ return imports.ImportList(
+ ImportVar(
+ package=dep,
+ tag=None,
+ render=False,
+ transpile=self._should_transpile(dep),
+ )
for dep in self.lib_dependencies
- }
+ )
- def _get_hooks_imports(self) -> imports.ImportDict:
+ def _get_hooks_imports(self) -> imports.ImportList:
"""Get the imports required by certain hooks.
Returns:
The imports required for all selected hooks.
"""
- _imports = {}
+ _imports = imports.ImportList()
if self._get_ref_hook():
# Handle hooks needed for attaching react refs to DOM nodes.
- _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
- _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs"))
+ _imports.extend(
+ [
+ ImportVar(package="react", tag="useRef"),
+ ImportVar(package=f"/{Dirs.STATE_PATH}", tag="refs"),
+ ]
+ )
if self._get_mount_lifecycle_hook():
# Handle hooks for `on_mount` / `on_unmount`.
- _imports.setdefault("react", set()).add(ImportVar(tag="useEffect"))
+ _imports.append(ImportVar(package="react", tag="useEffect"))
if self._get_special_hooks():
# Handle additional internal hooks (autofocus, etc).
- _imports.setdefault("react", set()).update(
- {
- ImportVar(tag="useRef"),
- ImportVar(tag="useEffect"),
- },
+ _imports.extend(
+ [
+ ImportVar(package="react", tag="useEffect"),
+ ImportVar(package="react", tag="useRef"),
+ ]
)
user_hooks = self._get_hooks()
if user_hooks is not None and isinstance(user_hooks, Var):
- _imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore
+ _imports.extend(user_hooks._var_data.imports)
return _imports
def _get_imports(self) -> imports.ImportDict:
- """Get all the libraries and fields that are used by the component.
+ """Deprecated method to get all the libraries and fields used by the component.
Returns:
The imports needed by the component.
"""
- _imports = {}
+ return {}
+
+ def _get_imports_list(self) -> imports.ImportList:
+ """Internal method to get the imports as a list.
+
+ Returns:
+ The imports as a list.
+ """
+ _imports = imports.ImportList(
+ itertools.chain(
+ self._get_props_imports(),
+ self._get_dependencies_imports(),
+ self._get_hooks_imports(),
+ )
+ )
+
+ # Handle deprecated _get_imports
+ import_dict = self._get_imports()
+ if import_dict:
+ console.deprecate(
+ feature_name="_get_imports",
+ reason="use add_imports instead",
+ deprecation_version="0.5.0",
+ removal_version="0.6.0",
+ )
+ _imports.extend(imports.ImportList.from_import_dict(import_dict))
# Import this component's tag from the main library.
if self.library is not None and self.tag is not None:
- _imports[self.library] = {self.import_var}
+ _imports.append(self.import_var)
# Get static imports required for event processing.
- event_imports = Imports.EVENTS if self.event_triggers else {}
+ if self.event_triggers:
+ _imports.append(Imports.EVENTS)
# Collect imports from Vars used directly by this component.
- var_imports = [
- var._var_data.imports for var in self._get_vars() if var._var_data
- ]
+ for var in self._get_vars():
+ if var._var_data:
+ _imports.extend(var._var_data.imports)
+ return _imports
- return imports.merge_imports(
- *self._get_props_imports(),
- self._get_dependencies_imports(),
- self._get_hooks_imports(),
- _imports,
- event_imports,
- *var_imports,
- )
-
- def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
+ def _get_all_imports(self, collapse: bool = False) -> imports.ImportList:
"""Get all the libraries and fields that are used by the component and its children.
Args:
- collapse: Whether to collapse the imports by removing duplicates.
+ collapse: Whether to collapse the imports into a dict (deprecated).
Returns:
- The import dict with the required imports.
+ The list of all required imports.
"""
- _imports = imports.merge_imports(
- self._get_imports(), *[child._get_all_imports() for child in self.children]
+ _imports = imports.ImportList(
+ self._get_imports_list()
+ + sum((child._get_all_imports() for child in self.children), [])
)
- return imports.collapse_imports(_imports) if collapse else _imports
+
+ if collapse:
+ console.deprecate(
+ feature_name="collapse kwarg to _get_all_imports",
+ reason="use ImportList.collapse instead",
+ deprecation_version="0.5.0",
+ removal_version="0.6.0",
+ )
+ return _imports.collapse() # type: ignore
+
+ return _imports
def _get_mount_lifecycle_hook(self) -> str | None:
"""Generate the component lifecycle hook.
@@ -1296,6 +1334,7 @@ class Component(BaseComponent, ABC):
tag = self.tag.partition(".")[0] if self.tag else None
alias = self.alias.partition(".")[0] if self.alias else None
return ImportVar(
+ package=self.library,
tag=tag,
is_default=self.is_default,
alias=alias,
@@ -1575,7 +1614,6 @@ class NoSSRComponent(Component):
return imports.merge_imports(
dynamic_import,
_imports,
- self._get_dependencies_imports(),
)
def _get_dynamic_imports(self) -> str:
@@ -1893,18 +1931,21 @@ class StatefulComponent(BaseComponent):
"""
return {}
- def _get_all_imports(self) -> imports.ImportDict:
+ def _get_all_imports(self) -> imports.ImportList:
"""Get all the libraries and fields that are used by the component.
Returns:
- The import dict with the required imports.
+ The list of all required imports.
"""
if self.rendered_as_shared:
- return {
- f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [
- ImportVar(tag=self.tag)
+ return imports.ImportList(
+ [
+ imports.ImportVar(
+ package=f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}",
+ tag=self.tag,
+ )
]
- }
+ )
return self.component._get_all_imports()
def _get_all_dynamic_imports(self) -> set[str]:
diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py
index 737b650d8..343b082b3 100644
--- a/reflex/components/core/cond.py
+++ b/reflex/components/core/cond.py
@@ -12,9 +12,9 @@ from reflex.style import LIGHT_COLOR_MODE, color_mode
from reflex.utils import format, imports
from reflex.vars import BaseVar, Var, VarData
-_IS_TRUE_IMPORT = {
- f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
-}
+_IS_TRUE_IMPORT = imports.ImportList(
+ [imports.ImportVar(library=f"/{Dirs.STATE_PATH}", tag="isTrue")]
+)
class Cond(MemoizationLeaf):
@@ -95,11 +95,13 @@ class Cond(MemoizationLeaf):
cond_state=f"isTrue({self.cond._var_full_name})",
)
- def _get_imports(self) -> imports.ImportDict:
- return imports.merge_imports(
- super()._get_imports(),
- getattr(self.cond._var_data, "imports", {}),
- _IS_TRUE_IMPORT,
+ def _get_imports_list(self) -> imports.ImportList:
+ return imports.ImportList(
+ [
+ *super()._get_imports_list(),
+ *getattr(self.cond._var_data, "imports", []),
+ *_IS_TRUE_IMPORT,
+ ]
)
def _apply_theme(self, theme: Component):
diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py
index b99e31e8c..4686ef5f8 100644
--- a/reflex/constants/compiler.py
+++ b/reflex/constants/compiler.py
@@ -6,7 +6,7 @@ from types import SimpleNamespace
from reflex.base import Base
from reflex.constants import Dirs
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList, ImportVar
# The prefix used to create setters for state vars.
SETTER_PREFIX = "set_"
@@ -102,11 +102,11 @@ class ComponentName(Enum):
class Imports(SimpleNamespace):
"""Common sets of import vars."""
- EVENTS = {
- "react": {ImportVar(tag="useContext")},
- f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
- f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
- }
+ EVENTS: ImportList = [
+ ImportVar(package="react", tag="useContext"),
+ ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"),
+ ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT),
+ ]
class Hooks(SimpleNamespace):
diff --git a/reflex/utils/format.py b/reflex/utils/format.py
index 70f6b5b25..fa2d115ad 100644
--- a/reflex/utils/format.py
+++ b/reflex/utils/format.py
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
from reflex import constants
from reflex.utils import exceptions, serializers, types
+from reflex.utils.imports import split_library_name_version
from reflex.utils.serializers import serialize
from reflex.vars import BaseVar, Var
@@ -716,11 +717,7 @@ def format_library_name(library_fullname: str):
Returns:
The name without the @version if it was part of the name
"""
- lib, at, version = library_fullname.rpartition("@")
- if not lib:
- lib = at + version
-
- return lib
+ return split_library_name_version(library_fullname)[0]
def json_dumps(obj: Any) -> str:
diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py
index 263de1e3d..42c3a9385 100644
--- a/reflex/utils/imports.py
+++ b/reflex/utils/imports.py
@@ -3,9 +3,10 @@
from __future__ import annotations
from collections import defaultdict
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Set
from reflex.base import Base
+from reflex.constants.installer import PackageJson
def merge_imports(*imports) -> ImportDict:
@@ -36,9 +37,29 @@ def collapse_imports(imports: ImportDict) -> ImportDict:
return {lib: list(set(import_vars)) for lib, import_vars in imports.items()}
+def split_library_name_version(library_fullname: str):
+ """Split the name of a library from its version.
+
+ Args:
+ library_fullname: The fullname of the library.
+
+ Returns:
+ A tuple of the library name and version.
+ """
+ lib, at, version = library_fullname.rpartition("@")
+ if not lib:
+ lib = at + version
+ version = None
+
+ return lib, version
+
+
class ImportVar(Base):
"""An import var."""
+ # The package name associated with the tag
+ library: Optional[str]
+
# The name of the import tag.
tag: Optional[str]
@@ -48,6 +69,12 @@ class ImportVar(Base):
# The tag alias.
alias: Optional[str] = None
+ # The following fields provide extra information about the import,
+ # but are not factored in when considering hash or equality
+
+ # The version of the package
+ version: Optional[str]
+
# Whether this import need to install the associated lib
install: Optional[bool] = True
@@ -58,6 +85,34 @@ class ImportVar(Base):
# https://nextjs.org/docs/app/api-reference/next-config-js/transpilePackages
transpile: Optional[bool] = False
+ def __init__(
+ self,
+ *,
+ package: Optional[str] = None,
+ **kwargs,
+ ):
+ if package is not None:
+ if (
+ kwargs.get("library", None) is not None
+ or kwargs.get("version", None) is not None
+ ):
+ raise ValueError(
+ "Cannot provide 'library' or 'version' as keyword arguments when "
+ "specifying 'package' as an argument"
+ )
+ kwargs["library"], kwargs["version"] = split_library_name_version(package)
+
+ install = (
+ package is not None
+ # TODO: handle version conflicts
+ and package not in PackageJson.DEPENDENCIES
+ and package not in PackageJson.DEV_DEPENDENCIES
+ and not any(package.startswith(prefix) for prefix in ["/", ".", "next/"])
+ and package != ""
+ )
+ kwargs.setdefault("install", install)
+ super().__init__(**kwargs)
+
@property
def name(self) -> str:
"""The name of the import.
@@ -72,6 +127,17 @@ class ImportVar(Base):
else:
return self.tag or ""
+ @property
+ def package(self) -> str:
+ """The package to install for this import
+
+ Returns:
+ The library name and (optional) version to be installed by npm/bun.
+ """
+ if self.version:
+ return f"{self.library}@{self.version}"
+ return self.library
+
def __hash__(self) -> int:
"""Define a hash function for the import var.
@@ -80,14 +146,97 @@ class ImportVar(Base):
"""
return hash(
(
+ self.library,
self.tag,
self.is_default,
self.alias,
- self.install,
- self.render,
- self.transpile,
+ # These do not fundamentally change the import in any way
+ # self.install,
+ # self.render,
+ # self.transpile,
)
)
+ def __eq__(self, other: ImportVar) -> bool:
+ """Define equality for the import var.
-ImportDict = Dict[str, List[ImportVar]]
+ Args:
+ other: The other import var to compare.
+
+ Returns:
+ Whether the two import vars are equal.
+ """
+ if type(self) != type(other):
+ return NotImplemented
+ return (self.library, self.tag, self.is_default, self.alias) == (
+ other.library,
+ other.tag,
+ other.is_default,
+ other.alias,
+ )
+
+ def collapse(self, other_import_var: ImportVar) -> ImportVar:
+ """Collapse two import vars together.
+
+ Args:
+ other_import_var: The other import var to collapse with.
+
+ Returns:
+ The collapsed import var with sticky props perserved.
+ """
+ if self != other_import_var:
+ raise ValueError("Cannot collapse two import vars with different hashes")
+
+ if self.version is not None and other_import_var.version is not None:
+ if self.version != other_import_var.version:
+ raise ValueError(
+ "Cannot collapse two import vars with conflicting version specifiers: "
+ f"{self} {other_import_var}"
+ )
+
+ return type(self)(
+ library=self.library,
+ version=self.version or other_import_var.version,
+ tag=self.tag,
+ is_default=self.is_default,
+ alias=self.alias,
+ install=self.install or other_import_var.install,
+ render=self.render or other_import_var.render,
+ transpile=self.transpile or other_import_var.transpile,
+ )
+
+
+class ImportList(List[ImportVar]):
+ """A list of import vars."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ for ix, value in enumerate(self):
+ if not isinstance(value, ImportVar):
+ # convert dicts to ImportVar
+ self[ix] = ImportVar(**value)
+
+ @classmethod
+ def from_import_dict(cls, import_dict: ImportDict) -> ImportList:
+ return [
+ ImportVar(package=lib, **imp.dict())
+ for lib, imps in import_dict.items()
+ for imp in imps
+ ]
+
+ def collapse(self) -> ImportDict:
+ """When collapsing an import list, prefer packages with version specifiers."""
+ collapsed = {}
+ for imp in self:
+ collapsed.setdefault(imp.library, {})
+ if imp in collapsed[imp.library]:
+ # Need to check if the current import has any special properties that need to
+ # be preserved, like the version specifier, install, or transpile.
+ existing_imp = collapsed[imp.library][imp]
+ collapsed[imp.library][imp] = existing_imp.collapse(imp)
+ else:
+ collapsed[imp.library][imp] = imp
+ return {lib: set(imps) for lib, imps in collapsed.items()}
+
+
+ImportDict = Dict[str, Set[ImportVar]]
diff --git a/reflex/vars.py b/reflex/vars.py
index 4a8e6b30f..7793ef07c 100644
--- a/reflex/vars.py
+++ b/reflex/vars.py
@@ -37,7 +37,7 @@ from reflex.base import Base
from reflex.utils import console, format, imports, serializers, types
# This module used to export ImportVar itself, so we still import it for export here
-from reflex.utils.imports import ImportDict, ImportVar
+from reflex.utils.imports import ImportDict, ImportList, ImportVar
if TYPE_CHECKING:
from reflex.state import BaseState
@@ -116,7 +116,7 @@ class VarData(Base):
state: str = ""
# Imports needed to render this var
- imports: ImportDict = {}
+ imports: ImportList = []
# Hooks that need to be present in the component to render this var
hooks: Dict[str, None] = {}
@@ -126,6 +126,19 @@ class VarData(Base):
# segments.
interpolations: List[Tuple[int, int]] = []
+ def __init__(self, imports: ImportDict | ImportList = None, **kwargs):
+ if isinstance(imports, dict):
+ imports = ImportList.from_import_dict(imports)
+ console.deprecate(
+ feature_name="Passing ImportDict for VarData",
+ reason="use ImportList instead",
+ deprecation_version="0.5.0",
+ removal_version="0.6.0",
+ )
+ elif imports is None:
+ imports = []
+ super().__init__(imports=imports, **kwargs)
+
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None:
"""Merge multiple var data objects.
@@ -137,14 +150,14 @@ class VarData(Base):
The merged var data object.
"""
state = ""
- _imports = {}
+ _imports = []
hooks = {}
interpolations = []
for var_data in others:
if var_data is None:
continue
state = state or var_data.state
- _imports = imports.merge_imports(_imports, var_data.imports)
+ _imports.extend(var_data.imports)
hooks.update(var_data.hooks)
interpolations += var_data.interpolations
@@ -180,11 +193,18 @@ class VarData(Base):
# Don't compare interpolations - that's added in by the decoder, and
# not part of the vardata itself.
+ if not isinstance(self.imports, ImportList):
+ self_imports = ImportList(self.imports).collapse()
+ else:
+ self_imports = self.imports.collapse()
+ if not isinstance(other.imports, ImportList):
+ other_imports = ImportList(other.imports).collapse()
+ else:
+ other_imports = other.imports.collapse()
return (
self.state == other.state
and self.hooks.keys() == other.hooks.keys()
- and imports.collapse_imports(self.imports)
- == imports.collapse_imports(other.imports)
+ and self_imports == other_imports
)
def dict(self) -> dict:
@@ -196,10 +216,7 @@ class VarData(Base):
return {
"state": self.state,
"interpolations": list(self.interpolations),
- "imports": {
- lib: [import_var.dict() for import_var in import_vars]
- for lib, import_vars in self.imports.items()
- },
+ "imports": [import_var.dict() for import_var in self.imports],
"hooks": self.hooks,
}
diff --git a/tests/compiler/test_compiler.py b/tests/compiler/test_compiler.py
index b6191974a..db0baae86 100644
--- a/tests/compiler/test_compiler.py
+++ b/tests/compiler/test_compiler.py
@@ -4,8 +4,7 @@ from typing import List
import pytest
from reflex.compiler import compiler, utils
-from reflex.utils import imports
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList, ImportVar
@pytest.mark.parametrize(
@@ -48,43 +47,56 @@ def test_compile_import_statement(
@pytest.mark.parametrize(
- "import_dict,test_dicts",
+ "import_list,test_dicts",
[
- ({}, []),
+ (ImportList(), []),
(
- {"axios": [ImportVar(tag="axios", is_default=True)]},
+ ImportList([ImportVar(library="axios", tag="axios", is_default=True)]),
[{"lib": "axios", "default": "axios", "rest": []}],
),
(
- {"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]},
+ ImportList(
+ [
+ ImportVar(library="axios", tag="foo"),
+ ImportVar(library="axios", tag="bar"),
+ ]
+ ),
[{"lib": "axios", "default": "", "rest": ["bar", "foo"]}],
),
(
- {
- "axios": [
- ImportVar(tag="axios", is_default=True),
- ImportVar(tag="foo"),
- ImportVar(tag="bar"),
- ],
- "react": [ImportVar(tag="react", is_default=True)],
- },
+ ImportList(
+ [
+ ImportVar(library="axios", tag="axios", is_default=True),
+ ImportVar(library="axios", tag="foo"),
+ ImportVar(library="axios", tag="bar"),
+ ImportVar(library="react", tag="react", is_default=True),
+ ]
+ ),
[
{"lib": "axios", "default": "axios", "rest": ["bar", "foo"]},
{"lib": "react", "default": "react", "rest": []},
],
),
(
- {"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]},
+ ImportList(
+ [
+ ImportVar(library="", tag="lib1.js"),
+ ImportVar(library="", tag="lib2.js"),
+ ]
+ ),
[
{"lib": "lib1.js", "default": "", "rest": []},
{"lib": "lib2.js", "default": "", "rest": []},
],
),
(
- {
- "": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")],
- "axios": [ImportVar(tag="axios", is_default=True)],
- },
+ ImportList(
+ [
+ ImportVar(library="", tag="lib1.js"),
+ ImportVar(library="", tag="lib2.js"),
+ ImportVar(library="axios", tag="axios", is_default=True),
+ ]
+ ),
[
{"lib": "lib1.js", "default": "", "rest": []},
{"lib": "lib2.js", "default": "", "rest": []},
@@ -93,14 +105,14 @@ def test_compile_import_statement(
),
],
)
-def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]):
+def test_compile_imports(import_list: ImportList, test_dicts: List[dict]):
"""Test the compile_imports function.
Args:
- import_dict: The import dictionary.
+ import_list: The list of ImportVar.
test_dicts: The expected output.
"""
- imports = utils.compile_imports(import_dict)
+ imports = utils.compile_imports(import_list)
for import_dict, test_dict in zip(imports, test_dicts):
assert import_dict["lib"] == test_dict["lib"]
assert import_dict["default"] == test_dict["default"]
diff --git a/tests/components/core/test_banner.py b/tests/components/core/test_banner.py
index f929eef37..bfdf86b7c 100644
--- a/tests/components/core/test_banner.py
+++ b/tests/components/core/test_banner.py
@@ -20,7 +20,7 @@ def test_connection_banner():
"react",
"/utils/context",
"/utils/state",
- "@radix-ui/themes@^3.0.0",
+ "@radix-ui/themes",
"/env.json",
]
@@ -36,7 +36,7 @@ def test_connection_modal():
"react",
"/utils/context",
"/utils/state",
- "@radix-ui/themes@^3.0.0",
+ "@radix-ui/themes",
"/env.json",
]
diff --git a/tests/components/test_component.py b/tests/components/test_component.py
index 21ec409af..942e7a932 100644
--- a/tests/components/test_component.py
+++ b/tests/components/test_component.py
@@ -296,11 +296,11 @@ def test_get_imports(component1, component2):
"""
c1 = component1.create()
c2 = component2.create(c1)
- assert c1._get_all_imports() == {"react": [ImportVar(tag="Component")]}
- assert c2._get_all_imports() == {
- "react-redux": [ImportVar(tag="connect")],
- "react": [ImportVar(tag="Component")],
- }
+ assert c1._get_all_imports() == [ImportVar(library="react", tag="Component")]
+ assert c2._get_all_imports() == [
+ ImportVar(library="react-redux", tag="connect"),
+ ImportVar(library="react", tag="Component"),
+ ]
def test_get_custom_code(component1, component2):
@@ -1514,22 +1514,24 @@ def test_custom_component_get_imports():
custom_comp = wrapper()
# Inner is not imported directly, but it is imported by the custom component.
- assert "inner" not in custom_comp._get_all_imports()
+ inner_import = ImportVar(library="inner", tag="Inner")
+ assert inner_import not in custom_comp._get_all_imports()
# The imports are only resolved during compilation.
_, _, imports_inner = compile_components(custom_comp._get_all_custom_components())
- assert "inner" in imports_inner
+ assert inner_import in imports_inner
outer_comp = outer(c=wrapper())
# Libraries are not imported directly, but are imported by the custom component.
- assert "inner" not in outer_comp._get_all_imports()
- assert "other" not in outer_comp._get_all_imports()
+ other_import = ImportVar(library="other", tag="Other")
+ assert inner_import not in outer_comp._get_all_imports()
+ assert other_import not in outer_comp._get_all_imports()
# The imports are only resolved during compilation.
_, _, imports_outer = compile_components(outer_comp._get_all_custom_components())
- assert "inner" in imports_outer
- assert "other" in imports_outer
+ assert inner_import in imports_outer
+ assert other_import in imports_outer
def test_custom_component_declare_event_handlers_in_fields():
diff --git a/tests/test_var.py b/tests/test_var.py
index a58c49392..6af50e187 100644
--- a/tests/test_var.py
+++ b/tests/test_var.py
@@ -837,7 +837,7 @@ def test_state_with_initial_computed_var(
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
(
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
- 'testing f-string with ${"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}{state.myvar}',
+ 'testing f-string with ${"state": "state", "interpolations": [], "imports": [{"library": "/utils/context", "tag": "StateContexts", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}, {"library": "react", "tag": "useContext", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}], "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}{state.myvar}',
),
(
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",