diff --git a/reflex/app.py b/reflex/app.py index 477d06511..6248bcec0 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import contextlib +import copy import functools import os from multiprocessing.pool import ThreadPool @@ -589,9 +590,9 @@ class App(Base): def _app_root(self, app_wrappers): order = sorted(app_wrappers, key=lambda k: k[0], reverse=True) - root = parent = app_wrappers[order[0]] + root = parent = copy.deepcopy(app_wrappers[order[0]]) for key in order[1:]: - child = app_wrappers[key] + child = copy.deepcopy(app_wrappers[key]) parent.children.append(child) parent = child return root diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index aa2e5f2b9..4ffd4a876 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -15,15 +15,15 @@ from reflex.vars import ImportVar # Imports to be included in every Reflex app. DEFAULT_IMPORTS: imports.ImportDict = { - "react": { + "react": [ ImportVar(tag="Fragment"), ImportVar(tag="useEffect"), ImportVar(tag="useRef"), ImportVar(tag="useState"), ImportVar(tag="useContext"), - }, - "next/router": {ImportVar(tag="useRouter")}, - f"/{constants.Dirs.STATE_PATH}": { + ], + "next/router": [ImportVar(tag="useRouter")], + f"/{constants.Dirs.STATE_PATH}": [ ImportVar(tag="uploadFiles"), ImportVar(tag="Event"), ImportVar(tag="isTrue"), @@ -34,17 +34,17 @@ DEFAULT_IMPORTS: imports.ImportDict = { ImportVar(tag="getRefValues"), ImportVar(tag="getAllLocalStorageItems"), ImportVar(tag="useEventLoop"), - }, - "/utils/context.js": { + ], + "/utils/context.js": [ ImportVar(tag="EventLoopContext"), ImportVar(tag="initialEvents"), ImportVar(tag="StateContext"), ImportVar(tag="ColorModeContext"), - }, - "/utils/helpers/range.js": { + ], + "/utils/helpers/range.js": [ ImportVar(tag="range", is_default=True), - }, - "": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)}, + ], + "": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)], } @@ -129,6 +129,7 @@ def _compile_page( """ # Merge the default imports with the app-specific imports. imports = utils.merge_imports(DEFAULT_IMPORTS, component.get_imports()) + imports = {k: list(set(v)) for k, v in imports.items()} utils.validate_imports(imports) imports = utils.compile_imports(imports) @@ -216,8 +217,8 @@ def _compile_components(components: set[CustomComponent]) -> str: The compiled components. """ imports = { - "react": {ImportVar(tag="memo")}, - f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="E"), ImportVar(tag="isTrue")}, + "react": [ImportVar(tag="memo")], + f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")], } component_renders = [] diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index dc7c384a4..b6967eaef 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -30,7 +30,7 @@ from reflex.vars import ImportVar merge_imports = imports.merge_imports -def compile_import_statement(fields: set[ImportVar]) -> tuple[str, set[str]]: +def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]: """Compile an import statement. Args: @@ -42,17 +42,17 @@ def compile_import_statement(fields: set[ImportVar]) -> tuple[str, set[str]]: rest: rest of libraries. When install "import {rest1, rest2} from library" """ # ignore the ImportVar fields with render=False during compilation - fields = {field for field in fields if field.render} + fields_set = {field for field in fields if field.render} # Check for default imports. - defaults = {field for field in fields if field.is_default} + defaults = {field for field in fields_set if field.is_default} assert len(defaults) < 2 # Get the default import, and the specific imports. default = next(iter({field.name for field in defaults}), "") - rest = {field.name for field in fields - defaults} + rest = {field.name for field in fields_set - defaults} - return default, rest + return default, list(rest) def validate_imports(imports: imports.ImportDict): @@ -109,7 +109,7 @@ def compile_imports(imports: imports.ImportDict) -> list[dict]: return import_dicts -def get_import_dict(lib: str, default: str = "", rest: set[str] | None = None) -> dict: +def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) -> dict: """Get dictionary for import template. Args: @@ -123,7 +123,7 @@ def get_import_dict(lib: str, default: str = "", rest: set[str] | None = None) - return { "lib": lib, "default": default, - "rest": rest if rest else set(), + "rest": rest if rest else [], } @@ -235,7 +235,7 @@ def compile_custom_component( A tuple of the compiled component and the imports required by the component. """ # Render the component. - render = component.get_component() + render = component.get_component(component) # Get the imports. imports = { diff --git a/reflex/components/component.py b/reflex/components/component.py index 1771bea6d..d274fbc37 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -4,7 +4,7 @@ from __future__ import annotations import typing from abc import ABC -from functools import wraps +from functools import lru_cache, wraps from typing import Any, Callable, Dict, List, Optional, Set, Type, Union from reflex.base import Base @@ -399,6 +399,7 @@ class Component(Base, ABC): return tag.add_props(**props) @classmethod + @lru_cache(maxsize=None) def get_props(cls) -> Set[str]: """Get the unique fields for the component. @@ -408,6 +409,7 @@ class Component(Base, ABC): return set(cls.get_fields()) - set(Component.get_fields()) @classmethod + @lru_cache(maxsize=None) def get_initial_props(cls) -> Set[str]: """Get the initial props to set for the component. @@ -417,6 +419,7 @@ class Component(Base, ABC): return set() @classmethod + @lru_cache(maxsize=None) def get_component_props(cls) -> set[str]: """Get the props that expected a component as value. @@ -619,19 +622,17 @@ class Component(Base, ABC): # Return the dynamic imports return dynamic_imports - def _get_props_imports(self) -> imports.ImportDict: + def _get_props_imports(self) -> List[str]: """Get the imports needed for components props. Returns: The imports for the components props of the component. """ - return imports.merge_imports( - *[ - getattr(self, prop).get_imports() - for prop in self.get_component_props() - if getattr(self, prop) is not None - ] - ) + return [ + getattr(self, prop).get_imports() + for prop in self.get_component_props() + if getattr(self, prop) is not None + ] def _get_dependencies_imports(self) -> imports.ImportDict: """Get the imports from lib_dependencies for installing. @@ -639,9 +640,9 @@ class Component(Base, ABC): Returns: The dependencies imports of the component. """ - return imports.merge_imports( - {dep: {ImportVar(tag=None, render=False)} for dep in self.lib_dependencies} - ) + return { + dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies + } def _get_imports(self) -> imports.ImportDict: """Get all the libraries and fields that are used by the component. @@ -654,7 +655,7 @@ class Component(Base, ABC): _imports[self.library] = {self.import_var} return imports.merge_imports( - self._get_props_imports(), + *self._get_props_imports(), self._get_dependencies_imports(), _imports, ) @@ -804,7 +805,8 @@ class Component(Base, ABC): alias = self.alias.partition(".")[0] if self.alias else None return ImportVar(tag=tag, is_default=self.is_default, alias=alias) - def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: """Get the app wrap components for the component. Returns: @@ -948,7 +950,9 @@ class CustomComponent(Component): # Avoid adding the same component twice. if self.tag not in seen: seen.add(self.tag) - custom_components |= self.get_component().get_custom_components(seen=seen) + custom_components |= self.get_component(self).get_custom_components( + seen=seen + ) return custom_components def _render(self) -> Tag: @@ -975,6 +979,7 @@ class CustomComponent(Component): for name, prop in self.props.items() ] + @lru_cache(maxsize=None) # noqa def get_component(self) -> Component: """Render the component. diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index a39cdfc90..e666d284c 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -333,7 +333,8 @@ class DataEditor(NoSSRComponent): height=props.pop("height", "100%"), ) - def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: """Get the app wrap components for the component. Returns: diff --git a/reflex/components/forms/editor.py b/reflex/components/forms/editor.py index 3f9bf34ff..b7bfb08a5 100644 --- a/reflex/components/forms/editor.py +++ b/reflex/components/forms/editor.py @@ -177,9 +177,9 @@ class Editor(NoSSRComponent): def _get_imports(self): imports = super()._get_imports() - imports[""] = { + imports[""] = [ ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False) - } + ] return imports def get_event_triggers(self) -> Dict[str, Any]: diff --git a/reflex/components/forms/upload.py b/reflex/components/forms/upload.py index 21457e64e..bd01d01c6 100644 --- a/reflex/components/forms/upload.py +++ b/reflex/components/forms/upload.py @@ -175,5 +175,5 @@ class Upload(Component): def _get_imports(self) -> imports.ImportDict: return { **super()._get_imports(), - f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="upload_files")}, + f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="upload_files")], } diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py index b584b0f0c..2455433d7 100644 --- a/reflex/components/layout/cond.py +++ b/reflex/components/layout/cond.py @@ -53,6 +53,14 @@ class Cond(Component): ) ) + def _get_props_imports(self): + """Get the imports needed for components props. + + Returns: + The imports for the components props of the component. + """ + return [] + def _render(self) -> Tag: return CondTag( cond=self.cond, diff --git a/reflex/components/libs/chakra.py b/reflex/components/libs/chakra.py index f2773017b..fdab62fd4 100644 --- a/reflex/components/libs/chakra.py +++ b/reflex/components/libs/chakra.py @@ -1,6 +1,7 @@ """Components that are based on Chakra-UI.""" from __future__ import annotations +from functools import lru_cache from typing import List, Literal from reflex.components.component import Component @@ -17,10 +18,27 @@ class ChakraComponent(Component): "framer-motion@10.16.4", ] - def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: + @staticmethod + @lru_cache(maxsize=None) + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: return { - **super()._get_app_wrap_components(), - (60, "ChakraProvider"): ChakraProvider.create(), + (60, "ChakraProvider"): chakra_provider, + } + + @classmethod + @lru_cache(maxsize=None) + def _get_dependencies_imports(cls) -> imports.ImportDict: + """Get the imports from lib_dependencies for installing. + + Returns: + The dependencies imports of the component. + """ + return { + dep: [ImportVar(tag=None, render=False)] + for dep in [ + "@chakra-ui/system@2.5.7", + "framer-motion@10.16.4", + ] } @@ -58,13 +76,13 @@ class ChakraProvider(ChakraComponent): def _get_imports(self) -> imports.ImportDict: imports = super()._get_imports() - imports.setdefault(self.__fields__["library"].default, set()).add( + imports.setdefault(self.__fields__["library"].default, []).append( ImportVar(tag="extendTheme", is_default=False), ) - imports.setdefault("/utils/theme.js", set()).add( + imports.setdefault("/utils/theme.js", []).append( ImportVar(tag="theme", is_default=True), ) - imports.setdefault(Global.__fields__["library"].default, set()).add( + imports.setdefault(Global.__fields__["library"].default, []).append( ImportVar(tag="css", is_default=False), ) return imports @@ -82,12 +100,17 @@ const GlobalStyles = css` `; """ - def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: + @staticmethod + @lru_cache(maxsize=None) + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: return { - (50, "ChakraColorModeProvider"): ChakraColorModeProvider.create(), + (50, "ChakraColorModeProvider"): chakra_color_mode_provider, } +chakra_provider = ChakraProvider.create() + + class ChakraColorModeProvider(Component): """Next-themes integration for chakra colorModeProvider.""" @@ -96,6 +119,9 @@ class ChakraColorModeProvider(Component): is_default = True +chakra_color_mode_provider = ChakraColorModeProvider.create() + + LiteralColorScheme = Literal[ "none", "gray", diff --git a/reflex/components/libs/chakra.pyi b/reflex/components/libs/chakra.pyi index 4679c3f8d..967dcda35 100644 --- a/reflex/components/libs/chakra.pyi +++ b/reflex/components/libs/chakra.pyi @@ -7,6 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style +from functools import lru_cache from typing import List, Literal from reflex.components.component import Component from reflex.utils import imports @@ -238,6 +239,8 @@ class ChakraProvider(ChakraComponent): """ ... +chakra_provider = ChakraProvider.create() + class ChakraColorModeProvider(Component): @overload @classmethod @@ -317,6 +320,7 @@ class ChakraColorModeProvider(Component): """ ... +chakra_color_mode_provider = ChakraColorModeProvider.create() LiteralColorScheme = Literal[ "none", "gray", diff --git a/reflex/components/navigation/link.py b/reflex/components/navigation/link.py index 6b8c61cd7..3894a00b3 100644 --- a/reflex/components/navigation/link.py +++ b/reflex/components/navigation/link.py @@ -7,6 +7,8 @@ from reflex.components.navigation.nextlink import NextLink from reflex.utils import imports from reflex.vars import BaseVar, Var +next_link = NextLink.create() + class Link(ChakraComponent): """Link to another page.""" @@ -29,7 +31,7 @@ class Link(ChakraComponent): is_external: Var[bool] def _get_imports(self) -> imports.ImportDict: - return {**super()._get_imports(), **NextLink.create()._get_imports()} + return {**super()._get_imports(), **next_link._get_imports()} @classmethod def create(cls, *children, **props) -> Component: diff --git a/reflex/components/navigation/link.pyi b/reflex/components/navigation/link.pyi index 4e3cac328..b5af30572 100644 --- a/reflex/components/navigation/link.pyi +++ b/reflex/components/navigation/link.pyi @@ -13,6 +13,8 @@ from reflex.components.navigation.nextlink import NextLink from reflex.utils import imports from reflex.vars import BaseVar, Var +next_link = NextLink.create() + class Link(ChakraComponent): @overload @classmethod diff --git a/reflex/components/overlay/banner.py b/reflex/components/overlay/banner.py index 7b9a31eb9..cdb01c063 100644 --- a/reflex/components/overlay/banner.py +++ b/reflex/components/overlay/banner.py @@ -28,7 +28,7 @@ class WebsocketTargetURL(Bare): def _get_imports(self) -> imports.ImportDict: return { - "/utils/state.js": {ImportVar(tag="getEventURL")}, + "/utils/state.js": [ImportVar(tag="getEventURL")], } @classmethod diff --git a/reflex/components/radix/themes/base.py b/reflex/components/radix/themes/base.py index b24b1c0e8..3e589faca 100644 --- a/reflex/components/radix/themes/base.py +++ b/reflex/components/radix/themes/base.py @@ -66,9 +66,9 @@ class RadixThemesComponent(Component): ) return component - def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: return { - **super()._get_app_wrap_components(), (45, "RadixThemesColorModeProvider"): RadixThemesColorModeProvider.create(), } @@ -147,7 +147,7 @@ class Theme(RadixThemesComponent): def _get_imports(self) -> imports.ImportDict: return { **super()._get_imports(), - "": {ImportVar(tag="@radix-ui/themes/styles.css", install=False)}, + "": [ImportVar(tag="@radix-ui/themes/styles.css", install=False)], } diff --git a/reflex/components/typography/markdown.py b/reflex/components/typography/markdown.py index 7459d1e38..f414f1781 100644 --- a/reflex/components/typography/markdown.py +++ b/reflex/components/typography/markdown.py @@ -151,19 +151,19 @@ class Markdown(Component): # Special markdown imports. imports.update( { - "": {ImportVar(tag="katex/dist/katex.min.css")}, - "remark-math@5.1.1": { + "": [ImportVar(tag="katex/dist/katex.min.css")], + "remark-math@5.1.1": [ ImportVar(tag=_REMARK_MATH._var_name, is_default=True) - }, - "remark-gfm@3.0.1": { + ], + "remark-gfm@3.0.1": [ ImportVar(tag=_REMARK_GFM._var_name, is_default=True) - }, - "rehype-katex@6.0.3": { + ], + "rehype-katex@6.0.3": [ ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True) - }, - "rehype-raw@6.1.1": { + ], + "rehype-raw@6.1.1": [ ImportVar(tag=_REHYPE_RAW._var_name, is_default=True) - }, + ], } ) diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index d1e37ae5d..f20ec141d 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -3,11 +3,11 @@ from __future__ import annotations from collections import defaultdict -from typing import Dict, Set +from typing import Dict, List from reflex.vars import ImportVar -ImportDict = Dict[str, Set[ImportVar]] +ImportDict = Dict[str, List[ImportVar]] def merge_imports(*imports) -> ImportDict: @@ -19,9 +19,8 @@ def merge_imports(*imports) -> ImportDict: Returns: The merged import dicts. """ - all_imports = defaultdict(set) + all_imports = defaultdict(list) for import_dict in imports: for lib, fields in import_dict.items(): - for field in fields: - all_imports[lib].add(field) + all_imports[lib].extend(fields) return all_imports diff --git a/tests/compiler/test_compiler.py b/tests/compiler/test_compiler.py index 5329b9778..e8423d97f 100644 --- a/tests/compiler/test_compiler.py +++ b/tests/compiler/test_compiler.py @@ -1,5 +1,5 @@ import os -from typing import List, Set +from typing import List import pytest @@ -12,28 +12,28 @@ from reflex.vars import ImportVar "fields,test_default,test_rest", [ ( - {ImportVar(tag="axios", is_default=True)}, + [ImportVar(tag="axios", is_default=True)], "axios", - set(), + [], ), ( - {ImportVar(tag="foo"), ImportVar(tag="bar")}, + [ImportVar(tag="foo"), ImportVar(tag="bar")], "", - {"foo", "bar"}, + ["bar", "foo"], ), ( - { + [ ImportVar(tag="axios", is_default=True), ImportVar(tag="foo"), ImportVar(tag="bar"), - }, + ], "axios", - {"foo", "bar"}, + ["bar", "foo"], ), ], ) def test_compile_import_statement( - fields: Set[ImportVar], test_default: str, test_rest: str + fields: List[ImportVar], test_default: str, test_rest: str ): """Test the compile_import_statement function. @@ -44,7 +44,7 @@ def test_compile_import_statement( """ default, rest = utils.compile_import_statement(fields) assert default == test_default - assert rest == test_rest + assert sorted(rest) == test_rest @pytest.mark.parametrize( @@ -52,43 +52,43 @@ def test_compile_import_statement( [ ({}, []), ( - {"axios": {ImportVar(tag="axios", is_default=True)}}, - [{"lib": "axios", "default": "axios", "rest": set()}], + {"axios": [ImportVar(tag="axios", is_default=True)]}, + [{"lib": "axios", "default": "axios", "rest": []}], ), ( - {"axios": {ImportVar(tag="foo"), ImportVar(tag="bar")}}, - [{"lib": "axios", "default": "", "rest": {"foo", "bar"}}], + {"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]}, + [{"lib": "axios", "default": "", "rest": ["bar", "foo"]}], ), ( { - "axios": { + "axios": [ ImportVar(tag="axios", is_default=True), ImportVar(tag="foo"), ImportVar(tag="bar"), - }, - "react": {ImportVar(tag="react", is_default=True)}, + ], + "react": [ImportVar(tag="react", is_default=True)], }, [ - {"lib": "axios", "default": "axios", "rest": {"foo", "bar"}}, - {"lib": "react", "default": "react", "rest": set()}, + {"lib": "axios", "default": "axios", "rest": ["bar", "foo"]}, + {"lib": "react", "default": "react", "rest": []}, ], ), ( - {"": {ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")}}, + {"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]}, [ - {"lib": "lib1.js", "default": "", "rest": set()}, - {"lib": "lib2.js", "default": "", "rest": set()}, + {"lib": "lib1.js", "default": "", "rest": []}, + {"lib": "lib2.js", "default": "", "rest": []}, ], ), ( { - "": {ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")}, - "axios": {ImportVar(tag="axios", is_default=True)}, + "": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")], + "axios": [ImportVar(tag="axios", is_default=True)], }, [ - {"lib": "lib1.js", "default": "", "rest": set()}, - {"lib": "lib2.js", "default": "", "rest": set()}, - {"lib": "axios", "default": "axios", "rest": set()}, + {"lib": "lib1.js", "default": "", "rest": []}, + {"lib": "lib2.js", "default": "", "rest": []}, + {"lib": "axios", "default": "axios", "rest": []}, ], ), ], @@ -104,7 +104,7 @@ def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict] for import_dict, test_dict in zip(imports, test_dicts): assert import_dict["lib"] == test_dict["lib"] assert import_dict["default"] == test_dict["default"] - assert import_dict["rest"] == test_dict["rest"] + assert sorted(import_dict["rest"]) == test_dict["rest"] # type: ignore def test_compile_stylesheets(tmp_path, mocker): diff --git a/tests/components/test_component.py b/tests/components/test_component.py index f2f481e29..abf6c5a06 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -44,7 +44,7 @@ def component1() -> Type[Component]: number: Var[int] def _get_imports(self) -> imports.ImportDict: - return {"react": {ImportVar(tag="Component")}} + return {"react": [ImportVar(tag="Component")]} def _get_custom_code(self) -> str: return "console.log('component1')" @@ -77,7 +77,7 @@ def component2() -> Type[Component]: } def _get_imports(self) -> imports.ImportDict: - return {"react-redux": {ImportVar(tag="connect")}} + return {"react-redux": [ImportVar(tag="connect")]} def _get_custom_code(self) -> str: return "console.log('component2')" @@ -268,10 +268,10 @@ def test_get_imports(component1, component2): """ c1 = component1.create() c2 = component2.create(c1) - assert c1.get_imports() == {"react": {ImportVar(tag="Component")}} + assert c1.get_imports() == {"react": [ImportVar(tag="Component")]} assert c2.get_imports() == { - "react-redux": {ImportVar(tag="connect")}, - "react": {ImportVar(tag="Component")}, + "react-redux": [ImportVar(tag="connect")], + "react": [ImportVar(tag="Component")], } @@ -469,7 +469,7 @@ def test_custom_component_wrapper(): assert len(ccomponent.children) == 1 assert isinstance(ccomponent.children[0], rx.Text) - component = ccomponent.get_component() + component = ccomponent.get_component(ccomponent) assert isinstance(component, Box)