WiP: use ImportList internally instead of ImportDict

* deprecate `_get_imports` in favor of new `_get_imports_list`
* `_get_all_imports` now returns an `ImportList`
* Compiler uses `ImportList.collapse` to get an `ImportDict`
This commit is contained in:
Masen Furer 2024-04-29 19:05:52 -07:00
parent 3564df7620
commit 35252464a0
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
14 changed files with 404 additions and 187 deletions

View File

@ -79,7 +79,7 @@ from reflex.state import (
) )
from reflex.utils import console, exceptions, format, prerequisites, types from reflex.utils import console, exceptions, format, prerequisites, types
from reflex.utils.exec import is_testing_env, should_skip_compile from reflex.utils.exec import is_testing_env, should_skip_compile
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportList
# Define custom types. # Define custom types.
ComponentCallable = Callable[[], Component] ComponentCallable = Callable[[], Component]
@ -618,27 +618,16 @@ class App(Base):
admin.mount_to(self.api) admin.mount_to(self.api)
def get_frontend_packages(self, imports: Dict[str, set[ImportVar]]): def get_frontend_packages(self, imports: ImportList):
"""Gets the frontend packages to be installed and filters out the unnecessary ones. """Gets the frontend packages to be installed and filters out the unnecessary ones.
Args: Args:
imports: A dictionary containing the imports used in the current page. imports: A list containing the imports used in the current page.
Example: Example:
>>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"}) >>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"})
""" """
page_imports = { page_imports = [i.package for i in imports.collapse().values() if i.install]
i
for i, tags in imports.items()
if i
not in [
*constants.PackageJson.DEPENDENCIES.keys(),
*constants.PackageJson.DEV_DEPENDENCIES.keys(),
]
and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
and i != ""
and any(tag.install for tag in tags)
}
frontend_packages = get_config().frontend_packages frontend_packages = get_config().frontend_packages
_frontend_packages = [] _frontend_packages = []
for package in frontend_packages: for package in frontend_packages:
@ -653,7 +642,7 @@ class App(Base):
) )
continue continue
_frontend_packages.append(package) _frontend_packages.append(package)
page_imports.update(_frontend_packages) page_imports.extend(_frontend_packages)
prerequisites.install_frontend_packages(page_imports, get_config()) prerequisites.install_frontend_packages(page_imports, get_config())
def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component: def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
@ -794,7 +783,7 @@ class App(Base):
self.style = evaluate_style_namespaces(self.style) self.style = evaluate_style_namespaces(self.style)
# Track imports and custom components found. # Track imports and custom components found.
all_imports = {} all_imports = ImportList()
custom_components = set() custom_components = set()
for _route, component in self.pages.items(): for _route, component in self.pages.items():
@ -804,7 +793,7 @@ class App(Base):
component.apply_theme(self.theme) component.apply_theme(self.theme)
# Add component._get_all_imports() to all_imports. # Add component._get_all_imports() to all_imports.
all_imports.update(component._get_all_imports()) all_imports.extend(component._get_all_imports())
# Add the app wrappers from this component. # Add the app wrappers from this component.
app_wrappers.update(component._get_all_app_wrap_components()) app_wrappers.update(component._get_all_app_wrap_components())
@ -932,10 +921,10 @@ class App(Base):
custom_components_imports, custom_components_imports,
) = custom_components_future.result() ) = custom_components_future.result()
compile_results.append(custom_components_result) compile_results.append(custom_components_result)
all_imports.update(custom_components_imports) all_imports.extend(custom_components_imports)
# Get imports from AppWrap components. # Get imports from AppWrap components.
all_imports.update(app_root._get_all_imports()) all_imports.extend(app_root._get_all_imports())
progress.advance(task) progress.advance(task)
@ -951,7 +940,7 @@ class App(Base):
# Setup the next.config.js # Setup the next.config.js
transpile_packages = [ transpile_packages = [
package package
for package, import_vars in all_imports.items() for package, import_vars in all_imports.collapse().items()
if any(import_var.transpile for import_var in import_vars) if any(import_var.transpile for import_var in import_vars)
] ]
prerequisites.update_next_config( prerequisites.update_next_config(

View File

@ -19,7 +19,7 @@ from reflex.config import get_config
from reflex.state import BaseState from reflex.state import BaseState
from reflex.style import LIGHT_COLOR_MODE from reflex.style import LIGHT_COLOR_MODE
from reflex.utils.exec import is_prod_mode from reflex.utils.exec import is_prod_mode
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportList, ImportVar
from reflex.vars import Var from reflex.vars import Var
@ -197,25 +197,34 @@ def _compile_components(
Returns: Returns:
The compiled components. The compiled components.
""" """
imports = { _imports = ImportList(
"react": [ImportVar(tag="memo")], [
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")], ImportVar(package="react", tag="memo"),
} ImportVar(
package=f"/{constants.Dirs.STATE_PATH}",
tag="E",
),
ImportVar(
package=f"/{constants.Dirs.STATE_PATH}",
tag="isTrue",
),
]
)
component_renders = [] component_renders = []
# Compile each component. # Compile each component.
for component in components: for component in components:
component_render, component_imports = utils.compile_custom_component(component) component_render, component_imports = utils.compile_custom_component(component)
component_renders.append(component_render) component_renders.append(component_render)
imports = utils.merge_imports(imports, component_imports) _imports.extend(component_imports)
# Compile the components page. # Compile the components page.
return ( return (
templates.COMPONENTS.render( templates.COMPONENTS.render(
imports=utils.compile_imports(imports), imports=utils.compile_imports(_imports),
components=component_renders, components=component_renders,
), ),
imports, _imports,
) )
@ -235,7 +244,7 @@ def _compile_stateful_components(
Returns: Returns:
The rendered stateful components code. The rendered stateful components code.
""" """
all_import_dicts = [] all_imports = []
rendered_components = {} rendered_components = {}
def get_shared_components_recursive(component: BaseComponent): def get_shared_components_recursive(component: BaseComponent):
@ -266,7 +275,7 @@ def _compile_stateful_components(
rendered_components.update( rendered_components.update(
{code: None for code in component._get_all_custom_code()}, {code: None for code in component._get_all_custom_code()},
) )
all_import_dicts.append(component._get_all_imports()) all_imports.extend(component._get_all_imports())
# Indicate that this component now imports from the shared file. # Indicate that this component now imports from the shared file.
component.rendered_as_shared = True component.rendered_as_shared = True
@ -275,9 +284,11 @@ def _compile_stateful_components(
get_shared_components_recursive(page_component) get_shared_components_recursive(page_component)
# Don't import from the file that we're about to create. # Don't import from the file that we're about to create.
all_imports = utils.merge_imports(*all_import_dicts) all_imports = ImportList(
all_imports.pop( imp
f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None for imp in all_imports
if imp.library
!= f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}"
) )
return templates.STATEFUL_COMPONENTS.render( return templates.STATEFUL_COMPONENTS.render(
@ -408,7 +419,7 @@ def compile_page(
def compile_components( def compile_components(
components: set[CustomComponent], components: set[CustomComponent],
) -> tuple[str, str, Dict[str, list[ImportVar]]]: ) -> tuple[str, str, ImportList]:
"""Compile the custom components. """Compile the custom components.
Args: Args:

View File

@ -88,16 +88,16 @@ def validate_imports(import_dict: imports.ImportDict):
used_tags[import_name] = lib used_tags[import_name] = lib
def compile_imports(import_dict: imports.ImportDict) -> list[dict]: def compile_imports(import_list: imports.ImportList) -> list[dict]:
"""Compile an import dict. """Compile an import list.
Args: Args:
import_dict: The import dict to compile. import_list: The import list to compile.
Returns: Returns:
The list of import dict. The list of template import dict.
""" """
collapsed_import_dict = imports.collapse_imports(import_dict) collapsed_import_dict = import_list.collapse()
validate_imports(collapsed_import_dict) validate_imports(collapsed_import_dict)
import_dicts = [] import_dicts = []
for lib, fields in collapsed_import_dict.items(): for lib, fields in collapsed_import_dict.items():
@ -114,9 +114,6 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
import_dicts.append(get_import_dict(module)) import_dicts.append(get_import_dict(module))
continue continue
# remove the version before rendering the package imports
lib = format.format_library_name(lib)
import_dicts.append(get_import_dict(lib, default, rest)) import_dicts.append(get_import_dict(lib, default, rest))
return import_dicts return import_dicts
@ -237,7 +234,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
def compile_custom_component( def compile_custom_component(
component: CustomComponent, component: CustomComponent,
) -> tuple[dict, imports.ImportDict]: ) -> tuple[dict, imports.ImportList]:
"""Compile a custom component. """Compile a custom component.
Args: Args:
@ -250,11 +247,12 @@ def compile_custom_component(
render = component.get_component(component) render = component.get_component(component)
# Get the imports. # Get the imports.
imports = { component_library_name = format.format_library_name(component.library)
lib: fields _imports = imports.ImportList(
for lib, fields in render._get_all_imports().items() imp
if lib != component.library for imp in render._get_all_imports()
} if imp.library != component_library_name
)
# Concatenate the props. # Concatenate the props.
props = [prop._var_name for prop in component.get_prop_vars()] props = [prop._var_name for prop in component.get_prop_vars()]
@ -268,7 +266,7 @@ def compile_custom_component(
"hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()}, "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
"custom_code": render._get_all_custom_code(), "custom_code": render._get_all_custom_code(),
}, },
imports, _imports,
) )

View File

@ -35,19 +35,18 @@ class ChakraComponent(Component):
@classmethod @classmethod
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _get_dependencies_imports(cls) -> imports.ImportDict: def _get_dependencies_imports(cls) -> imports.ImportList:
"""Get the imports from lib_dependencies for installing. """Get the imports from lib_dependencies for installing.
Returns: Returns:
The dependencies imports of the component. The dependencies imports of the component.
""" """
return { return [
dep: [imports.ImportVar(tag=None, render=False)] imports.ImportVar(
for dep in [ package="@chakra-ui/system@2.5.7", tag=None, render=False
"@chakra-ui/system@2.5.7", ),
"framer-motion@10.16.4", imports.ImportVar(package="framer-motion@10.16.4", tag=None, render=False),
] ]
}
class ChakraProvider(ChakraComponent): class ChakraProvider(ChakraComponent):

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
import itertools
import typing import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import lru_cache, wraps from functools import lru_cache, wraps
@ -95,11 +96,11 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def _get_all_imports(self) -> imports.ImportDict: def _get_all_imports(self) -> imports.ImportList:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.
Returns: Returns:
The import dict with the required imports. The list of all required ImportVar.
""" """
@abstractmethod @abstractmethod
@ -994,17 +995,22 @@ class Component(BaseComponent, ABC):
# Return the dynamic imports # Return the dynamic imports
return dynamic_imports return dynamic_imports
def _get_props_imports(self) -> List[str]: def _get_props_imports(self) -> imports.ImportList:
"""Get the imports needed for components props. """Get the imports needed for components props.
Returns: Returns:
The imports for the components props of the component. The imports for the components props of the component.
""" """
return [ return imports.ImportList(
getattr(self, prop)._get_all_imports() sum(
for prop in self.get_component_props() (
if getattr(self, prop) is not None getattr(self, prop)._get_all_imports()
] for prop in self.get_component_props()
if getattr(self, prop) is not None
),
[],
)
)
def _should_transpile(self, dep: str | None) -> bool: def _should_transpile(self, dep: str | None) -> bool:
"""Check if a dependency should be transpiled. """Check if a dependency should be transpiled.
@ -1020,97 +1026,129 @@ class Component(BaseComponent, ABC):
or format.format_library_name(dep or "") in self.transpile_packages or format.format_library_name(dep or "") in self.transpile_packages
) )
def _get_dependencies_imports(self) -> imports.ImportDict: def _get_dependencies_imports(self) -> imports.ImportList:
"""Get the imports from lib_dependencies for installing. """Get the imports from lib_dependencies for installing.
Returns: Returns:
The dependencies imports of the component. The dependencies imports of the component.
""" """
return { return imports.ImportList(
dep: [ ImportVar(
ImportVar( package=dep,
tag=None, tag=None,
render=False, render=False,
transpile=self._should_transpile(dep), transpile=self._should_transpile(dep),
) )
]
for dep in self.lib_dependencies for dep in self.lib_dependencies
} )
def _get_hooks_imports(self) -> imports.ImportDict: def _get_hooks_imports(self) -> imports.ImportList:
"""Get the imports required by certain hooks. """Get the imports required by certain hooks.
Returns: Returns:
The imports required for all selected hooks. The imports required for all selected hooks.
""" """
_imports = {} _imports = imports.ImportList()
if self._get_ref_hook(): if self._get_ref_hook():
# Handle hooks needed for attaching react refs to DOM nodes. # Handle hooks needed for attaching react refs to DOM nodes.
_imports.setdefault("react", set()).add(ImportVar(tag="useRef")) _imports.extend(
_imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs")) [
ImportVar(package="react", tag="useRef"),
ImportVar(package=f"/{Dirs.STATE_PATH}", tag="refs"),
]
)
if self._get_mount_lifecycle_hook(): if self._get_mount_lifecycle_hook():
# Handle hooks for `on_mount` / `on_unmount`. # Handle hooks for `on_mount` / `on_unmount`.
_imports.setdefault("react", set()).add(ImportVar(tag="useEffect")) _imports.append(ImportVar(package="react", tag="useEffect"))
if self._get_special_hooks(): if self._get_special_hooks():
# Handle additional internal hooks (autofocus, etc). # Handle additional internal hooks (autofocus, etc).
_imports.setdefault("react", set()).update( _imports.extend(
{ [
ImportVar(tag="useRef"), ImportVar(package="react", tag="useEffect"),
ImportVar(tag="useEffect"), ImportVar(package="react", tag="useRef"),
}, ]
) )
user_hooks = self._get_hooks() user_hooks = self._get_hooks()
if user_hooks is not None and isinstance(user_hooks, Var): if user_hooks is not None and isinstance(user_hooks, Var):
_imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore _imports.extend(user_hooks._var_data.imports)
return _imports return _imports
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component. """Deprecated method to get all the libraries and fields used by the component.
Returns: Returns:
The imports needed by the component. The imports needed by the component.
""" """
_imports = {} return {}
def _get_imports_list(self) -> imports.ImportList:
"""Internal method to get the imports as a list.
Returns:
The imports as a list.
"""
_imports = imports.ImportList(
itertools.chain(
self._get_props_imports(),
self._get_dependencies_imports(),
self._get_hooks_imports(),
)
)
# Handle deprecated _get_imports
import_dict = self._get_imports()
if import_dict:
console.deprecate(
feature_name="_get_imports",
reason="use add_imports instead",
deprecation_version="0.5.0",
removal_version="0.6.0",
)
_imports.extend(imports.ImportList.from_import_dict(import_dict))
# Import this component's tag from the main library. # Import this component's tag from the main library.
if self.library is not None and self.tag is not None: if self.library is not None and self.tag is not None:
_imports[self.library] = {self.import_var} _imports.append(self.import_var)
# Get static imports required for event processing. # Get static imports required for event processing.
event_imports = Imports.EVENTS if self.event_triggers else {} if self.event_triggers:
_imports.append(Imports.EVENTS)
# Collect imports from Vars used directly by this component. # Collect imports from Vars used directly by this component.
var_imports = [ for var in self._get_vars():
var._var_data.imports for var in self._get_vars() if var._var_data if var._var_data:
] _imports.extend(var._var_data.imports)
return _imports
return imports.merge_imports( def _get_all_imports(self, collapse: bool = False) -> imports.ImportList:
*self._get_props_imports(),
self._get_dependencies_imports(),
self._get_hooks_imports(),
_imports,
event_imports,
*var_imports,
)
def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component and its children. """Get all the libraries and fields that are used by the component and its children.
Args: Args:
collapse: Whether to collapse the imports by removing duplicates. collapse: Whether to collapse the imports into a dict (deprecated).
Returns: Returns:
The import dict with the required imports. The list of all required imports.
""" """
_imports = imports.merge_imports( _imports = imports.ImportList(
self._get_imports(), *[child._get_all_imports() for child in self.children] self._get_imports_list()
+ sum((child._get_all_imports() for child in self.children), [])
) )
return imports.collapse_imports(_imports) if collapse else _imports
if collapse:
console.deprecate(
feature_name="collapse kwarg to _get_all_imports",
reason="use ImportList.collapse instead",
deprecation_version="0.5.0",
removal_version="0.6.0",
)
return _imports.collapse() # type: ignore
return _imports
def _get_mount_lifecycle_hook(self) -> str | None: def _get_mount_lifecycle_hook(self) -> str | None:
"""Generate the component lifecycle hook. """Generate the component lifecycle hook.
@ -1296,6 +1334,7 @@ class Component(BaseComponent, ABC):
tag = self.tag.partition(".")[0] if self.tag else None tag = self.tag.partition(".")[0] if self.tag else None
alias = self.alias.partition(".")[0] if self.alias else None alias = self.alias.partition(".")[0] if self.alias else None
return ImportVar( return ImportVar(
package=self.library,
tag=tag, tag=tag,
is_default=self.is_default, is_default=self.is_default,
alias=alias, alias=alias,
@ -1575,7 +1614,6 @@ class NoSSRComponent(Component):
return imports.merge_imports( return imports.merge_imports(
dynamic_import, dynamic_import,
_imports, _imports,
self._get_dependencies_imports(),
) )
def _get_dynamic_imports(self) -> str: def _get_dynamic_imports(self) -> str:
@ -1893,18 +1931,21 @@ class StatefulComponent(BaseComponent):
""" """
return {} return {}
def _get_all_imports(self) -> imports.ImportDict: def _get_all_imports(self) -> imports.ImportList:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.
Returns: Returns:
The import dict with the required imports. The list of all required imports.
""" """
if self.rendered_as_shared: if self.rendered_as_shared:
return { return imports.ImportList(
f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [ [
ImportVar(tag=self.tag) imports.ImportVar(
package=f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}",
tag=self.tag,
)
] ]
} )
return self.component._get_all_imports() return self.component._get_all_imports()
def _get_all_dynamic_imports(self) -> set[str]: def _get_all_dynamic_imports(self) -> set[str]:

View File

@ -12,9 +12,9 @@ from reflex.style import LIGHT_COLOR_MODE, color_mode
from reflex.utils import format, imports from reflex.utils import format, imports
from reflex.vars import BaseVar, Var, VarData from reflex.vars import BaseVar, Var, VarData
_IS_TRUE_IMPORT = { _IS_TRUE_IMPORT = imports.ImportList(
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")}, [imports.ImportVar(library=f"/{Dirs.STATE_PATH}", tag="isTrue")]
} )
class Cond(MemoizationLeaf): class Cond(MemoizationLeaf):
@ -95,11 +95,13 @@ class Cond(MemoizationLeaf):
cond_state=f"isTrue({self.cond._var_full_name})", cond_state=f"isTrue({self.cond._var_full_name})",
) )
def _get_imports(self) -> imports.ImportDict: def _get_imports_list(self) -> imports.ImportList:
return imports.merge_imports( return imports.ImportList(
super()._get_imports(), [
getattr(self.cond._var_data, "imports", {}), *super()._get_imports_list(),
_IS_TRUE_IMPORT, *getattr(self.cond._var_data, "imports", []),
*_IS_TRUE_IMPORT,
]
) )
def _apply_theme(self, theme: Component): def _apply_theme(self, theme: Component):

View File

@ -6,7 +6,7 @@ from types import SimpleNamespace
from reflex.base import Base from reflex.base import Base
from reflex.constants import Dirs from reflex.constants import Dirs
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportList, ImportVar
# The prefix used to create setters for state vars. # The prefix used to create setters for state vars.
SETTER_PREFIX = "set_" SETTER_PREFIX = "set_"
@ -102,11 +102,11 @@ class ComponentName(Enum):
class Imports(SimpleNamespace): class Imports(SimpleNamespace):
"""Common sets of import vars.""" """Common sets of import vars."""
EVENTS = { EVENTS: ImportList = [
"react": {ImportVar(tag="useContext")}, ImportVar(package="react", tag="useContext"),
f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")}, ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"),
f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)}, ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT),
} ]
class Hooks(SimpleNamespace): class Hooks(SimpleNamespace):

View File

@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
from reflex import constants from reflex import constants
from reflex.utils import exceptions, serializers, types from reflex.utils import exceptions, serializers, types
from reflex.utils.imports import split_library_name_version
from reflex.utils.serializers import serialize from reflex.utils.serializers import serialize
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, Var
@ -716,11 +717,7 @@ def format_library_name(library_fullname: str):
Returns: Returns:
The name without the @version if it was part of the name The name without the @version if it was part of the name
""" """
lib, at, version = library_fullname.rpartition("@") return split_library_name_version(library_fullname)[0]
if not lib:
lib = at + version
return lib
def json_dumps(obj: Any) -> str: def json_dumps(obj: Any) -> str:

View File

@ -3,9 +3,10 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional from typing import Dict, List, Optional, Set
from reflex.base import Base from reflex.base import Base
from reflex.constants.installer import PackageJson
def merge_imports(*imports) -> ImportDict: def merge_imports(*imports) -> ImportDict:
@ -36,9 +37,29 @@ def collapse_imports(imports: ImportDict) -> ImportDict:
return {lib: list(set(import_vars)) for lib, import_vars in imports.items()} return {lib: list(set(import_vars)) for lib, import_vars in imports.items()}
def split_library_name_version(library_fullname: str):
"""Split the name of a library from its version.
Args:
library_fullname: The fullname of the library.
Returns:
A tuple of the library name and version.
"""
lib, at, version = library_fullname.rpartition("@")
if not lib:
lib = at + version
version = None
return lib, version
class ImportVar(Base): class ImportVar(Base):
"""An import var.""" """An import var."""
# The package name associated with the tag
library: Optional[str]
# The name of the import tag. # The name of the import tag.
tag: Optional[str] tag: Optional[str]
@ -48,6 +69,12 @@ class ImportVar(Base):
# The tag alias. # The tag alias.
alias: Optional[str] = None alias: Optional[str] = None
# The following fields provide extra information about the import,
# but are not factored in when considering hash or equality
# The version of the package
version: Optional[str]
# Whether this import need to install the associated lib # Whether this import need to install the associated lib
install: Optional[bool] = True install: Optional[bool] = True
@ -58,6 +85,34 @@ class ImportVar(Base):
# https://nextjs.org/docs/app/api-reference/next-config-js/transpilePackages # https://nextjs.org/docs/app/api-reference/next-config-js/transpilePackages
transpile: Optional[bool] = False transpile: Optional[bool] = False
def __init__(
self,
*,
package: Optional[str] = None,
**kwargs,
):
if package is not None:
if (
kwargs.get("library", None) is not None
or kwargs.get("version", None) is not None
):
raise ValueError(
"Cannot provide 'library' or 'version' as keyword arguments when "
"specifying 'package' as an argument"
)
kwargs["library"], kwargs["version"] = split_library_name_version(package)
install = (
package is not None
# TODO: handle version conflicts
and package not in PackageJson.DEPENDENCIES
and package not in PackageJson.DEV_DEPENDENCIES
and not any(package.startswith(prefix) for prefix in ["/", ".", "next/"])
and package != ""
)
kwargs.setdefault("install", install)
super().__init__(**kwargs)
@property @property
def name(self) -> str: def name(self) -> str:
"""The name of the import. """The name of the import.
@ -72,6 +127,17 @@ class ImportVar(Base):
else: else:
return self.tag or "" return self.tag or ""
@property
def package(self) -> str:
"""The package to install for this import
Returns:
The library name and (optional) version to be installed by npm/bun.
"""
if self.version:
return f"{self.library}@{self.version}"
return self.library
def __hash__(self) -> int: def __hash__(self) -> int:
"""Define a hash function for the import var. """Define a hash function for the import var.
@ -80,14 +146,97 @@ class ImportVar(Base):
""" """
return hash( return hash(
( (
self.library,
self.tag, self.tag,
self.is_default, self.is_default,
self.alias, self.alias,
self.install, # These do not fundamentally change the import in any way
self.render, # self.install,
self.transpile, # self.render,
# self.transpile,
) )
) )
def __eq__(self, other: ImportVar) -> bool:
"""Define equality for the import var.
ImportDict = Dict[str, List[ImportVar]] Args:
other: The other import var to compare.
Returns:
Whether the two import vars are equal.
"""
if type(self) != type(other):
return NotImplemented
return (self.library, self.tag, self.is_default, self.alias) == (
other.library,
other.tag,
other.is_default,
other.alias,
)
def collapse(self, other_import_var: ImportVar) -> ImportVar:
"""Collapse two import vars together.
Args:
other_import_var: The other import var to collapse with.
Returns:
The collapsed import var with sticky props perserved.
"""
if self != other_import_var:
raise ValueError("Cannot collapse two import vars with different hashes")
if self.version is not None and other_import_var.version is not None:
if self.version != other_import_var.version:
raise ValueError(
"Cannot collapse two import vars with conflicting version specifiers: "
f"{self} {other_import_var}"
)
return type(self)(
library=self.library,
version=self.version or other_import_var.version,
tag=self.tag,
is_default=self.is_default,
alias=self.alias,
install=self.install or other_import_var.install,
render=self.render or other_import_var.render,
transpile=self.transpile or other_import_var.transpile,
)
class ImportList(List[ImportVar]):
"""A list of import vars."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for ix, value in enumerate(self):
if not isinstance(value, ImportVar):
# convert dicts to ImportVar
self[ix] = ImportVar(**value)
@classmethod
def from_import_dict(cls, import_dict: ImportDict) -> ImportList:
return [
ImportVar(package=lib, **imp.dict())
for lib, imps in import_dict.items()
for imp in imps
]
def collapse(self) -> ImportDict:
"""When collapsing an import list, prefer packages with version specifiers."""
collapsed = {}
for imp in self:
collapsed.setdefault(imp.library, {})
if imp in collapsed[imp.library]:
# Need to check if the current import has any special properties that need to
# be preserved, like the version specifier, install, or transpile.
existing_imp = collapsed[imp.library][imp]
collapsed[imp.library][imp] = existing_imp.collapse(imp)
else:
collapsed[imp.library][imp] = imp
return {lib: set(imps) for lib, imps in collapsed.items()}
ImportDict = Dict[str, Set[ImportVar]]

View File

@ -37,7 +37,7 @@ from reflex.base import Base
from reflex.utils import console, format, imports, serializers, types from reflex.utils import console, format, imports, serializers, types
# This module used to export ImportVar itself, so we still import it for export here # This module used to export ImportVar itself, so we still import it for export here
from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.imports import ImportDict, ImportList, ImportVar
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.state import BaseState from reflex.state import BaseState
@ -116,7 +116,7 @@ class VarData(Base):
state: str = "" state: str = ""
# Imports needed to render this var # Imports needed to render this var
imports: ImportDict = {} imports: ImportList = []
# Hooks that need to be present in the component to render this var # Hooks that need to be present in the component to render this var
hooks: Dict[str, None] = {} hooks: Dict[str, None] = {}
@ -126,6 +126,19 @@ class VarData(Base):
# segments. # segments.
interpolations: List[Tuple[int, int]] = [] interpolations: List[Tuple[int, int]] = []
def __init__(self, imports: ImportDict | ImportList = None, **kwargs):
if isinstance(imports, dict):
imports = ImportList.from_import_dict(imports)
console.deprecate(
feature_name="Passing ImportDict for VarData",
reason="use ImportList instead",
deprecation_version="0.5.0",
removal_version="0.6.0",
)
elif imports is None:
imports = []
super().__init__(imports=imports, **kwargs)
@classmethod @classmethod
def merge(cls, *others: VarData | None) -> VarData | None: def merge(cls, *others: VarData | None) -> VarData | None:
"""Merge multiple var data objects. """Merge multiple var data objects.
@ -137,14 +150,14 @@ class VarData(Base):
The merged var data object. The merged var data object.
""" """
state = "" state = ""
_imports = {} _imports = []
hooks = {} hooks = {}
interpolations = [] interpolations = []
for var_data in others: for var_data in others:
if var_data is None: if var_data is None:
continue continue
state = state or var_data.state state = state or var_data.state
_imports = imports.merge_imports(_imports, var_data.imports) _imports.extend(var_data.imports)
hooks.update(var_data.hooks) hooks.update(var_data.hooks)
interpolations += var_data.interpolations interpolations += var_data.interpolations
@ -180,11 +193,18 @@ class VarData(Base):
# Don't compare interpolations - that's added in by the decoder, and # Don't compare interpolations - that's added in by the decoder, and
# not part of the vardata itself. # not part of the vardata itself.
if not isinstance(self.imports, ImportList):
self_imports = ImportList(self.imports).collapse()
else:
self_imports = self.imports.collapse()
if not isinstance(other.imports, ImportList):
other_imports = ImportList(other.imports).collapse()
else:
other_imports = other.imports.collapse()
return ( return (
self.state == other.state self.state == other.state
and self.hooks.keys() == other.hooks.keys() and self.hooks.keys() == other.hooks.keys()
and imports.collapse_imports(self.imports) and self_imports == other_imports
== imports.collapse_imports(other.imports)
) )
def dict(self) -> dict: def dict(self) -> dict:
@ -196,10 +216,7 @@ class VarData(Base):
return { return {
"state": self.state, "state": self.state,
"interpolations": list(self.interpolations), "interpolations": list(self.interpolations),
"imports": { "imports": [import_var.dict() for import_var in self.imports],
lib: [import_var.dict() for import_var in import_vars]
for lib, import_vars in self.imports.items()
},
"hooks": self.hooks, "hooks": self.hooks,
} }

View File

@ -4,8 +4,7 @@ from typing import List
import pytest import pytest
from reflex.compiler import compiler, utils from reflex.compiler import compiler, utils
from reflex.utils import imports from reflex.utils.imports import ImportList, ImportVar
from reflex.utils.imports import ImportVar
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -48,43 +47,56 @@ def test_compile_import_statement(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"import_dict,test_dicts", "import_list,test_dicts",
[ [
({}, []), (ImportList(), []),
( (
{"axios": [ImportVar(tag="axios", is_default=True)]}, ImportList([ImportVar(library="axios", tag="axios", is_default=True)]),
[{"lib": "axios", "default": "axios", "rest": []}], [{"lib": "axios", "default": "axios", "rest": []}],
), ),
( (
{"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]}, ImportList(
[
ImportVar(library="axios", tag="foo"),
ImportVar(library="axios", tag="bar"),
]
),
[{"lib": "axios", "default": "", "rest": ["bar", "foo"]}], [{"lib": "axios", "default": "", "rest": ["bar", "foo"]}],
), ),
( (
{ ImportList(
"axios": [ [
ImportVar(tag="axios", is_default=True), ImportVar(library="axios", tag="axios", is_default=True),
ImportVar(tag="foo"), ImportVar(library="axios", tag="foo"),
ImportVar(tag="bar"), ImportVar(library="axios", tag="bar"),
], ImportVar(library="react", tag="react", is_default=True),
"react": [ImportVar(tag="react", is_default=True)], ]
}, ),
[ [
{"lib": "axios", "default": "axios", "rest": ["bar", "foo"]}, {"lib": "axios", "default": "axios", "rest": ["bar", "foo"]},
{"lib": "react", "default": "react", "rest": []}, {"lib": "react", "default": "react", "rest": []},
], ],
), ),
( (
{"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]}, ImportList(
[
ImportVar(library="", tag="lib1.js"),
ImportVar(library="", tag="lib2.js"),
]
),
[ [
{"lib": "lib1.js", "default": "", "rest": []}, {"lib": "lib1.js", "default": "", "rest": []},
{"lib": "lib2.js", "default": "", "rest": []}, {"lib": "lib2.js", "default": "", "rest": []},
], ],
), ),
( (
{ ImportList(
"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")], [
"axios": [ImportVar(tag="axios", is_default=True)], ImportVar(library="", tag="lib1.js"),
}, ImportVar(library="", tag="lib2.js"),
ImportVar(library="axios", tag="axios", is_default=True),
]
),
[ [
{"lib": "lib1.js", "default": "", "rest": []}, {"lib": "lib1.js", "default": "", "rest": []},
{"lib": "lib2.js", "default": "", "rest": []}, {"lib": "lib2.js", "default": "", "rest": []},
@ -93,14 +105,14 @@ def test_compile_import_statement(
), ),
], ],
) )
def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]): def test_compile_imports(import_list: ImportList, test_dicts: List[dict]):
"""Test the compile_imports function. """Test the compile_imports function.
Args: Args:
import_dict: The import dictionary. import_list: The list of ImportVar.
test_dicts: The expected output. test_dicts: The expected output.
""" """
imports = utils.compile_imports(import_dict) imports = utils.compile_imports(import_list)
for import_dict, test_dict in zip(imports, test_dicts): for import_dict, test_dict in zip(imports, test_dicts):
assert import_dict["lib"] == test_dict["lib"] assert import_dict["lib"] == test_dict["lib"]
assert import_dict["default"] == test_dict["default"] assert import_dict["default"] == test_dict["default"]

View File

@ -20,7 +20,7 @@ def test_connection_banner():
"react", "react",
"/utils/context", "/utils/context",
"/utils/state", "/utils/state",
"@radix-ui/themes@^3.0.0", "@radix-ui/themes",
"/env.json", "/env.json",
] ]
@ -36,7 +36,7 @@ def test_connection_modal():
"react", "react",
"/utils/context", "/utils/context",
"/utils/state", "/utils/state",
"@radix-ui/themes@^3.0.0", "@radix-ui/themes",
"/env.json", "/env.json",
] ]

View File

@ -296,11 +296,11 @@ def test_get_imports(component1, component2):
""" """
c1 = component1.create() c1 = component1.create()
c2 = component2.create(c1) c2 = component2.create(c1)
assert c1._get_all_imports() == {"react": [ImportVar(tag="Component")]} assert c1._get_all_imports() == [ImportVar(library="react", tag="Component")]
assert c2._get_all_imports() == { assert c2._get_all_imports() == [
"react-redux": [ImportVar(tag="connect")], ImportVar(library="react-redux", tag="connect"),
"react": [ImportVar(tag="Component")], ImportVar(library="react", tag="Component"),
} ]
def test_get_custom_code(component1, component2): def test_get_custom_code(component1, component2):
@ -1514,22 +1514,24 @@ def test_custom_component_get_imports():
custom_comp = wrapper() custom_comp = wrapper()
# Inner is not imported directly, but it is imported by the custom component. # Inner is not imported directly, but it is imported by the custom component.
assert "inner" not in custom_comp._get_all_imports() inner_import = ImportVar(library="inner", tag="Inner")
assert inner_import not in custom_comp._get_all_imports()
# The imports are only resolved during compilation. # The imports are only resolved during compilation.
_, _, imports_inner = compile_components(custom_comp._get_all_custom_components()) _, _, imports_inner = compile_components(custom_comp._get_all_custom_components())
assert "inner" in imports_inner assert inner_import in imports_inner
outer_comp = outer(c=wrapper()) outer_comp = outer(c=wrapper())
# Libraries are not imported directly, but are imported by the custom component. # Libraries are not imported directly, but are imported by the custom component.
assert "inner" not in outer_comp._get_all_imports() other_import = ImportVar(library="other", tag="Other")
assert "other" not in outer_comp._get_all_imports() assert inner_import not in outer_comp._get_all_imports()
assert other_import not in outer_comp._get_all_imports()
# The imports are only resolved during compilation. # The imports are only resolved during compilation.
_, _, imports_outer = compile_components(outer_comp._get_all_custom_components()) _, _, imports_outer = compile_components(outer_comp._get_all_custom_components())
assert "inner" in imports_outer assert inner_import in imports_outer
assert "other" in imports_outer assert other_import in imports_outer
def test_custom_component_declare_event_handlers_in_fields(): def test_custom_component_declare_event_handlers_in_fields():

View File

@ -837,7 +837,7 @@ def test_state_with_initial_computed_var(
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"), (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
( (
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}", f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}</reflex.Var>{state.myvar}', 'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": [{"library": "/utils/context", "tag": "StateContexts", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}, {"library": "react", "tag": "useContext", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}], "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}</reflex.Var>{state.myvar}',
), ),
( (
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}", f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",