Use singleton app provider to speed up compiles (#2172)

This commit is contained in:
Nikhil Rao 2023-11-20 18:11:24 -08:00 committed by GitHub
parent b5f6ab3a82
commit e9437ad941
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 151 additions and 102 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import copy
import functools import functools
import os import os
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
@ -589,9 +590,9 @@ class App(Base):
def _app_root(self, app_wrappers): def _app_root(self, app_wrappers):
order = sorted(app_wrappers, key=lambda k: k[0], reverse=True) order = sorted(app_wrappers, key=lambda k: k[0], reverse=True)
root = parent = app_wrappers[order[0]] root = parent = copy.deepcopy(app_wrappers[order[0]])
for key in order[1:]: for key in order[1:]:
child = app_wrappers[key] child = copy.deepcopy(app_wrappers[key])
parent.children.append(child) parent.children.append(child)
parent = child parent = child
return root return root

View File

@ -15,15 +15,15 @@ from reflex.vars import ImportVar
# Imports to be included in every Reflex app. # Imports to be included in every Reflex app.
DEFAULT_IMPORTS: imports.ImportDict = { DEFAULT_IMPORTS: imports.ImportDict = {
"react": { "react": [
ImportVar(tag="Fragment"), ImportVar(tag="Fragment"),
ImportVar(tag="useEffect"), ImportVar(tag="useEffect"),
ImportVar(tag="useRef"), ImportVar(tag="useRef"),
ImportVar(tag="useState"), ImportVar(tag="useState"),
ImportVar(tag="useContext"), ImportVar(tag="useContext"),
}, ],
"next/router": {ImportVar(tag="useRouter")}, "next/router": [ImportVar(tag="useRouter")],
f"/{constants.Dirs.STATE_PATH}": { f"/{constants.Dirs.STATE_PATH}": [
ImportVar(tag="uploadFiles"), ImportVar(tag="uploadFiles"),
ImportVar(tag="Event"), ImportVar(tag="Event"),
ImportVar(tag="isTrue"), ImportVar(tag="isTrue"),
@ -34,17 +34,17 @@ DEFAULT_IMPORTS: imports.ImportDict = {
ImportVar(tag="getRefValues"), ImportVar(tag="getRefValues"),
ImportVar(tag="getAllLocalStorageItems"), ImportVar(tag="getAllLocalStorageItems"),
ImportVar(tag="useEventLoop"), ImportVar(tag="useEventLoop"),
}, ],
"/utils/context.js": { "/utils/context.js": [
ImportVar(tag="EventLoopContext"), ImportVar(tag="EventLoopContext"),
ImportVar(tag="initialEvents"), ImportVar(tag="initialEvents"),
ImportVar(tag="StateContext"), ImportVar(tag="StateContext"),
ImportVar(tag="ColorModeContext"), ImportVar(tag="ColorModeContext"),
}, ],
"/utils/helpers/range.js": { "/utils/helpers/range.js": [
ImportVar(tag="range", is_default=True), ImportVar(tag="range", is_default=True),
}, ],
"": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)}, "": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)],
} }
@ -129,6 +129,7 @@ def _compile_page(
""" """
# Merge the default imports with the app-specific imports. # Merge the default imports with the app-specific imports.
imports = utils.merge_imports(DEFAULT_IMPORTS, component.get_imports()) imports = utils.merge_imports(DEFAULT_IMPORTS, component.get_imports())
imports = {k: list(set(v)) for k, v in imports.items()}
utils.validate_imports(imports) utils.validate_imports(imports)
imports = utils.compile_imports(imports) imports = utils.compile_imports(imports)
@ -216,8 +217,8 @@ def _compile_components(components: set[CustomComponent]) -> str:
The compiled components. The compiled components.
""" """
imports = { imports = {
"react": {ImportVar(tag="memo")}, "react": [ImportVar(tag="memo")],
f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="E"), ImportVar(tag="isTrue")}, f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")],
} }
component_renders = [] component_renders = []

View File

@ -30,7 +30,7 @@ from reflex.vars import ImportVar
merge_imports = imports.merge_imports merge_imports = imports.merge_imports
def compile_import_statement(fields: set[ImportVar]) -> tuple[str, set[str]]: def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
"""Compile an import statement. """Compile an import statement.
Args: Args:
@ -42,17 +42,17 @@ def compile_import_statement(fields: set[ImportVar]) -> tuple[str, set[str]]:
rest: rest of libraries. When install "import {rest1, rest2} from library" rest: rest of libraries. When install "import {rest1, rest2} from library"
""" """
# ignore the ImportVar fields with render=False during compilation # ignore the ImportVar fields with render=False during compilation
fields = {field for field in fields if field.render} fields_set = {field for field in fields if field.render}
# Check for default imports. # Check for default imports.
defaults = {field for field in fields if field.is_default} defaults = {field for field in fields_set if field.is_default}
assert len(defaults) < 2 assert len(defaults) < 2
# Get the default import, and the specific imports. # Get the default import, and the specific imports.
default = next(iter({field.name for field in defaults}), "") default = next(iter({field.name for field in defaults}), "")
rest = {field.name for field in fields - defaults} rest = {field.name for field in fields_set - defaults}
return default, rest return default, list(rest)
def validate_imports(imports: imports.ImportDict): def validate_imports(imports: imports.ImportDict):
@ -109,7 +109,7 @@ def compile_imports(imports: imports.ImportDict) -> list[dict]:
return import_dicts return import_dicts
def get_import_dict(lib: str, default: str = "", rest: set[str] | None = None) -> dict: def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) -> dict:
"""Get dictionary for import template. """Get dictionary for import template.
Args: Args:
@ -123,7 +123,7 @@ def get_import_dict(lib: str, default: str = "", rest: set[str] | None = None) -
return { return {
"lib": lib, "lib": lib,
"default": default, "default": default,
"rest": rest if rest else set(), "rest": rest if rest else [],
} }
@ -235,7 +235,7 @@ def compile_custom_component(
A tuple of the compiled component and the imports required by the component. A tuple of the compiled component and the imports required by the component.
""" """
# Render the component. # Render the component.
render = component.get_component() render = component.get_component(component)
# Get the imports. # Get the imports.
imports = { imports = {

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import typing import typing
from abc import ABC from abc import ABC
from functools import wraps from functools import lru_cache, wraps
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from reflex.base import Base from reflex.base import Base
@ -399,6 +399,7 @@ class Component(Base, ABC):
return tag.add_props(**props) return tag.add_props(**props)
@classmethod @classmethod
@lru_cache(maxsize=None)
def get_props(cls) -> Set[str]: def get_props(cls) -> Set[str]:
"""Get the unique fields for the component. """Get the unique fields for the component.
@ -408,6 +409,7 @@ class Component(Base, ABC):
return set(cls.get_fields()) - set(Component.get_fields()) return set(cls.get_fields()) - set(Component.get_fields())
@classmethod @classmethod
@lru_cache(maxsize=None)
def get_initial_props(cls) -> Set[str]: def get_initial_props(cls) -> Set[str]:
"""Get the initial props to set for the component. """Get the initial props to set for the component.
@ -417,6 +419,7 @@ class Component(Base, ABC):
return set() return set()
@classmethod @classmethod
@lru_cache(maxsize=None)
def get_component_props(cls) -> set[str]: def get_component_props(cls) -> set[str]:
"""Get the props that expected a component as value. """Get the props that expected a component as value.
@ -619,19 +622,17 @@ class Component(Base, ABC):
# Return the dynamic imports # Return the dynamic imports
return dynamic_imports return dynamic_imports
def _get_props_imports(self) -> imports.ImportDict: def _get_props_imports(self) -> List[str]:
"""Get the imports needed for components props. """Get the imports needed for components props.
Returns: Returns:
The imports for the components props of the component. The imports for the components props of the component.
""" """
return imports.merge_imports( return [
*[ getattr(self, prop).get_imports()
getattr(self, prop).get_imports() for prop in self.get_component_props()
for prop in self.get_component_props() if getattr(self, prop) is not None
if getattr(self, prop) is not None ]
]
)
def _get_dependencies_imports(self) -> imports.ImportDict: def _get_dependencies_imports(self) -> imports.ImportDict:
"""Get the imports from lib_dependencies for installing. """Get the imports from lib_dependencies for installing.
@ -639,9 +640,9 @@ class Component(Base, ABC):
Returns: Returns:
The dependencies imports of the component. The dependencies imports of the component.
""" """
return imports.merge_imports( return {
{dep: {ImportVar(tag=None, render=False)} for dep in self.lib_dependencies} dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies
) }
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.
@ -654,7 +655,7 @@ class Component(Base, ABC):
_imports[self.library] = {self.import_var} _imports[self.library] = {self.import_var}
return imports.merge_imports( return imports.merge_imports(
self._get_props_imports(), *self._get_props_imports(),
self._get_dependencies_imports(), self._get_dependencies_imports(),
_imports, _imports,
) )
@ -804,7 +805,8 @@ class Component(Base, ABC):
alias = self.alias.partition(".")[0] if self.alias else None alias = self.alias.partition(".")[0] if self.alias else None
return ImportVar(tag=tag, is_default=self.is_default, alias=alias) return ImportVar(tag=tag, is_default=self.is_default, alias=alias)
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: @staticmethod
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
"""Get the app wrap components for the component. """Get the app wrap components for the component.
Returns: Returns:
@ -948,7 +950,9 @@ class CustomComponent(Component):
# Avoid adding the same component twice. # Avoid adding the same component twice.
if self.tag not in seen: if self.tag not in seen:
seen.add(self.tag) seen.add(self.tag)
custom_components |= self.get_component().get_custom_components(seen=seen) custom_components |= self.get_component(self).get_custom_components(
seen=seen
)
return custom_components return custom_components
def _render(self) -> Tag: def _render(self) -> Tag:
@ -975,6 +979,7 @@ class CustomComponent(Component):
for name, prop in self.props.items() for name, prop in self.props.items()
] ]
@lru_cache(maxsize=None) # noqa
def get_component(self) -> Component: def get_component(self) -> Component:
"""Render the component. """Render the component.

View File

@ -333,7 +333,8 @@ class DataEditor(NoSSRComponent):
height=props.pop("height", "100%"), height=props.pop("height", "100%"),
) )
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: @staticmethod
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
"""Get the app wrap components for the component. """Get the app wrap components for the component.
Returns: Returns:

View File

@ -177,9 +177,9 @@ class Editor(NoSSRComponent):
def _get_imports(self): def _get_imports(self):
imports = super()._get_imports() imports = super()._get_imports()
imports[""] = { imports[""] = [
ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False) ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False)
} ]
return imports return imports
def get_event_triggers(self) -> Dict[str, Any]: def get_event_triggers(self) -> Dict[str, Any]:

View File

@ -175,5 +175,5 @@ class Upload(Component):
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
return { return {
**super()._get_imports(), **super()._get_imports(),
f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="upload_files")}, f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="upload_files")],
} }

View File

@ -53,6 +53,14 @@ class Cond(Component):
) )
) )
def _get_props_imports(self):
"""Get the imports needed for components props.
Returns:
The imports for the components props of the component.
"""
return []
def _render(self) -> Tag: def _render(self) -> Tag:
return CondTag( return CondTag(
cond=self.cond, cond=self.cond,

View File

@ -1,6 +1,7 @@
"""Components that are based on Chakra-UI.""" """Components that are based on Chakra-UI."""
from __future__ import annotations from __future__ import annotations
from functools import lru_cache
from typing import List, Literal from typing import List, Literal
from reflex.components.component import Component from reflex.components.component import Component
@ -17,10 +18,27 @@ class ChakraComponent(Component):
"framer-motion@10.16.4", "framer-motion@10.16.4",
] ]
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: @staticmethod
@lru_cache(maxsize=None)
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
return { return {
**super()._get_app_wrap_components(), (60, "ChakraProvider"): chakra_provider,
(60, "ChakraProvider"): ChakraProvider.create(), }
@classmethod
@lru_cache(maxsize=None)
def _get_dependencies_imports(cls) -> imports.ImportDict:
"""Get the imports from lib_dependencies for installing.
Returns:
The dependencies imports of the component.
"""
return {
dep: [ImportVar(tag=None, render=False)]
for dep in [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
]
} }
@ -58,13 +76,13 @@ class ChakraProvider(ChakraComponent):
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
imports = super()._get_imports() imports = super()._get_imports()
imports.setdefault(self.__fields__["library"].default, set()).add( imports.setdefault(self.__fields__["library"].default, []).append(
ImportVar(tag="extendTheme", is_default=False), ImportVar(tag="extendTheme", is_default=False),
) )
imports.setdefault("/utils/theme.js", set()).add( imports.setdefault("/utils/theme.js", []).append(
ImportVar(tag="theme", is_default=True), ImportVar(tag="theme", is_default=True),
) )
imports.setdefault(Global.__fields__["library"].default, set()).add( imports.setdefault(Global.__fields__["library"].default, []).append(
ImportVar(tag="css", is_default=False), ImportVar(tag="css", is_default=False),
) )
return imports return imports
@ -82,12 +100,17 @@ const GlobalStyles = css`
`; `;
""" """
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: @staticmethod
@lru_cache(maxsize=None)
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
return { return {
(50, "ChakraColorModeProvider"): ChakraColorModeProvider.create(), (50, "ChakraColorModeProvider"): chakra_color_mode_provider,
} }
chakra_provider = ChakraProvider.create()
class ChakraColorModeProvider(Component): class ChakraColorModeProvider(Component):
"""Next-themes integration for chakra colorModeProvider.""" """Next-themes integration for chakra colorModeProvider."""
@ -96,6 +119,9 @@ class ChakraColorModeProvider(Component):
is_default = True is_default = True
chakra_color_mode_provider = ChakraColorModeProvider.create()
LiteralColorScheme = Literal[ LiteralColorScheme = Literal[
"none", "none",
"gray", "gray",

View File

@ -7,6 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.vars import Var, BaseVar, ComputedVar from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style from reflex.style import Style
from functools import lru_cache
from typing import List, Literal from typing import List, Literal
from reflex.components.component import Component from reflex.components.component import Component
from reflex.utils import imports from reflex.utils import imports
@ -238,6 +239,8 @@ class ChakraProvider(ChakraComponent):
""" """
... ...
chakra_provider = ChakraProvider.create()
class ChakraColorModeProvider(Component): class ChakraColorModeProvider(Component):
@overload @overload
@classmethod @classmethod
@ -317,6 +320,7 @@ class ChakraColorModeProvider(Component):
""" """
... ...
chakra_color_mode_provider = ChakraColorModeProvider.create()
LiteralColorScheme = Literal[ LiteralColorScheme = Literal[
"none", "none",
"gray", "gray",

View File

@ -7,6 +7,8 @@ from reflex.components.navigation.nextlink import NextLink
from reflex.utils import imports from reflex.utils import imports
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, Var
next_link = NextLink.create()
class Link(ChakraComponent): class Link(ChakraComponent):
"""Link to another page.""" """Link to another page."""
@ -29,7 +31,7 @@ class Link(ChakraComponent):
is_external: Var[bool] is_external: Var[bool]
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
return {**super()._get_imports(), **NextLink.create()._get_imports()} return {**super()._get_imports(), **next_link._get_imports()}
@classmethod @classmethod
def create(cls, *children, **props) -> Component: def create(cls, *children, **props) -> Component:

View File

@ -13,6 +13,8 @@ from reflex.components.navigation.nextlink import NextLink
from reflex.utils import imports from reflex.utils import imports
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, Var
next_link = NextLink.create()
class Link(ChakraComponent): class Link(ChakraComponent):
@overload @overload
@classmethod @classmethod

View File

@ -28,7 +28,7 @@ class WebsocketTargetURL(Bare):
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
return { return {
"/utils/state.js": {ImportVar(tag="getEventURL")}, "/utils/state.js": [ImportVar(tag="getEventURL")],
} }
@classmethod @classmethod

View File

@ -66,9 +66,9 @@ class RadixThemesComponent(Component):
) )
return component return component
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: @staticmethod
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
return { return {
**super()._get_app_wrap_components(),
(45, "RadixThemesColorModeProvider"): RadixThemesColorModeProvider.create(), (45, "RadixThemesColorModeProvider"): RadixThemesColorModeProvider.create(),
} }
@ -147,7 +147,7 @@ class Theme(RadixThemesComponent):
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
return { return {
**super()._get_imports(), **super()._get_imports(),
"": {ImportVar(tag="@radix-ui/themes/styles.css", install=False)}, "": [ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
} }

View File

@ -151,19 +151,19 @@ class Markdown(Component):
# Special markdown imports. # Special markdown imports.
imports.update( imports.update(
{ {
"": {ImportVar(tag="katex/dist/katex.min.css")}, "": [ImportVar(tag="katex/dist/katex.min.css")],
"remark-math@5.1.1": { "remark-math@5.1.1": [
ImportVar(tag=_REMARK_MATH._var_name, is_default=True) ImportVar(tag=_REMARK_MATH._var_name, is_default=True)
}, ],
"remark-gfm@3.0.1": { "remark-gfm@3.0.1": [
ImportVar(tag=_REMARK_GFM._var_name, is_default=True) ImportVar(tag=_REMARK_GFM._var_name, is_default=True)
}, ],
"rehype-katex@6.0.3": { "rehype-katex@6.0.3": [
ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True) ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True)
}, ],
"rehype-raw@6.1.1": { "rehype-raw@6.1.1": [
ImportVar(tag=_REHYPE_RAW._var_name, is_default=True) ImportVar(tag=_REHYPE_RAW._var_name, is_default=True)
}, ],
} }
) )

View File

@ -3,11 +3,11 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from typing import Dict, Set from typing import Dict, List
from reflex.vars import ImportVar from reflex.vars import ImportVar
ImportDict = Dict[str, Set[ImportVar]] ImportDict = Dict[str, List[ImportVar]]
def merge_imports(*imports) -> ImportDict: def merge_imports(*imports) -> ImportDict:
@ -19,9 +19,8 @@ def merge_imports(*imports) -> ImportDict:
Returns: Returns:
The merged import dicts. The merged import dicts.
""" """
all_imports = defaultdict(set) all_imports = defaultdict(list)
for import_dict in imports: for import_dict in imports:
for lib, fields in import_dict.items(): for lib, fields in import_dict.items():
for field in fields: all_imports[lib].extend(fields)
all_imports[lib].add(field)
return all_imports return all_imports

View File

@ -1,5 +1,5 @@
import os import os
from typing import List, Set from typing import List
import pytest import pytest
@ -12,28 +12,28 @@ from reflex.vars import ImportVar
"fields,test_default,test_rest", "fields,test_default,test_rest",
[ [
( (
{ImportVar(tag="axios", is_default=True)}, [ImportVar(tag="axios", is_default=True)],
"axios", "axios",
set(), [],
), ),
( (
{ImportVar(tag="foo"), ImportVar(tag="bar")}, [ImportVar(tag="foo"), ImportVar(tag="bar")],
"", "",
{"foo", "bar"}, ["bar", "foo"],
), ),
( (
{ [
ImportVar(tag="axios", is_default=True), ImportVar(tag="axios", is_default=True),
ImportVar(tag="foo"), ImportVar(tag="foo"),
ImportVar(tag="bar"), ImportVar(tag="bar"),
}, ],
"axios", "axios",
{"foo", "bar"}, ["bar", "foo"],
), ),
], ],
) )
def test_compile_import_statement( def test_compile_import_statement(
fields: Set[ImportVar], test_default: str, test_rest: str fields: List[ImportVar], test_default: str, test_rest: str
): ):
"""Test the compile_import_statement function. """Test the compile_import_statement function.
@ -44,7 +44,7 @@ def test_compile_import_statement(
""" """
default, rest = utils.compile_import_statement(fields) default, rest = utils.compile_import_statement(fields)
assert default == test_default assert default == test_default
assert rest == test_rest assert sorted(rest) == test_rest
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -52,43 +52,43 @@ def test_compile_import_statement(
[ [
({}, []), ({}, []),
( (
{"axios": {ImportVar(tag="axios", is_default=True)}}, {"axios": [ImportVar(tag="axios", is_default=True)]},
[{"lib": "axios", "default": "axios", "rest": set()}], [{"lib": "axios", "default": "axios", "rest": []}],
), ),
( (
{"axios": {ImportVar(tag="foo"), ImportVar(tag="bar")}}, {"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]},
[{"lib": "axios", "default": "", "rest": {"foo", "bar"}}], [{"lib": "axios", "default": "", "rest": ["bar", "foo"]}],
), ),
( (
{ {
"axios": { "axios": [
ImportVar(tag="axios", is_default=True), ImportVar(tag="axios", is_default=True),
ImportVar(tag="foo"), ImportVar(tag="foo"),
ImportVar(tag="bar"), ImportVar(tag="bar"),
}, ],
"react": {ImportVar(tag="react", is_default=True)}, "react": [ImportVar(tag="react", is_default=True)],
}, },
[ [
{"lib": "axios", "default": "axios", "rest": {"foo", "bar"}}, {"lib": "axios", "default": "axios", "rest": ["bar", "foo"]},
{"lib": "react", "default": "react", "rest": set()}, {"lib": "react", "default": "react", "rest": []},
], ],
), ),
( (
{"": {ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")}}, {"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]},
[ [
{"lib": "lib1.js", "default": "", "rest": set()}, {"lib": "lib1.js", "default": "", "rest": []},
{"lib": "lib2.js", "default": "", "rest": set()}, {"lib": "lib2.js", "default": "", "rest": []},
], ],
), ),
( (
{ {
"": {ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")}, "": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")],
"axios": {ImportVar(tag="axios", is_default=True)}, "axios": [ImportVar(tag="axios", is_default=True)],
}, },
[ [
{"lib": "lib1.js", "default": "", "rest": set()}, {"lib": "lib1.js", "default": "", "rest": []},
{"lib": "lib2.js", "default": "", "rest": set()}, {"lib": "lib2.js", "default": "", "rest": []},
{"lib": "axios", "default": "axios", "rest": set()}, {"lib": "axios", "default": "axios", "rest": []},
], ],
), ),
], ],
@ -104,7 +104,7 @@ def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]
for import_dict, test_dict in zip(imports, test_dicts): for import_dict, test_dict in zip(imports, test_dicts):
assert import_dict["lib"] == test_dict["lib"] assert import_dict["lib"] == test_dict["lib"]
assert import_dict["default"] == test_dict["default"] assert import_dict["default"] == test_dict["default"]
assert import_dict["rest"] == test_dict["rest"] assert sorted(import_dict["rest"]) == test_dict["rest"] # type: ignore
def test_compile_stylesheets(tmp_path, mocker): def test_compile_stylesheets(tmp_path, mocker):

View File

@ -44,7 +44,7 @@ def component1() -> Type[Component]:
number: Var[int] number: Var[int]
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
return {"react": {ImportVar(tag="Component")}} return {"react": [ImportVar(tag="Component")]}
def _get_custom_code(self) -> str: def _get_custom_code(self) -> str:
return "console.log('component1')" return "console.log('component1')"
@ -77,7 +77,7 @@ def component2() -> Type[Component]:
} }
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
return {"react-redux": {ImportVar(tag="connect")}} return {"react-redux": [ImportVar(tag="connect")]}
def _get_custom_code(self) -> str: def _get_custom_code(self) -> str:
return "console.log('component2')" return "console.log('component2')"
@ -268,10 +268,10 @@ def test_get_imports(component1, component2):
""" """
c1 = component1.create() c1 = component1.create()
c2 = component2.create(c1) c2 = component2.create(c1)
assert c1.get_imports() == {"react": {ImportVar(tag="Component")}} assert c1.get_imports() == {"react": [ImportVar(tag="Component")]}
assert c2.get_imports() == { assert c2.get_imports() == {
"react-redux": {ImportVar(tag="connect")}, "react-redux": [ImportVar(tag="connect")],
"react": {ImportVar(tag="Component")}, "react": [ImportVar(tag="Component")],
} }
@ -469,7 +469,7 @@ def test_custom_component_wrapper():
assert len(ccomponent.children) == 1 assert len(ccomponent.children) == 1
assert isinstance(ccomponent.children[0], rx.Text) assert isinstance(ccomponent.children[0], rx.Text)
component = ccomponent.get_component() component = ccomponent.get_component(ccomponent)
assert isinstance(component, Box) assert isinstance(component, Box)