From 35252464a0442d7fc7a633846d0b7df5bd8e6c3f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 29 Apr 2024 19:05:52 -0700 Subject: [PATCH] WiP: use ImportList internally instead of ImportDict * deprecate `_get_imports` in favor of new `_get_imports_list` * `_get_all_imports` now returns an `ImportList` * Compiler uses `ImportList.collapse` to get an `ImportDict` --- reflex/app.py | 31 ++---- reflex/compiler/compiler.py | 39 ++++--- reflex/compiler/utils.py | 28 +++-- reflex/components/chakra/base.py | 15 ++- reflex/components/component.py | 159 +++++++++++++++++---------- reflex/components/core/cond.py | 18 +-- reflex/constants/compiler.py | 12 +- reflex/utils/format.py | 7 +- reflex/utils/imports.py | 159 ++++++++++++++++++++++++++- reflex/vars.py | 37 +++++-- tests/compiler/test_compiler.py | 56 ++++++---- tests/components/core/test_banner.py | 4 +- tests/components/test_component.py | 24 ++-- tests/test_var.py | 2 +- 14 files changed, 404 insertions(+), 187 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index c47077d0e..b3d9149e6 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -79,7 +79,7 @@ from reflex.state import ( ) from reflex.utils import console, exceptions, format, prerequisites, types from reflex.utils.exec import is_testing_env, should_skip_compile -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportList # Define custom types. ComponentCallable = Callable[[], Component] @@ -618,27 +618,16 @@ class App(Base): admin.mount_to(self.api) - def get_frontend_packages(self, imports: Dict[str, set[ImportVar]]): + def get_frontend_packages(self, imports: ImportList): """Gets the frontend packages to be installed and filters out the unnecessary ones. Args: - imports: A dictionary containing the imports used in the current page. + imports: A list containing the imports used in the current page. Example: >>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"}) """ - page_imports = { - i - for i, tags in imports.items() - if i - not in [ - *constants.PackageJson.DEPENDENCIES.keys(), - *constants.PackageJson.DEV_DEPENDENCIES.keys(), - ] - and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"]) - and i != "" - and any(tag.install for tag in tags) - } + page_imports = [i.package for i in imports.collapse().values() if i.install] frontend_packages = get_config().frontend_packages _frontend_packages = [] for package in frontend_packages: @@ -653,7 +642,7 @@ class App(Base): ) continue _frontend_packages.append(package) - page_imports.update(_frontend_packages) + page_imports.extend(_frontend_packages) prerequisites.install_frontend_packages(page_imports, get_config()) def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component: @@ -794,7 +783,7 @@ class App(Base): self.style = evaluate_style_namespaces(self.style) # Track imports and custom components found. - all_imports = {} + all_imports = ImportList() custom_components = set() for _route, component in self.pages.items(): @@ -804,7 +793,7 @@ class App(Base): component.apply_theme(self.theme) # Add component._get_all_imports() to all_imports. - all_imports.update(component._get_all_imports()) + all_imports.extend(component._get_all_imports()) # Add the app wrappers from this component. app_wrappers.update(component._get_all_app_wrap_components()) @@ -932,10 +921,10 @@ class App(Base): custom_components_imports, ) = custom_components_future.result() compile_results.append(custom_components_result) - all_imports.update(custom_components_imports) + all_imports.extend(custom_components_imports) # Get imports from AppWrap components. - all_imports.update(app_root._get_all_imports()) + all_imports.extend(app_root._get_all_imports()) progress.advance(task) @@ -951,7 +940,7 @@ class App(Base): # Setup the next.config.js transpile_packages = [ package - for package, import_vars in all_imports.items() + for package, import_vars in all_imports.collapse().items() if any(import_var.transpile for import_var in import_vars) ] prerequisites.update_next_config( diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 89ac867f7..dae97c153 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -19,7 +19,7 @@ from reflex.config import get_config from reflex.state import BaseState from reflex.style import LIGHT_COLOR_MODE from reflex.utils.exec import is_prod_mode -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportList, ImportVar from reflex.vars import Var @@ -197,25 +197,34 @@ def _compile_components( Returns: The compiled components. """ - imports = { - "react": [ImportVar(tag="memo")], - f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")], - } + _imports = ImportList( + [ + ImportVar(package="react", tag="memo"), + ImportVar( + package=f"/{constants.Dirs.STATE_PATH}", + tag="E", + ), + ImportVar( + package=f"/{constants.Dirs.STATE_PATH}", + tag="isTrue", + ), + ] + ) component_renders = [] # Compile each component. for component in components: component_render, component_imports = utils.compile_custom_component(component) component_renders.append(component_render) - imports = utils.merge_imports(imports, component_imports) + _imports.extend(component_imports) # Compile the components page. return ( templates.COMPONENTS.render( - imports=utils.compile_imports(imports), + imports=utils.compile_imports(_imports), components=component_renders, ), - imports, + _imports, ) @@ -235,7 +244,7 @@ def _compile_stateful_components( Returns: The rendered stateful components code. """ - all_import_dicts = [] + all_imports = [] rendered_components = {} def get_shared_components_recursive(component: BaseComponent): @@ -266,7 +275,7 @@ def _compile_stateful_components( rendered_components.update( {code: None for code in component._get_all_custom_code()}, ) - all_import_dicts.append(component._get_all_imports()) + all_imports.extend(component._get_all_imports()) # Indicate that this component now imports from the shared file. component.rendered_as_shared = True @@ -275,9 +284,11 @@ def _compile_stateful_components( get_shared_components_recursive(page_component) # Don't import from the file that we're about to create. - all_imports = utils.merge_imports(*all_import_dicts) - all_imports.pop( - f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None + all_imports = ImportList( + imp + for imp in all_imports + if imp.library + != f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}" ) return templates.STATEFUL_COMPONENTS.render( @@ -408,7 +419,7 @@ def compile_page( def compile_components( components: set[CustomComponent], -) -> tuple[str, str, Dict[str, list[ImportVar]]]: +) -> tuple[str, str, ImportList]: """Compile the custom components. Args: diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 14d7d4d36..ee65d08a2 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -88,16 +88,16 @@ def validate_imports(import_dict: imports.ImportDict): used_tags[import_name] = lib -def compile_imports(import_dict: imports.ImportDict) -> list[dict]: - """Compile an import dict. +def compile_imports(import_list: imports.ImportList) -> list[dict]: + """Compile an import list. Args: - import_dict: The import dict to compile. + import_list: The import list to compile. Returns: - The list of import dict. + The list of template import dict. """ - collapsed_import_dict = imports.collapse_imports(import_dict) + collapsed_import_dict = import_list.collapse() validate_imports(collapsed_import_dict) import_dicts = [] for lib, fields in collapsed_import_dict.items(): @@ -114,9 +114,6 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]: import_dicts.append(get_import_dict(module)) continue - # remove the version before rendering the package imports - lib = format.format_library_name(lib) - import_dicts.append(get_import_dict(lib, default, rest)) return import_dicts @@ -237,7 +234,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]: def compile_custom_component( component: CustomComponent, -) -> tuple[dict, imports.ImportDict]: +) -> tuple[dict, imports.ImportList]: """Compile a custom component. Args: @@ -250,11 +247,12 @@ def compile_custom_component( render = component.get_component(component) # Get the imports. - imports = { - lib: fields - for lib, fields in render._get_all_imports().items() - if lib != component.library - } + component_library_name = format.format_library_name(component.library) + _imports = imports.ImportList( + imp + for imp in render._get_all_imports() + if imp.library != component_library_name + ) # Concatenate the props. props = [prop._var_name for prop in component.get_prop_vars()] @@ -268,7 +266,7 @@ def compile_custom_component( "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()}, "custom_code": render._get_all_custom_code(), }, - imports, + _imports, ) diff --git a/reflex/components/chakra/base.py b/reflex/components/chakra/base.py index 9e1f3f698..95e457a05 100644 --- a/reflex/components/chakra/base.py +++ b/reflex/components/chakra/base.py @@ -35,19 +35,18 @@ class ChakraComponent(Component): @classmethod @lru_cache(maxsize=None) - def _get_dependencies_imports(cls) -> imports.ImportDict: + def _get_dependencies_imports(cls) -> imports.ImportList: """Get the imports from lib_dependencies for installing. Returns: The dependencies imports of the component. """ - return { - dep: [imports.ImportVar(tag=None, render=False)] - for dep in [ - "@chakra-ui/system@2.5.7", - "framer-motion@10.16.4", - ] - } + return [ + imports.ImportVar( + package="@chakra-ui/system@2.5.7", tag=None, render=False + ), + imports.ImportVar(package="framer-motion@10.16.4", tag=None, render=False), + ] class ChakraProvider(ChakraComponent): diff --git a/reflex/components/component.py b/reflex/components/component.py index 9dd11254c..96834be08 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import itertools import typing from abc import ABC, abstractmethod from functools import lru_cache, wraps @@ -95,11 +96,11 @@ class BaseComponent(Base, ABC): """ @abstractmethod - def _get_all_imports(self) -> imports.ImportDict: + def _get_all_imports(self) -> imports.ImportList: """Get all the libraries and fields that are used by the component. Returns: - The import dict with the required imports. + The list of all required ImportVar. """ @abstractmethod @@ -994,17 +995,22 @@ class Component(BaseComponent, ABC): # Return the dynamic imports return dynamic_imports - def _get_props_imports(self) -> List[str]: + def _get_props_imports(self) -> imports.ImportList: """Get the imports needed for components props. Returns: - The imports for the components props of the component. + The imports for the components props of the component. """ - return [ - getattr(self, prop)._get_all_imports() - for prop in self.get_component_props() - if getattr(self, prop) is not None - ] + return imports.ImportList( + sum( + ( + getattr(self, prop)._get_all_imports() + for prop in self.get_component_props() + if getattr(self, prop) is not None + ), + [], + ) + ) def _should_transpile(self, dep: str | None) -> bool: """Check if a dependency should be transpiled. @@ -1020,97 +1026,129 @@ class Component(BaseComponent, ABC): or format.format_library_name(dep or "") in self.transpile_packages ) - def _get_dependencies_imports(self) -> imports.ImportDict: + def _get_dependencies_imports(self) -> imports.ImportList: """Get the imports from lib_dependencies for installing. Returns: The dependencies imports of the component. """ - return { - dep: [ - ImportVar( - tag=None, - render=False, - transpile=self._should_transpile(dep), - ) - ] + return imports.ImportList( + ImportVar( + package=dep, + tag=None, + render=False, + transpile=self._should_transpile(dep), + ) for dep in self.lib_dependencies - } + ) - def _get_hooks_imports(self) -> imports.ImportDict: + def _get_hooks_imports(self) -> imports.ImportList: """Get the imports required by certain hooks. Returns: The imports required for all selected hooks. """ - _imports = {} + _imports = imports.ImportList() if self._get_ref_hook(): # Handle hooks needed for attaching react refs to DOM nodes. - _imports.setdefault("react", set()).add(ImportVar(tag="useRef")) - _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs")) + _imports.extend( + [ + ImportVar(package="react", tag="useRef"), + ImportVar(package=f"/{Dirs.STATE_PATH}", tag="refs"), + ] + ) if self._get_mount_lifecycle_hook(): # Handle hooks for `on_mount` / `on_unmount`. - _imports.setdefault("react", set()).add(ImportVar(tag="useEffect")) + _imports.append(ImportVar(package="react", tag="useEffect")) if self._get_special_hooks(): # Handle additional internal hooks (autofocus, etc). - _imports.setdefault("react", set()).update( - { - ImportVar(tag="useRef"), - ImportVar(tag="useEffect"), - }, + _imports.extend( + [ + ImportVar(package="react", tag="useEffect"), + ImportVar(package="react", tag="useRef"), + ] ) user_hooks = self._get_hooks() if user_hooks is not None and isinstance(user_hooks, Var): - _imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore + _imports.extend(user_hooks._var_data.imports) return _imports def _get_imports(self) -> imports.ImportDict: - """Get all the libraries and fields that are used by the component. + """Deprecated method to get all the libraries and fields used by the component. Returns: The imports needed by the component. """ - _imports = {} + return {} + + def _get_imports_list(self) -> imports.ImportList: + """Internal method to get the imports as a list. + + Returns: + The imports as a list. + """ + _imports = imports.ImportList( + itertools.chain( + self._get_props_imports(), + self._get_dependencies_imports(), + self._get_hooks_imports(), + ) + ) + + # Handle deprecated _get_imports + import_dict = self._get_imports() + if import_dict: + console.deprecate( + feature_name="_get_imports", + reason="use add_imports instead", + deprecation_version="0.5.0", + removal_version="0.6.0", + ) + _imports.extend(imports.ImportList.from_import_dict(import_dict)) # Import this component's tag from the main library. if self.library is not None and self.tag is not None: - _imports[self.library] = {self.import_var} + _imports.append(self.import_var) # Get static imports required for event processing. - event_imports = Imports.EVENTS if self.event_triggers else {} + if self.event_triggers: + _imports.append(Imports.EVENTS) # Collect imports from Vars used directly by this component. - var_imports = [ - var._var_data.imports for var in self._get_vars() if var._var_data - ] + for var in self._get_vars(): + if var._var_data: + _imports.extend(var._var_data.imports) + return _imports - return imports.merge_imports( - *self._get_props_imports(), - self._get_dependencies_imports(), - self._get_hooks_imports(), - _imports, - event_imports, - *var_imports, - ) - - def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict: + def _get_all_imports(self, collapse: bool = False) -> imports.ImportList: """Get all the libraries and fields that are used by the component and its children. Args: - collapse: Whether to collapse the imports by removing duplicates. + collapse: Whether to collapse the imports into a dict (deprecated). Returns: - The import dict with the required imports. + The list of all required imports. """ - _imports = imports.merge_imports( - self._get_imports(), *[child._get_all_imports() for child in self.children] + _imports = imports.ImportList( + self._get_imports_list() + + sum((child._get_all_imports() for child in self.children), []) ) - return imports.collapse_imports(_imports) if collapse else _imports + + if collapse: + console.deprecate( + feature_name="collapse kwarg to _get_all_imports", + reason="use ImportList.collapse instead", + deprecation_version="0.5.0", + removal_version="0.6.0", + ) + return _imports.collapse() # type: ignore + + return _imports def _get_mount_lifecycle_hook(self) -> str | None: """Generate the component lifecycle hook. @@ -1296,6 +1334,7 @@ class Component(BaseComponent, ABC): tag = self.tag.partition(".")[0] if self.tag else None alias = self.alias.partition(".")[0] if self.alias else None return ImportVar( + package=self.library, tag=tag, is_default=self.is_default, alias=alias, @@ -1575,7 +1614,6 @@ class NoSSRComponent(Component): return imports.merge_imports( dynamic_import, _imports, - self._get_dependencies_imports(), ) def _get_dynamic_imports(self) -> str: @@ -1893,18 +1931,21 @@ class StatefulComponent(BaseComponent): """ return {} - def _get_all_imports(self) -> imports.ImportDict: + def _get_all_imports(self) -> imports.ImportList: """Get all the libraries and fields that are used by the component. Returns: - The import dict with the required imports. + The list of all required imports. """ if self.rendered_as_shared: - return { - f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [ - ImportVar(tag=self.tag) + return imports.ImportList( + [ + imports.ImportVar( + package=f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}", + tag=self.tag, + ) ] - } + ) return self.component._get_all_imports() def _get_all_dynamic_imports(self) -> set[str]: diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 737b650d8..343b082b3 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -12,9 +12,9 @@ from reflex.style import LIGHT_COLOR_MODE, color_mode from reflex.utils import format, imports from reflex.vars import BaseVar, Var, VarData -_IS_TRUE_IMPORT = { - f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")}, -} +_IS_TRUE_IMPORT = imports.ImportList( + [imports.ImportVar(library=f"/{Dirs.STATE_PATH}", tag="isTrue")] +) class Cond(MemoizationLeaf): @@ -95,11 +95,13 @@ class Cond(MemoizationLeaf): cond_state=f"isTrue({self.cond._var_full_name})", ) - def _get_imports(self) -> imports.ImportDict: - return imports.merge_imports( - super()._get_imports(), - getattr(self.cond._var_data, "imports", {}), - _IS_TRUE_IMPORT, + def _get_imports_list(self) -> imports.ImportList: + return imports.ImportList( + [ + *super()._get_imports_list(), + *getattr(self.cond._var_data, "imports", []), + *_IS_TRUE_IMPORT, + ] ) def _apply_theme(self, theme: Component): diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index b99e31e8c..4686ef5f8 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -6,7 +6,7 @@ from types import SimpleNamespace from reflex.base import Base from reflex.constants import Dirs -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportList, ImportVar # The prefix used to create setters for state vars. SETTER_PREFIX = "set_" @@ -102,11 +102,11 @@ class ComponentName(Enum): class Imports(SimpleNamespace): """Common sets of import vars.""" - EVENTS = { - "react": {ImportVar(tag="useContext")}, - f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")}, - f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)}, - } + EVENTS: ImportList = [ + ImportVar(package="react", tag="useContext"), + ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"), + ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT), + ] class Hooks(SimpleNamespace): diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 70f6b5b25..fa2d115ad 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union from reflex import constants from reflex.utils import exceptions, serializers, types +from reflex.utils.imports import split_library_name_version from reflex.utils.serializers import serialize from reflex.vars import BaseVar, Var @@ -716,11 +717,7 @@ def format_library_name(library_fullname: str): Returns: The name without the @version if it was part of the name """ - lib, at, version = library_fullname.rpartition("@") - if not lib: - lib = at + version - - return lib + return split_library_name_version(library_fullname)[0] def json_dumps(obj: Any) -> str: diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index 263de1e3d..42c3a9385 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -3,9 +3,10 @@ from __future__ import annotations from collections import defaultdict -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set from reflex.base import Base +from reflex.constants.installer import PackageJson def merge_imports(*imports) -> ImportDict: @@ -36,9 +37,29 @@ def collapse_imports(imports: ImportDict) -> ImportDict: return {lib: list(set(import_vars)) for lib, import_vars in imports.items()} +def split_library_name_version(library_fullname: str): + """Split the name of a library from its version. + + Args: + library_fullname: The fullname of the library. + + Returns: + A tuple of the library name and version. + """ + lib, at, version = library_fullname.rpartition("@") + if not lib: + lib = at + version + version = None + + return lib, version + + class ImportVar(Base): """An import var.""" + # The package name associated with the tag + library: Optional[str] + # The name of the import tag. tag: Optional[str] @@ -48,6 +69,12 @@ class ImportVar(Base): # The tag alias. alias: Optional[str] = None + # The following fields provide extra information about the import, + # but are not factored in when considering hash or equality + + # The version of the package + version: Optional[str] + # Whether this import need to install the associated lib install: Optional[bool] = True @@ -58,6 +85,34 @@ class ImportVar(Base): # https://nextjs.org/docs/app/api-reference/next-config-js/transpilePackages transpile: Optional[bool] = False + def __init__( + self, + *, + package: Optional[str] = None, + **kwargs, + ): + if package is not None: + if ( + kwargs.get("library", None) is not None + or kwargs.get("version", None) is not None + ): + raise ValueError( + "Cannot provide 'library' or 'version' as keyword arguments when " + "specifying 'package' as an argument" + ) + kwargs["library"], kwargs["version"] = split_library_name_version(package) + + install = ( + package is not None + # TODO: handle version conflicts + and package not in PackageJson.DEPENDENCIES + and package not in PackageJson.DEV_DEPENDENCIES + and not any(package.startswith(prefix) for prefix in ["/", ".", "next/"]) + and package != "" + ) + kwargs.setdefault("install", install) + super().__init__(**kwargs) + @property def name(self) -> str: """The name of the import. @@ -72,6 +127,17 @@ class ImportVar(Base): else: return self.tag or "" + @property + def package(self) -> str: + """The package to install for this import + + Returns: + The library name and (optional) version to be installed by npm/bun. + """ + if self.version: + return f"{self.library}@{self.version}" + return self.library + def __hash__(self) -> int: """Define a hash function for the import var. @@ -80,14 +146,97 @@ class ImportVar(Base): """ return hash( ( + self.library, self.tag, self.is_default, self.alias, - self.install, - self.render, - self.transpile, + # These do not fundamentally change the import in any way + # self.install, + # self.render, + # self.transpile, ) ) + def __eq__(self, other: ImportVar) -> bool: + """Define equality for the import var. -ImportDict = Dict[str, List[ImportVar]] + Args: + other: The other import var to compare. + + Returns: + Whether the two import vars are equal. + """ + if type(self) != type(other): + return NotImplemented + return (self.library, self.tag, self.is_default, self.alias) == ( + other.library, + other.tag, + other.is_default, + other.alias, + ) + + def collapse(self, other_import_var: ImportVar) -> ImportVar: + """Collapse two import vars together. + + Args: + other_import_var: The other import var to collapse with. + + Returns: + The collapsed import var with sticky props perserved. + """ + if self != other_import_var: + raise ValueError("Cannot collapse two import vars with different hashes") + + if self.version is not None and other_import_var.version is not None: + if self.version != other_import_var.version: + raise ValueError( + "Cannot collapse two import vars with conflicting version specifiers: " + f"{self} {other_import_var}" + ) + + return type(self)( + library=self.library, + version=self.version or other_import_var.version, + tag=self.tag, + is_default=self.is_default, + alias=self.alias, + install=self.install or other_import_var.install, + render=self.render or other_import_var.render, + transpile=self.transpile or other_import_var.transpile, + ) + + +class ImportList(List[ImportVar]): + """A list of import vars.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for ix, value in enumerate(self): + if not isinstance(value, ImportVar): + # convert dicts to ImportVar + self[ix] = ImportVar(**value) + + @classmethod + def from_import_dict(cls, import_dict: ImportDict) -> ImportList: + return [ + ImportVar(package=lib, **imp.dict()) + for lib, imps in import_dict.items() + for imp in imps + ] + + def collapse(self) -> ImportDict: + """When collapsing an import list, prefer packages with version specifiers.""" + collapsed = {} + for imp in self: + collapsed.setdefault(imp.library, {}) + if imp in collapsed[imp.library]: + # Need to check if the current import has any special properties that need to + # be preserved, like the version specifier, install, or transpile. + existing_imp = collapsed[imp.library][imp] + collapsed[imp.library][imp] = existing_imp.collapse(imp) + else: + collapsed[imp.library][imp] = imp + return {lib: set(imps) for lib, imps in collapsed.items()} + + +ImportDict = Dict[str, Set[ImportVar]] diff --git a/reflex/vars.py b/reflex/vars.py index 4a8e6b30f..7793ef07c 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -37,7 +37,7 @@ from reflex.base import Base from reflex.utils import console, format, imports, serializers, types # This module used to export ImportVar itself, so we still import it for export here -from reflex.utils.imports import ImportDict, ImportVar +from reflex.utils.imports import ImportDict, ImportList, ImportVar if TYPE_CHECKING: from reflex.state import BaseState @@ -116,7 +116,7 @@ class VarData(Base): state: str = "" # Imports needed to render this var - imports: ImportDict = {} + imports: ImportList = [] # Hooks that need to be present in the component to render this var hooks: Dict[str, None] = {} @@ -126,6 +126,19 @@ class VarData(Base): # segments. interpolations: List[Tuple[int, int]] = [] + def __init__(self, imports: ImportDict | ImportList = None, **kwargs): + if isinstance(imports, dict): + imports = ImportList.from_import_dict(imports) + console.deprecate( + feature_name="Passing ImportDict for VarData", + reason="use ImportList instead", + deprecation_version="0.5.0", + removal_version="0.6.0", + ) + elif imports is None: + imports = [] + super().__init__(imports=imports, **kwargs) + @classmethod def merge(cls, *others: VarData | None) -> VarData | None: """Merge multiple var data objects. @@ -137,14 +150,14 @@ class VarData(Base): The merged var data object. """ state = "" - _imports = {} + _imports = [] hooks = {} interpolations = [] for var_data in others: if var_data is None: continue state = state or var_data.state - _imports = imports.merge_imports(_imports, var_data.imports) + _imports.extend(var_data.imports) hooks.update(var_data.hooks) interpolations += var_data.interpolations @@ -180,11 +193,18 @@ class VarData(Base): # Don't compare interpolations - that's added in by the decoder, and # not part of the vardata itself. + if not isinstance(self.imports, ImportList): + self_imports = ImportList(self.imports).collapse() + else: + self_imports = self.imports.collapse() + if not isinstance(other.imports, ImportList): + other_imports = ImportList(other.imports).collapse() + else: + other_imports = other.imports.collapse() return ( self.state == other.state and self.hooks.keys() == other.hooks.keys() - and imports.collapse_imports(self.imports) - == imports.collapse_imports(other.imports) + and self_imports == other_imports ) def dict(self) -> dict: @@ -196,10 +216,7 @@ class VarData(Base): return { "state": self.state, "interpolations": list(self.interpolations), - "imports": { - lib: [import_var.dict() for import_var in import_vars] - for lib, import_vars in self.imports.items() - }, + "imports": [import_var.dict() for import_var in self.imports], "hooks": self.hooks, } diff --git a/tests/compiler/test_compiler.py b/tests/compiler/test_compiler.py index b6191974a..db0baae86 100644 --- a/tests/compiler/test_compiler.py +++ b/tests/compiler/test_compiler.py @@ -4,8 +4,7 @@ from typing import List import pytest from reflex.compiler import compiler, utils -from reflex.utils import imports -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportList, ImportVar @pytest.mark.parametrize( @@ -48,43 +47,56 @@ def test_compile_import_statement( @pytest.mark.parametrize( - "import_dict,test_dicts", + "import_list,test_dicts", [ - ({}, []), + (ImportList(), []), ( - {"axios": [ImportVar(tag="axios", is_default=True)]}, + ImportList([ImportVar(library="axios", tag="axios", is_default=True)]), [{"lib": "axios", "default": "axios", "rest": []}], ), ( - {"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]}, + ImportList( + [ + ImportVar(library="axios", tag="foo"), + ImportVar(library="axios", tag="bar"), + ] + ), [{"lib": "axios", "default": "", "rest": ["bar", "foo"]}], ), ( - { - "axios": [ - ImportVar(tag="axios", is_default=True), - ImportVar(tag="foo"), - ImportVar(tag="bar"), - ], - "react": [ImportVar(tag="react", is_default=True)], - }, + ImportList( + [ + ImportVar(library="axios", tag="axios", is_default=True), + ImportVar(library="axios", tag="foo"), + ImportVar(library="axios", tag="bar"), + ImportVar(library="react", tag="react", is_default=True), + ] + ), [ {"lib": "axios", "default": "axios", "rest": ["bar", "foo"]}, {"lib": "react", "default": "react", "rest": []}, ], ), ( - {"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]}, + ImportList( + [ + ImportVar(library="", tag="lib1.js"), + ImportVar(library="", tag="lib2.js"), + ] + ), [ {"lib": "lib1.js", "default": "", "rest": []}, {"lib": "lib2.js", "default": "", "rest": []}, ], ), ( - { - "": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")], - "axios": [ImportVar(tag="axios", is_default=True)], - }, + ImportList( + [ + ImportVar(library="", tag="lib1.js"), + ImportVar(library="", tag="lib2.js"), + ImportVar(library="axios", tag="axios", is_default=True), + ] + ), [ {"lib": "lib1.js", "default": "", "rest": []}, {"lib": "lib2.js", "default": "", "rest": []}, @@ -93,14 +105,14 @@ def test_compile_import_statement( ), ], ) -def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]): +def test_compile_imports(import_list: ImportList, test_dicts: List[dict]): """Test the compile_imports function. Args: - import_dict: The import dictionary. + import_list: The list of ImportVar. test_dicts: The expected output. """ - imports = utils.compile_imports(import_dict) + imports = utils.compile_imports(import_list) for import_dict, test_dict in zip(imports, test_dicts): assert import_dict["lib"] == test_dict["lib"] assert import_dict["default"] == test_dict["default"] diff --git a/tests/components/core/test_banner.py b/tests/components/core/test_banner.py index f929eef37..bfdf86b7c 100644 --- a/tests/components/core/test_banner.py +++ b/tests/components/core/test_banner.py @@ -20,7 +20,7 @@ def test_connection_banner(): "react", "/utils/context", "/utils/state", - "@radix-ui/themes@^3.0.0", + "@radix-ui/themes", "/env.json", ] @@ -36,7 +36,7 @@ def test_connection_modal(): "react", "/utils/context", "/utils/state", - "@radix-ui/themes@^3.0.0", + "@radix-ui/themes", "/env.json", ] diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 21ec409af..942e7a932 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -296,11 +296,11 @@ def test_get_imports(component1, component2): """ c1 = component1.create() c2 = component2.create(c1) - assert c1._get_all_imports() == {"react": [ImportVar(tag="Component")]} - assert c2._get_all_imports() == { - "react-redux": [ImportVar(tag="connect")], - "react": [ImportVar(tag="Component")], - } + assert c1._get_all_imports() == [ImportVar(library="react", tag="Component")] + assert c2._get_all_imports() == [ + ImportVar(library="react-redux", tag="connect"), + ImportVar(library="react", tag="Component"), + ] def test_get_custom_code(component1, component2): @@ -1514,22 +1514,24 @@ def test_custom_component_get_imports(): custom_comp = wrapper() # Inner is not imported directly, but it is imported by the custom component. - assert "inner" not in custom_comp._get_all_imports() + inner_import = ImportVar(library="inner", tag="Inner") + assert inner_import not in custom_comp._get_all_imports() # The imports are only resolved during compilation. _, _, imports_inner = compile_components(custom_comp._get_all_custom_components()) - assert "inner" in imports_inner + assert inner_import in imports_inner outer_comp = outer(c=wrapper()) # Libraries are not imported directly, but are imported by the custom component. - assert "inner" not in outer_comp._get_all_imports() - assert "other" not in outer_comp._get_all_imports() + other_import = ImportVar(library="other", tag="Other") + assert inner_import not in outer_comp._get_all_imports() + assert other_import not in outer_comp._get_all_imports() # The imports are only resolved during compilation. _, _, imports_outer = compile_components(outer_comp._get_all_custom_components()) - assert "inner" in imports_outer - assert "other" in imports_outer + assert inner_import in imports_outer + assert other_import in imports_outer def test_custom_component_declare_event_handlers_in_fields(): diff --git a/tests/test_var.py b/tests/test_var.py index a58c49392..6af50e187 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -837,7 +837,7 @@ def test_state_with_initial_computed_var( (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"), ( f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}", - 'testing f-string with ${"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}{state.myvar}', + 'testing f-string with ${"state": "state", "interpolations": [], "imports": [{"library": "/utils/context", "tag": "StateContexts", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}, {"library": "react", "tag": "useContext", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}], "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}{state.myvar}', ), ( f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",