Use singleton app provider to speed up compiles (#2172)

This commit is contained in:
Nikhil Rao 2023-11-20 18:11:24 -08:00 committed by GitHub
parent b5f6ab3a82
commit e9437ad941
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 151 additions and 102 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import contextlib
import copy
import functools
import os
from multiprocessing.pool import ThreadPool
@ -589,9 +590,9 @@ class App(Base):
def _app_root(self, app_wrappers):
order = sorted(app_wrappers, key=lambda k: k[0], reverse=True)
root = parent = app_wrappers[order[0]]
root = parent = copy.deepcopy(app_wrappers[order[0]])
for key in order[1:]:
child = app_wrappers[key]
child = copy.deepcopy(app_wrappers[key])
parent.children.append(child)
parent = child
return root

View File

@ -15,15 +15,15 @@ from reflex.vars import ImportVar
# Imports to be included in every Reflex app.
DEFAULT_IMPORTS: imports.ImportDict = {
"react": {
"react": [
ImportVar(tag="Fragment"),
ImportVar(tag="useEffect"),
ImportVar(tag="useRef"),
ImportVar(tag="useState"),
ImportVar(tag="useContext"),
},
"next/router": {ImportVar(tag="useRouter")},
f"/{constants.Dirs.STATE_PATH}": {
],
"next/router": [ImportVar(tag="useRouter")],
f"/{constants.Dirs.STATE_PATH}": [
ImportVar(tag="uploadFiles"),
ImportVar(tag="Event"),
ImportVar(tag="isTrue"),
@ -34,17 +34,17 @@ DEFAULT_IMPORTS: imports.ImportDict = {
ImportVar(tag="getRefValues"),
ImportVar(tag="getAllLocalStorageItems"),
ImportVar(tag="useEventLoop"),
},
"/utils/context.js": {
],
"/utils/context.js": [
ImportVar(tag="EventLoopContext"),
ImportVar(tag="initialEvents"),
ImportVar(tag="StateContext"),
ImportVar(tag="ColorModeContext"),
},
"/utils/helpers/range.js": {
],
"/utils/helpers/range.js": [
ImportVar(tag="range", is_default=True),
},
"": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)},
],
"": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)],
}
@ -129,6 +129,7 @@ def _compile_page(
"""
# Merge the default imports with the app-specific imports.
imports = utils.merge_imports(DEFAULT_IMPORTS, component.get_imports())
imports = {k: list(set(v)) for k, v in imports.items()}
utils.validate_imports(imports)
imports = utils.compile_imports(imports)
@ -216,8 +217,8 @@ def _compile_components(components: set[CustomComponent]) -> str:
The compiled components.
"""
imports = {
"react": {ImportVar(tag="memo")},
f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="E"), ImportVar(tag="isTrue")},
"react": [ImportVar(tag="memo")],
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")],
}
component_renders = []

View File

@ -30,7 +30,7 @@ from reflex.vars import ImportVar
merge_imports = imports.merge_imports
def compile_import_statement(fields: set[ImportVar]) -> tuple[str, set[str]]:
def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
"""Compile an import statement.
Args:
@ -42,17 +42,17 @@ def compile_import_statement(fields: set[ImportVar]) -> tuple[str, set[str]]:
rest: rest of libraries. When install "import {rest1, rest2} from library"
"""
# ignore the ImportVar fields with render=False during compilation
fields = {field for field in fields if field.render}
fields_set = {field for field in fields if field.render}
# Check for default imports.
defaults = {field for field in fields if field.is_default}
defaults = {field for field in fields_set if field.is_default}
assert len(defaults) < 2
# Get the default import, and the specific imports.
default = next(iter({field.name for field in defaults}), "")
rest = {field.name for field in fields - defaults}
rest = {field.name for field in fields_set - defaults}
return default, rest
return default, list(rest)
def validate_imports(imports: imports.ImportDict):
@ -109,7 +109,7 @@ def compile_imports(imports: imports.ImportDict) -> list[dict]:
return import_dicts
def get_import_dict(lib: str, default: str = "", rest: set[str] | None = None) -> dict:
def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) -> dict:
"""Get dictionary for import template.
Args:
@ -123,7 +123,7 @@ def get_import_dict(lib: str, default: str = "", rest: set[str] | None = None) -
return {
"lib": lib,
"default": default,
"rest": rest if rest else set(),
"rest": rest if rest else [],
}
@ -235,7 +235,7 @@ def compile_custom_component(
A tuple of the compiled component and the imports required by the component.
"""
# Render the component.
render = component.get_component()
render = component.get_component(component)
# Get the imports.
imports = {

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import typing
from abc import ABC
from functools import wraps
from functools import lru_cache, wraps
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from reflex.base import Base
@ -399,6 +399,7 @@ class Component(Base, ABC):
return tag.add_props(**props)
@classmethod
@lru_cache(maxsize=None)
def get_props(cls) -> Set[str]:
"""Get the unique fields for the component.
@ -408,6 +409,7 @@ class Component(Base, ABC):
return set(cls.get_fields()) - set(Component.get_fields())
@classmethod
@lru_cache(maxsize=None)
def get_initial_props(cls) -> Set[str]:
"""Get the initial props to set for the component.
@ -417,6 +419,7 @@ class Component(Base, ABC):
return set()
@classmethod
@lru_cache(maxsize=None)
def get_component_props(cls) -> set[str]:
"""Get the props that expected a component as value.
@ -619,19 +622,17 @@ class Component(Base, ABC):
# Return the dynamic imports
return dynamic_imports
def _get_props_imports(self) -> imports.ImportDict:
def _get_props_imports(self) -> List[str]:
"""Get the imports needed for components props.
Returns:
The imports for the components props of the component.
"""
return imports.merge_imports(
*[
getattr(self, prop).get_imports()
for prop in self.get_component_props()
if getattr(self, prop) is not None
]
)
return [
getattr(self, prop).get_imports()
for prop in self.get_component_props()
if getattr(self, prop) is not None
]
def _get_dependencies_imports(self) -> imports.ImportDict:
"""Get the imports from lib_dependencies for installing.
@ -639,9 +640,9 @@ class Component(Base, ABC):
Returns:
The dependencies imports of the component.
"""
return imports.merge_imports(
{dep: {ImportVar(tag=None, render=False)} for dep in self.lib_dependencies}
)
return {
dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies
}
def _get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component.
@ -654,7 +655,7 @@ class Component(Base, ABC):
_imports[self.library] = {self.import_var}
return imports.merge_imports(
self._get_props_imports(),
*self._get_props_imports(),
self._get_dependencies_imports(),
_imports,
)
@ -804,7 +805,8 @@ class Component(Base, ABC):
alias = self.alias.partition(".")[0] if self.alias else None
return ImportVar(tag=tag, is_default=self.is_default, alias=alias)
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
@staticmethod
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
"""Get the app wrap components for the component.
Returns:
@ -948,7 +950,9 @@ class CustomComponent(Component):
# Avoid adding the same component twice.
if self.tag not in seen:
seen.add(self.tag)
custom_components |= self.get_component().get_custom_components(seen=seen)
custom_components |= self.get_component(self).get_custom_components(
seen=seen
)
return custom_components
def _render(self) -> Tag:
@ -975,6 +979,7 @@ class CustomComponent(Component):
for name, prop in self.props.items()
]
@lru_cache(maxsize=None) # noqa
def get_component(self) -> Component:
"""Render the component.

View File

@ -333,7 +333,8 @@ class DataEditor(NoSSRComponent):
height=props.pop("height", "100%"),
)
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
@staticmethod
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
"""Get the app wrap components for the component.
Returns:

View File

@ -177,9 +177,9 @@ class Editor(NoSSRComponent):
def _get_imports(self):
imports = super()._get_imports()
imports[""] = {
imports[""] = [
ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False)
}
]
return imports
def get_event_triggers(self) -> Dict[str, Any]:

View File

@ -175,5 +175,5 @@ class Upload(Component):
def _get_imports(self) -> imports.ImportDict:
return {
**super()._get_imports(),
f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="upload_files")},
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="upload_files")],
}

View File

@ -53,6 +53,14 @@ class Cond(Component):
)
)
def _get_props_imports(self):
"""Get the imports needed for components props.
Returns:
The imports for the components props of the component.
"""
return []
def _render(self) -> Tag:
return CondTag(
cond=self.cond,

View File

@ -1,6 +1,7 @@
"""Components that are based on Chakra-UI."""
from __future__ import annotations
from functools import lru_cache
from typing import List, Literal
from reflex.components.component import Component
@ -17,10 +18,27 @@ class ChakraComponent(Component):
"framer-motion@10.16.4",
]
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
@staticmethod
@lru_cache(maxsize=None)
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
return {
**super()._get_app_wrap_components(),
(60, "ChakraProvider"): ChakraProvider.create(),
(60, "ChakraProvider"): chakra_provider,
}
@classmethod
@lru_cache(maxsize=None)
def _get_dependencies_imports(cls) -> imports.ImportDict:
"""Get the imports from lib_dependencies for installing.
Returns:
The dependencies imports of the component.
"""
return {
dep: [ImportVar(tag=None, render=False)]
for dep in [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
]
}
@ -58,13 +76,13 @@ class ChakraProvider(ChakraComponent):
def _get_imports(self) -> imports.ImportDict:
imports = super()._get_imports()
imports.setdefault(self.__fields__["library"].default, set()).add(
imports.setdefault(self.__fields__["library"].default, []).append(
ImportVar(tag="extendTheme", is_default=False),
)
imports.setdefault("/utils/theme.js", set()).add(
imports.setdefault("/utils/theme.js", []).append(
ImportVar(tag="theme", is_default=True),
)
imports.setdefault(Global.__fields__["library"].default, set()).add(
imports.setdefault(Global.__fields__["library"].default, []).append(
ImportVar(tag="css", is_default=False),
)
return imports
@ -82,12 +100,17 @@ const GlobalStyles = css`
`;
"""
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
@staticmethod
@lru_cache(maxsize=None)
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
return {
(50, "ChakraColorModeProvider"): ChakraColorModeProvider.create(),
(50, "ChakraColorModeProvider"): chakra_color_mode_provider,
}
chakra_provider = ChakraProvider.create()
class ChakraColorModeProvider(Component):
"""Next-themes integration for chakra colorModeProvider."""
@ -96,6 +119,9 @@ class ChakraColorModeProvider(Component):
is_default = True
chakra_color_mode_provider = ChakraColorModeProvider.create()
LiteralColorScheme = Literal[
"none",
"gray",

View File

@ -7,6 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style
from functools import lru_cache
from typing import List, Literal
from reflex.components.component import Component
from reflex.utils import imports
@ -238,6 +239,8 @@ class ChakraProvider(ChakraComponent):
"""
...
chakra_provider = ChakraProvider.create()
class ChakraColorModeProvider(Component):
@overload
@classmethod
@ -317,6 +320,7 @@ class ChakraColorModeProvider(Component):
"""
...
chakra_color_mode_provider = ChakraColorModeProvider.create()
LiteralColorScheme = Literal[
"none",
"gray",

View File

@ -7,6 +7,8 @@ from reflex.components.navigation.nextlink import NextLink
from reflex.utils import imports
from reflex.vars import BaseVar, Var
next_link = NextLink.create()
class Link(ChakraComponent):
"""Link to another page."""
@ -29,7 +31,7 @@ class Link(ChakraComponent):
is_external: Var[bool]
def _get_imports(self) -> imports.ImportDict:
return {**super()._get_imports(), **NextLink.create()._get_imports()}
return {**super()._get_imports(), **next_link._get_imports()}
@classmethod
def create(cls, *children, **props) -> Component:

View File

@ -13,6 +13,8 @@ from reflex.components.navigation.nextlink import NextLink
from reflex.utils import imports
from reflex.vars import BaseVar, Var
next_link = NextLink.create()
class Link(ChakraComponent):
@overload
@classmethod

View File

@ -28,7 +28,7 @@ class WebsocketTargetURL(Bare):
def _get_imports(self) -> imports.ImportDict:
return {
"/utils/state.js": {ImportVar(tag="getEventURL")},
"/utils/state.js": [ImportVar(tag="getEventURL")],
}
@classmethod

View File

@ -66,9 +66,9 @@ class RadixThemesComponent(Component):
)
return component
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
@staticmethod
def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
return {
**super()._get_app_wrap_components(),
(45, "RadixThemesColorModeProvider"): RadixThemesColorModeProvider.create(),
}
@ -147,7 +147,7 @@ class Theme(RadixThemesComponent):
def _get_imports(self) -> imports.ImportDict:
return {
**super()._get_imports(),
"": {ImportVar(tag="@radix-ui/themes/styles.css", install=False)},
"": [ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
}

View File

@ -151,19 +151,19 @@ class Markdown(Component):
# Special markdown imports.
imports.update(
{
"": {ImportVar(tag="katex/dist/katex.min.css")},
"remark-math@5.1.1": {
"": [ImportVar(tag="katex/dist/katex.min.css")],
"remark-math@5.1.1": [
ImportVar(tag=_REMARK_MATH._var_name, is_default=True)
},
"remark-gfm@3.0.1": {
],
"remark-gfm@3.0.1": [
ImportVar(tag=_REMARK_GFM._var_name, is_default=True)
},
"rehype-katex@6.0.3": {
],
"rehype-katex@6.0.3": [
ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True)
},
"rehype-raw@6.1.1": {
],
"rehype-raw@6.1.1": [
ImportVar(tag=_REHYPE_RAW._var_name, is_default=True)
},
],
}
)

View File

@ -3,11 +3,11 @@
from __future__ import annotations
from collections import defaultdict
from typing import Dict, Set
from typing import Dict, List
from reflex.vars import ImportVar
ImportDict = Dict[str, Set[ImportVar]]
ImportDict = Dict[str, List[ImportVar]]
def merge_imports(*imports) -> ImportDict:
@ -19,9 +19,8 @@ def merge_imports(*imports) -> ImportDict:
Returns:
The merged import dicts.
"""
all_imports = defaultdict(set)
all_imports = defaultdict(list)
for import_dict in imports:
for lib, fields in import_dict.items():
for field in fields:
all_imports[lib].add(field)
all_imports[lib].extend(fields)
return all_imports

View File

@ -1,5 +1,5 @@
import os
from typing import List, Set
from typing import List
import pytest
@ -12,28 +12,28 @@ from reflex.vars import ImportVar
"fields,test_default,test_rest",
[
(
{ImportVar(tag="axios", is_default=True)},
[ImportVar(tag="axios", is_default=True)],
"axios",
set(),
[],
),
(
{ImportVar(tag="foo"), ImportVar(tag="bar")},
[ImportVar(tag="foo"), ImportVar(tag="bar")],
"",
{"foo", "bar"},
["bar", "foo"],
),
(
{
[
ImportVar(tag="axios", is_default=True),
ImportVar(tag="foo"),
ImportVar(tag="bar"),
},
],
"axios",
{"foo", "bar"},
["bar", "foo"],
),
],
)
def test_compile_import_statement(
fields: Set[ImportVar], test_default: str, test_rest: str
fields: List[ImportVar], test_default: str, test_rest: str
):
"""Test the compile_import_statement function.
@ -44,7 +44,7 @@ def test_compile_import_statement(
"""
default, rest = utils.compile_import_statement(fields)
assert default == test_default
assert rest == test_rest
assert sorted(rest) == test_rest
@pytest.mark.parametrize(
@ -52,43 +52,43 @@ def test_compile_import_statement(
[
({}, []),
(
{"axios": {ImportVar(tag="axios", is_default=True)}},
[{"lib": "axios", "default": "axios", "rest": set()}],
{"axios": [ImportVar(tag="axios", is_default=True)]},
[{"lib": "axios", "default": "axios", "rest": []}],
),
(
{"axios": {ImportVar(tag="foo"), ImportVar(tag="bar")}},
[{"lib": "axios", "default": "", "rest": {"foo", "bar"}}],
{"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]},
[{"lib": "axios", "default": "", "rest": ["bar", "foo"]}],
),
(
{
"axios": {
"axios": [
ImportVar(tag="axios", is_default=True),
ImportVar(tag="foo"),
ImportVar(tag="bar"),
},
"react": {ImportVar(tag="react", is_default=True)},
],
"react": [ImportVar(tag="react", is_default=True)],
},
[
{"lib": "axios", "default": "axios", "rest": {"foo", "bar"}},
{"lib": "react", "default": "react", "rest": set()},
{"lib": "axios", "default": "axios", "rest": ["bar", "foo"]},
{"lib": "react", "default": "react", "rest": []},
],
),
(
{"": {ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")}},
{"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]},
[
{"lib": "lib1.js", "default": "", "rest": set()},
{"lib": "lib2.js", "default": "", "rest": set()},
{"lib": "lib1.js", "default": "", "rest": []},
{"lib": "lib2.js", "default": "", "rest": []},
],
),
(
{
"": {ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")},
"axios": {ImportVar(tag="axios", is_default=True)},
"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")],
"axios": [ImportVar(tag="axios", is_default=True)],
},
[
{"lib": "lib1.js", "default": "", "rest": set()},
{"lib": "lib2.js", "default": "", "rest": set()},
{"lib": "axios", "default": "axios", "rest": set()},
{"lib": "lib1.js", "default": "", "rest": []},
{"lib": "lib2.js", "default": "", "rest": []},
{"lib": "axios", "default": "axios", "rest": []},
],
),
],
@ -104,7 +104,7 @@ def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]
for import_dict, test_dict in zip(imports, test_dicts):
assert import_dict["lib"] == test_dict["lib"]
assert import_dict["default"] == test_dict["default"]
assert import_dict["rest"] == test_dict["rest"]
assert sorted(import_dict["rest"]) == test_dict["rest"] # type: ignore
def test_compile_stylesheets(tmp_path, mocker):

View File

@ -44,7 +44,7 @@ def component1() -> Type[Component]:
number: Var[int]
def _get_imports(self) -> imports.ImportDict:
return {"react": {ImportVar(tag="Component")}}
return {"react": [ImportVar(tag="Component")]}
def _get_custom_code(self) -> str:
return "console.log('component1')"
@ -77,7 +77,7 @@ def component2() -> Type[Component]:
}
def _get_imports(self) -> imports.ImportDict:
return {"react-redux": {ImportVar(tag="connect")}}
return {"react-redux": [ImportVar(tag="connect")]}
def _get_custom_code(self) -> str:
return "console.log('component2')"
@ -268,10 +268,10 @@ def test_get_imports(component1, component2):
"""
c1 = component1.create()
c2 = component2.create(c1)
assert c1.get_imports() == {"react": {ImportVar(tag="Component")}}
assert c1.get_imports() == {"react": [ImportVar(tag="Component")]}
assert c2.get_imports() == {
"react-redux": {ImportVar(tag="connect")},
"react": {ImportVar(tag="Component")},
"react-redux": [ImportVar(tag="connect")],
"react": [ImportVar(tag="Component")],
}
@ -469,7 +469,7 @@ def test_custom_component_wrapper():
assert len(ccomponent.children) == 1
assert isinstance(ccomponent.children[0], rx.Text)
component = ccomponent.get_component()
component = ccomponent.get_component(ccomponent)
assert isinstance(component, Box)