diff --git a/reflex/app.py b/reflex/app.py index b3d9149e6..dff45151e 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -627,7 +627,7 @@ class App(Base): Example: >>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"}) """ - page_imports = [i.package for i in imports.collapse().values() if i.install] + page_imports = [i.package for i in imports if i.install and i.package] frontend_packages = get_config().frontend_packages _frontend_packages = [] for package in frontend_packages: @@ -643,7 +643,7 @@ class App(Base): continue _frontend_packages.append(package) page_imports.extend(_frontend_packages) - prerequisites.install_frontend_packages(page_imports, get_config()) + prerequisites.install_frontend_packages(set(page_imports), get_config()) def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component: for component in tuple(app_wrappers.values()): diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index dae97c153..60a46bf9c 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -188,7 +188,7 @@ def _compile_component(component: Component) -> str: def _compile_components( components: set[CustomComponent], -) -> tuple[str, Dict[str, list[ImportVar]]]: +) -> tuple[str, ImportList]: """Compile the components. Args: diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index ee65d08a2..8385fd76c 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -247,7 +247,7 @@ def compile_custom_component( render = component.get_component(component) # Get the imports. - component_library_name = format.format_library_name(component.library) + component_library_name = format.format_library_name(component.library or "") _imports = imports.ImportList( imp for imp in render._get_all_imports() diff --git a/reflex/components/chakra/base.py b/reflex/components/chakra/base.py index 95e457a05..88b39fe97 100644 --- a/reflex/components/chakra/base.py +++ b/reflex/components/chakra/base.py @@ -35,7 +35,7 @@ class ChakraComponent(Component): @classmethod @lru_cache(maxsize=None) - def _get_dependencies_imports(cls) -> imports.ImportList: + def _get_dependencies_imports(cls) -> List[imports.ImportVar]: """Get the imports from lib_dependencies for installing. Returns: @@ -67,13 +67,21 @@ class ChakraProvider(ChakraComponent): theme=Var.create("extendTheme(theme)", _var_is_local=False), ) - def _get_imports(self) -> imports.ImportDict: - _imports = super()._get_imports() - _imports.setdefault(self.__fields__["library"].default, []).append( - imports.ImportVar(tag="extendTheme", is_default=False), - ) - _imports.setdefault("/utils/theme.js", []).append( - imports.ImportVar(tag="theme", is_default=True), + def _get_imports_list(self) -> List[imports.ImportVar]: + _imports = super()._get_imports_list() + _imports.extend( + [ + imports.ImportVar( + package=self.__fields__["library"].default, + tag="extendTheme", + is_default=False, + ), + imports.ImportVar( + package="/utils/theme.js", + tag="theme", + is_default=True, + ), + ], ) return _imports diff --git a/reflex/components/component.py b/reflex/components/component.py index 96834be08..62a242a75 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1026,7 +1026,7 @@ class Component(BaseComponent, ABC): or format.format_library_name(dep or "") in self.transpile_packages ) - def _get_dependencies_imports(self) -> imports.ImportList: + def _get_dependencies_imports(self) -> List[ImportVar]: """Get the imports from lib_dependencies for installing. Returns: @@ -1073,7 +1073,11 @@ class Component(BaseComponent, ABC): ) user_hooks = self._get_hooks() - if user_hooks is not None and isinstance(user_hooks, Var): + if ( + user_hooks is not None + and isinstance(user_hooks, Var) + and user_hooks._var_data is not None + ): _imports.extend(user_hooks._var_data.imports) return _imports @@ -1086,7 +1090,7 @@ class Component(BaseComponent, ABC): """ return {} - def _get_imports_list(self) -> imports.ImportList: + def _get_imports_list(self) -> List[ImportVar]: """Internal method to get the imports as a list. Returns: @@ -1117,7 +1121,7 @@ class Component(BaseComponent, ABC): # Get static imports required for event processing. if self.event_triggers: - _imports.append(Imports.EVENTS) + _imports.extend(Imports.EVENTS) # Collect imports from Vars used directly by this component. for var in self._get_vars(): diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py index 0c781fba8..5a2bb2425 100644 --- a/reflex/components/core/banner.py +++ b/reflex/components/core/banner.py @@ -51,11 +51,14 @@ has_too_many_connection_errors: Var = Var.create_safe( class WebsocketTargetURL(Bare): """A component that renders the websocket target URL.""" - def _get_imports(self) -> imports.ImportDict: - return { - f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")], - "/env.json": [imports.ImportVar(tag="env", is_default=True)], - } + def _get_imports_list(self) -> list[imports.ImportVar]: + return [ + imports.ImportVar( + library=f"/{Dirs.STATE_PATH}", + tag="getBackendURL", + ), + imports.ImportVar(library="/env.json", tag="env", is_default=True), + ] @classmethod def create(cls) -> Component: diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index ca54d0aa5..06f46aee3 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -7,7 +7,6 @@ from functools import lru_cache from hashlib import md5 from typing import Any, Callable, Dict, Union -from reflex.compiler import utils from reflex.components.component import Component, CustomComponent from reflex.components.radix.themes.layout.list import ( ListItem, @@ -154,47 +153,53 @@ class Markdown(Component): return custom_components - def _get_imports(self) -> imports.ImportDict: + def _get_imports_list(self) -> list[imports.ImportVar]: # Import here to avoid circular imports. from reflex.components.datadisplay.code import CodeBlock from reflex.components.radix.themes.typography.code import Code - imports = super()._get_imports() + _imports = super()._get_imports_list() # Special markdown imports. - imports.update( - { - "": [ImportVar(tag="katex/dist/katex.min.css")], - "remark-math@5.1.1": [ - ImportVar(tag=_REMARK_MATH._var_name, is_default=True) - ], - "remark-gfm@3.0.1": [ - ImportVar(tag=_REMARK_GFM._var_name, is_default=True) - ], - "remark-unwrap-images@4.0.0": [ - ImportVar(tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True) - ], - "rehype-katex@6.0.3": [ - ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True) - ], - "rehype-raw@6.1.1": [ - ImportVar(tag=_REHYPE_RAW._var_name, is_default=True) - ], - } + _imports.extend( + [ + ImportVar(library="", tag="katex/dist/katex.min.css"), + ImportVar( + package="remark-math@5.1.1", + tag=_REMARK_MATH._var_name, + is_default=True, + ), + ImportVar( + package="remark-gfm@3.0.1", + tag=_REMARK_GFM._var_name, + is_default=True, + ), + ImportVar( + package="remark-unwrap-images@4.0.0", + tag=_REMARK_UNWRAP_IMAGES._var_name, + is_default=True, + ), + ImportVar( + package="remark-katex@6.0.3", + tag=_REHYPE_KATEX._var_name, + is_default=True, + ), + ImportVar( + package="rehype-raw@6.1.1", + tag=_REHYPE_RAW._var_name, + is_default=True, + ), + ] ) # Get the imports for each component. for component in self.component_map.values(): - imports = utils.merge_imports( - imports, component(_MOCK_ARG)._get_all_imports() - ) + _imports.extend(component(_MOCK_ARG)._get_all_imports()) # Get the imports for the code components. - imports = utils.merge_imports( - imports, CodeBlock.create(theme="light")._get_imports() - ) - imports = utils.merge_imports(imports, Code.create()._get_imports()) - return imports + _imports.extend(CodeBlock.create(theme="light")._get_all_imports()) + _imports.extend(Code.create()._get_all_imports()) + return _imports def get_component(self, tag: str, **props) -> Component: """Get the component for a tag and props. diff --git a/reflex/components/radix/themes/base.py b/reflex/components/radix/themes/base.py index 559d10239..2e282e604 100644 --- a/reflex/components/radix/themes/base.py +++ b/reflex/components/radix/themes/base.py @@ -243,13 +243,11 @@ class ThemePanel(RadixThemesComponent): # Whether the panel is open. Defaults to False. default_open: Var[bool] - def _get_imports(self) -> dict[str, list[imports.ImportVar]]: - return imports.merge_imports( - super()._get_imports(), - { - "react": [imports.ImportVar(tag="useEffect")], - }, - ) + def _get_imports_list(self) -> list[imports.ImportVar]: + return [ + *super()._get_imports_list(), + imports.ImportVar(package="react", tag="useEffect"), + ] def _get_hooks(self) -> str | None: # The panel freezes the tab if the user color preference differs from the diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 4686ef5f8..cf52d8a1a 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -102,11 +102,13 @@ class ComponentName(Enum): class Imports(SimpleNamespace): """Common sets of import vars.""" - EVENTS: ImportList = [ - ImportVar(package="react", tag="useContext"), - ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"), - ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT), - ] + EVENTS: ImportList = ImportList( + [ + ImportVar(package="react", tag="useContext"), + ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"), + ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT), + ] + ) class Hooks(SimpleNamespace): diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index 42c3a9385..5c8062772 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional from reflex.base import Base from reflex.constants.installer import PackageJson @@ -91,6 +91,15 @@ class ImportVar(Base): package: Optional[str] = None, **kwargs, ): + """Create a new ImportVar. + + Args: + package: The package to install for this import. + **kwargs: The import var fields. + + Raises: + ValueError: If the package is provided with library or version. + """ if package is not None: if ( kwargs.get("library", None) is not None @@ -128,8 +137,8 @@ class ImportVar(Base): return self.tag or "" @property - def package(self) -> str: - """The package to install for this import + def package(self) -> str | None: + """The package to install for this import. Returns: The library name and (optional) version to be installed by npm/bun. @@ -150,10 +159,6 @@ class ImportVar(Base): self.tag, self.is_default, self.alias, - # These do not fundamentally change the import in any way - # self.install, - # self.render, - # self.transpile, ) ) @@ -183,16 +188,22 @@ class ImportVar(Base): Returns: The collapsed import var with sticky props perserved. + + Raises: + ValueError: If the two import vars have conflicting properties. """ if self != other_import_var: raise ValueError("Cannot collapse two import vars with different hashes") - if self.version is not None and other_import_var.version is not None: - if self.version != other_import_var.version: - raise ValueError( - "Cannot collapse two import vars with conflicting version specifiers: " - f"{self} {other_import_var}" - ) + if ( + self.version is not None + and other_import_var.version is not None + and self.version != other_import_var.version + ): + raise ValueError( + "Cannot collapse two import vars with conflicting version specifiers: " + f"{self} {other_import_var}" + ) return type(self)( library=self.library, @@ -210,6 +221,15 @@ class ImportList(List[ImportVar]): """A list of import vars.""" def __init__(self, *args, **kwargs): + """Create a new ImportList (wrapper over `list`). + + Any items that are not already `ImportVar` will be assumed as dicts to convert + into an ImportVar. + + Args: + *args: The args to pass to list.__init__ + **kwargs: The kwargs to pass to list.__init__ + """ super().__init__(*args, **kwargs) for ix, value in enumerate(self): if not isinstance(value, ImportVar): @@ -217,26 +237,41 @@ class ImportList(List[ImportVar]): self[ix] = ImportVar(**value) @classmethod - def from_import_dict(cls, import_dict: ImportDict) -> ImportList: - return [ + def from_import_dict( + cls, import_dict: ImportDict | Dict[str, set[ImportVar]] + ) -> ImportList: + """Create an import list from an import dict. + + Args: + import_dict: The import dict to convert. + + Returns: + The import list. + """ + return cls( ImportVar(package=lib, **imp.dict()) for lib, imps in import_dict.items() for imp in imps - ] + ) def collapse(self) -> ImportDict: - """When collapsing an import list, prefer packages with version specifiers.""" - collapsed = {} + """When collapsing an import list, prefer packages with version specifiers. + + Returns: + The collapsed import dict ({library_name: [import_var1, ...]}). + """ + collapsed: dict[str, dict[ImportVar, ImportVar]] = {} for imp in self: - collapsed.setdefault(imp.library, {}) - if imp in collapsed[imp.library]: + lib = imp.library or "" + collapsed.setdefault(lib, {}) + if imp in collapsed[lib]: # Need to check if the current import has any special properties that need to # be preserved, like the version specifier, install, or transpile. - existing_imp = collapsed[imp.library][imp] - collapsed[imp.library][imp] = existing_imp.collapse(imp) + existing_imp = collapsed[lib][imp] + collapsed[lib][imp] = existing_imp.collapse(imp) else: - collapsed[imp.library][imp] = imp - return {lib: set(imps) for lib, imps in collapsed.items()} + collapsed[lib][imp] = imp + return {lib: list(set(imps)) for lib, imps in collapsed.items()} -ImportDict = Dict[str, Set[ImportVar]] +ImportDict = Dict[str, List[ImportVar]] diff --git a/reflex/vars.py b/reflex/vars.py index 7793ef07c..4a402f852 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -34,7 +34,7 @@ from typing import ( from reflex import constants from reflex.base import Base -from reflex.utils import console, format, imports, serializers, types +from reflex.utils import console, format, serializers, types # This module used to export ImportVar itself, so we still import it for export here from reflex.utils.imports import ImportDict, ImportList, ImportVar @@ -116,7 +116,7 @@ class VarData(Base): state: str = "" # Imports needed to render this var - imports: ImportList = [] + imports: ImportList = ImportList() # Hooks that need to be present in the component to render this var hooks: Dict[str, None] = {} @@ -126,7 +126,24 @@ class VarData(Base): # segments. interpolations: List[Tuple[int, int]] = [] - def __init__(self, imports: ImportDict | ImportList = None, **kwargs): + def __init__( + self, + imports: ImportList + | List[ImportVar | Dict[str, Optional[Union[str, bool]]]] + | ImportDict + | Dict[str, set[ImportVar]] + | None = None, + **kwargs, + ): + """Initialize the VarData. + + If imports is an ImportDict it will be converted to an ImportList and a + deprecation warning will be displayed. + + Args: + imports: The imports needed to render this var. + **kwargs: Additional fields to set. + """ if isinstance(imports, dict): imports = ImportList.from_import_dict(imports) console.deprecate( @@ -135,9 +152,12 @@ class VarData(Base): deprecation_version="0.5.0", removal_version="0.6.0", ) - elif imports is None: - imports = [] - super().__init__(imports=imports, **kwargs) + else: + imports = ImportList(imports or []) + super().__init__( + imports=imports, # type: ignore + **kwargs, + ) @classmethod def merge(cls, *others: VarData | None) -> VarData | None: @@ -150,7 +170,7 @@ class VarData(Base): The merged var data object. """ state = "" - _imports = [] + _imports = ImportList() hooks = {} interpolations = [] for var_data in others: @@ -1059,11 +1079,12 @@ class Var: ",", other, fn="spreadArraysOrObjects", flip=flip )._replace( merge_var_data=VarData( - imports={ - f"/{constants.Dirs.STATE_PATH}": [ - ImportVar(tag="spreadArraysOrObjects") - ] - }, + imports=[ + ImportVar( + package=f"/{constants.Dirs.STATE_PATH}", + tag="spreadArraysOrObjects", + ), + ], ), ) return self.operation("+", other, flip=flip) @@ -1612,11 +1633,11 @@ class Var: v2._var_data, step._var_data, VarData( - imports={ - "/utils/helpers/range.js": [ - ImportVar(tag="range", is_default=True), - ], - }, + imports=[ + ImportVar( + package="/utils/helpers/range", tag="range", is_default=True + ), + ] ), ), ) @@ -1644,9 +1665,9 @@ class Var: _var_is_string=False, _var_full_name_needs_state_prefix=False, merge_var_data=VarData( - imports={ - f"/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")], - }, + imports=[ + ImportVar(package=f"/{constants.Dirs.STATE_PATH}", tag="refs") + ], ), ) @@ -1684,10 +1705,14 @@ class Var: format.format_state_name(state_name) ): None }, - imports={ - f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")], - "react": [ImportVar(tag="useContext")], - }, + imports=ImportList( + [ + ImportVar( + package=f"/{constants.Dirs.CONTEXTS_PATH}", tag="StateContexts" + ), + ImportVar(package="react", tag="useContext"), + ] + ), ) self._var_data = VarData.merge(self._var_data, new_var_data) self._var_full_name_needs_state_prefix = True diff --git a/tests/components/core/test_banner.py b/tests/components/core/test_banner.py index bfdf86b7c..ac01a825c 100644 --- a/tests/components/core/test_banner.py +++ b/tests/components/core/test_banner.py @@ -9,14 +9,14 @@ from reflex.components.radix.themes.typography.text import Text def test_websocket_target_url(): url = WebsocketTargetURL.create() - _imports = url._get_all_imports(collapse=True) - assert list(_imports.keys()) == ["/utils/state", "/env.json"] + _imports = url._get_all_imports() + assert [i.library for i in _imports] == ["/utils/state", "/env.json"] def test_connection_banner(): banner = ConnectionBanner.create() - _imports = banner._get_all_imports(collapse=True) - assert list(_imports.keys()) == [ + _imports = banner._get_all_imports() + assert [i.library for i in _imports] == [ "react", "/utils/context", "/utils/state", @@ -31,8 +31,8 @@ def test_connection_banner(): def test_connection_modal(): modal = ConnectionModal.create() - _imports = modal._get_all_imports(collapse=True) - assert list(_imports.keys()) == [ + _imports = modal._get_all_imports() + assert [i.library for i in _imports] == [ "react", "/utils/context", "/utils/state", @@ -48,4 +48,4 @@ def test_connection_modal(): def test_connection_pulser(): pulser = ConnectionPulser.create() _custom_code = pulser._get_all_custom_code() - _imports = pulser._get_all_imports(collapse=True) + _imports = pulser._get_all_imports()