Pass static checks

This commit is contained in:
Masen Furer 2024-04-30 12:24:36 -07:00
parent 35252464a0
commit 3423fec2a6
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
12 changed files with 199 additions and 119 deletions
reflex
tests/components/core

View File

@ -627,7 +627,7 @@ class App(Base):
Example:
>>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"})
"""
page_imports = [i.package for i in imports.collapse().values() if i.install]
page_imports = [i.package for i in imports if i.install and i.package]
frontend_packages = get_config().frontend_packages
_frontend_packages = []
for package in frontend_packages:
@ -643,7 +643,7 @@ class App(Base):
continue
_frontend_packages.append(package)
page_imports.extend(_frontend_packages)
prerequisites.install_frontend_packages(page_imports, get_config())
prerequisites.install_frontend_packages(set(page_imports), get_config())
def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
for component in tuple(app_wrappers.values()):

View File

@ -188,7 +188,7 @@ def _compile_component(component: Component) -> str:
def _compile_components(
components: set[CustomComponent],
) -> tuple[str, Dict[str, list[ImportVar]]]:
) -> tuple[str, ImportList]:
"""Compile the components.
Args:

View File

@ -247,7 +247,7 @@ def compile_custom_component(
render = component.get_component(component)
# Get the imports.
component_library_name = format.format_library_name(component.library)
component_library_name = format.format_library_name(component.library or "")
_imports = imports.ImportList(
imp
for imp in render._get_all_imports()

View File

@ -35,7 +35,7 @@ class ChakraComponent(Component):
@classmethod
@lru_cache(maxsize=None)
def _get_dependencies_imports(cls) -> imports.ImportList:
def _get_dependencies_imports(cls) -> List[imports.ImportVar]:
"""Get the imports from lib_dependencies for installing.
Returns:
@ -67,13 +67,21 @@ class ChakraProvider(ChakraComponent):
theme=Var.create("extendTheme(theme)", _var_is_local=False),
)
def _get_imports(self) -> imports.ImportDict:
_imports = super()._get_imports()
_imports.setdefault(self.__fields__["library"].default, []).append(
imports.ImportVar(tag="extendTheme", is_default=False),
)
_imports.setdefault("/utils/theme.js", []).append(
imports.ImportVar(tag="theme", is_default=True),
def _get_imports_list(self) -> List[imports.ImportVar]:
_imports = super()._get_imports_list()
_imports.extend(
[
imports.ImportVar(
package=self.__fields__["library"].default,
tag="extendTheme",
is_default=False,
),
imports.ImportVar(
package="/utils/theme.js",
tag="theme",
is_default=True,
),
],
)
return _imports

View File

@ -1026,7 +1026,7 @@ class Component(BaseComponent, ABC):
or format.format_library_name(dep or "") in self.transpile_packages
)
def _get_dependencies_imports(self) -> imports.ImportList:
def _get_dependencies_imports(self) -> List[ImportVar]:
"""Get the imports from lib_dependencies for installing.
Returns:
@ -1073,7 +1073,11 @@ class Component(BaseComponent, ABC):
)
user_hooks = self._get_hooks()
if user_hooks is not None and isinstance(user_hooks, Var):
if (
user_hooks is not None
and isinstance(user_hooks, Var)
and user_hooks._var_data is not None
):
_imports.extend(user_hooks._var_data.imports)
return _imports
@ -1086,7 +1090,7 @@ class Component(BaseComponent, ABC):
"""
return {}
def _get_imports_list(self) -> imports.ImportList:
def _get_imports_list(self) -> List[ImportVar]:
"""Internal method to get the imports as a list.
Returns:
@ -1117,7 +1121,7 @@ class Component(BaseComponent, ABC):
# Get static imports required for event processing.
if self.event_triggers:
_imports.append(Imports.EVENTS)
_imports.extend(Imports.EVENTS)
# Collect imports from Vars used directly by this component.
for var in self._get_vars():

View File

@ -51,11 +51,14 @@ has_too_many_connection_errors: Var = Var.create_safe(
class WebsocketTargetURL(Bare):
"""A component that renders the websocket target URL."""
def _get_imports(self) -> imports.ImportDict:
return {
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
"/env.json": [imports.ImportVar(tag="env", is_default=True)],
}
def _get_imports_list(self) -> list[imports.ImportVar]:
return [
imports.ImportVar(
library=f"/{Dirs.STATE_PATH}",
tag="getBackendURL",
),
imports.ImportVar(library="/env.json", tag="env", is_default=True),
]
@classmethod
def create(cls) -> Component:

View File

@ -7,7 +7,6 @@ from functools import lru_cache
from hashlib import md5
from typing import Any, Callable, Dict, Union
from reflex.compiler import utils
from reflex.components.component import Component, CustomComponent
from reflex.components.radix.themes.layout.list import (
ListItem,
@ -154,47 +153,53 @@ class Markdown(Component):
return custom_components
def _get_imports(self) -> imports.ImportDict:
def _get_imports_list(self) -> list[imports.ImportVar]:
# Import here to avoid circular imports.
from reflex.components.datadisplay.code import CodeBlock
from reflex.components.radix.themes.typography.code import Code
imports = super()._get_imports()
_imports = super()._get_imports_list()
# Special markdown imports.
imports.update(
{
"": [ImportVar(tag="katex/dist/katex.min.css")],
"remark-math@5.1.1": [
ImportVar(tag=_REMARK_MATH._var_name, is_default=True)
],
"remark-gfm@3.0.1": [
ImportVar(tag=_REMARK_GFM._var_name, is_default=True)
],
"remark-unwrap-images@4.0.0": [
ImportVar(tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True)
],
"rehype-katex@6.0.3": [
ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True)
],
"rehype-raw@6.1.1": [
ImportVar(tag=_REHYPE_RAW._var_name, is_default=True)
],
}
_imports.extend(
[
ImportVar(library="", tag="katex/dist/katex.min.css"),
ImportVar(
package="remark-math@5.1.1",
tag=_REMARK_MATH._var_name,
is_default=True,
),
ImportVar(
package="remark-gfm@3.0.1",
tag=_REMARK_GFM._var_name,
is_default=True,
),
ImportVar(
package="remark-unwrap-images@4.0.0",
tag=_REMARK_UNWRAP_IMAGES._var_name,
is_default=True,
),
ImportVar(
package="remark-katex@6.0.3",
tag=_REHYPE_KATEX._var_name,
is_default=True,
),
ImportVar(
package="rehype-raw@6.1.1",
tag=_REHYPE_RAW._var_name,
is_default=True,
),
]
)
# Get the imports for each component.
for component in self.component_map.values():
imports = utils.merge_imports(
imports, component(_MOCK_ARG)._get_all_imports()
)
_imports.extend(component(_MOCK_ARG)._get_all_imports())
# Get the imports for the code components.
imports = utils.merge_imports(
imports, CodeBlock.create(theme="light")._get_imports()
)
imports = utils.merge_imports(imports, Code.create()._get_imports())
return imports
_imports.extend(CodeBlock.create(theme="light")._get_all_imports())
_imports.extend(Code.create()._get_all_imports())
return _imports
def get_component(self, tag: str, **props) -> Component:
"""Get the component for a tag and props.

View File

@ -243,13 +243,11 @@ class ThemePanel(RadixThemesComponent):
# Whether the panel is open. Defaults to False.
default_open: Var[bool]
def _get_imports(self) -> dict[str, list[imports.ImportVar]]:
return imports.merge_imports(
super()._get_imports(),
{
"react": [imports.ImportVar(tag="useEffect")],
},
)
def _get_imports_list(self) -> list[imports.ImportVar]:
return [
*super()._get_imports_list(),
imports.ImportVar(package="react", tag="useEffect"),
]
def _get_hooks(self) -> str | None:
# The panel freezes the tab if the user color preference differs from the

View File

@ -102,11 +102,13 @@ class ComponentName(Enum):
class Imports(SimpleNamespace):
"""Common sets of import vars."""
EVENTS: ImportList = [
ImportVar(package="react", tag="useContext"),
ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"),
ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT),
]
EVENTS: ImportList = ImportList(
[
ImportVar(package="react", tag="useContext"),
ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"),
ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT),
]
)
class Hooks(SimpleNamespace):

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from collections import defaultdict
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional
from reflex.base import Base
from reflex.constants.installer import PackageJson
@ -91,6 +91,15 @@ class ImportVar(Base):
package: Optional[str] = None,
**kwargs,
):
"""Create a new ImportVar.
Args:
package: The package to install for this import.
**kwargs: The import var fields.
Raises:
ValueError: If the package is provided with library or version.
"""
if package is not None:
if (
kwargs.get("library", None) is not None
@ -128,8 +137,8 @@ class ImportVar(Base):
return self.tag or ""
@property
def package(self) -> str:
"""The package to install for this import
def package(self) -> str | None:
"""The package to install for this import.
Returns:
The library name and (optional) version to be installed by npm/bun.
@ -150,10 +159,6 @@ class ImportVar(Base):
self.tag,
self.is_default,
self.alias,
# These do not fundamentally change the import in any way
# self.install,
# self.render,
# self.transpile,
)
)
@ -183,16 +188,22 @@ class ImportVar(Base):
Returns:
The collapsed import var with sticky props perserved.
Raises:
ValueError: If the two import vars have conflicting properties.
"""
if self != other_import_var:
raise ValueError("Cannot collapse two import vars with different hashes")
if self.version is not None and other_import_var.version is not None:
if self.version != other_import_var.version:
raise ValueError(
"Cannot collapse two import vars with conflicting version specifiers: "
f"{self} {other_import_var}"
)
if (
self.version is not None
and other_import_var.version is not None
and self.version != other_import_var.version
):
raise ValueError(
"Cannot collapse two import vars with conflicting version specifiers: "
f"{self} {other_import_var}"
)
return type(self)(
library=self.library,
@ -210,6 +221,15 @@ class ImportList(List[ImportVar]):
"""A list of import vars."""
def __init__(self, *args, **kwargs):
"""Create a new ImportList (wrapper over `list`).
Any items that are not already `ImportVar` will be assumed as dicts to convert
into an ImportVar.
Args:
*args: The args to pass to list.__init__
**kwargs: The kwargs to pass to list.__init__
"""
super().__init__(*args, **kwargs)
for ix, value in enumerate(self):
if not isinstance(value, ImportVar):
@ -217,26 +237,41 @@ class ImportList(List[ImportVar]):
self[ix] = ImportVar(**value)
@classmethod
def from_import_dict(cls, import_dict: ImportDict) -> ImportList:
return [
def from_import_dict(
cls, import_dict: ImportDict | Dict[str, set[ImportVar]]
) -> ImportList:
"""Create an import list from an import dict.
Args:
import_dict: The import dict to convert.
Returns:
The import list.
"""
return cls(
ImportVar(package=lib, **imp.dict())
for lib, imps in import_dict.items()
for imp in imps
]
)
def collapse(self) -> ImportDict:
"""When collapsing an import list, prefer packages with version specifiers."""
collapsed = {}
"""When collapsing an import list, prefer packages with version specifiers.
Returns:
The collapsed import dict ({library_name: [import_var1, ...]}).
"""
collapsed: dict[str, dict[ImportVar, ImportVar]] = {}
for imp in self:
collapsed.setdefault(imp.library, {})
if imp in collapsed[imp.library]:
lib = imp.library or ""
collapsed.setdefault(lib, {})
if imp in collapsed[lib]:
# Need to check if the current import has any special properties that need to
# be preserved, like the version specifier, install, or transpile.
existing_imp = collapsed[imp.library][imp]
collapsed[imp.library][imp] = existing_imp.collapse(imp)
existing_imp = collapsed[lib][imp]
collapsed[lib][imp] = existing_imp.collapse(imp)
else:
collapsed[imp.library][imp] = imp
return {lib: set(imps) for lib, imps in collapsed.items()}
collapsed[lib][imp] = imp
return {lib: list(set(imps)) for lib, imps in collapsed.items()}
ImportDict = Dict[str, Set[ImportVar]]
ImportDict = Dict[str, List[ImportVar]]

View File

@ -34,7 +34,7 @@ from typing import (
from reflex import constants
from reflex.base import Base
from reflex.utils import console, format, imports, serializers, types
from reflex.utils import console, format, serializers, types
# This module used to export ImportVar itself, so we still import it for export here
from reflex.utils.imports import ImportDict, ImportList, ImportVar
@ -116,7 +116,7 @@ class VarData(Base):
state: str = ""
# Imports needed to render this var
imports: ImportList = []
imports: ImportList = ImportList()
# Hooks that need to be present in the component to render this var
hooks: Dict[str, None] = {}
@ -126,7 +126,24 @@ class VarData(Base):
# segments.
interpolations: List[Tuple[int, int]] = []
def __init__(self, imports: ImportDict | ImportList = None, **kwargs):
def __init__(
self,
imports: ImportList
| List[ImportVar | Dict[str, Optional[Union[str, bool]]]]
| ImportDict
| Dict[str, set[ImportVar]]
| None = None,
**kwargs,
):
"""Initialize the VarData.
If imports is an ImportDict it will be converted to an ImportList and a
deprecation warning will be displayed.
Args:
imports: The imports needed to render this var.
**kwargs: Additional fields to set.
"""
if isinstance(imports, dict):
imports = ImportList.from_import_dict(imports)
console.deprecate(
@ -135,9 +152,12 @@ class VarData(Base):
deprecation_version="0.5.0",
removal_version="0.6.0",
)
elif imports is None:
imports = []
super().__init__(imports=imports, **kwargs)
else:
imports = ImportList(imports or [])
super().__init__(
imports=imports, # type: ignore
**kwargs,
)
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None:
@ -150,7 +170,7 @@ class VarData(Base):
The merged var data object.
"""
state = ""
_imports = []
_imports = ImportList()
hooks = {}
interpolations = []
for var_data in others:
@ -1059,11 +1079,12 @@ class Var:
",", other, fn="spreadArraysOrObjects", flip=flip
)._replace(
merge_var_data=VarData(
imports={
f"/{constants.Dirs.STATE_PATH}": [
ImportVar(tag="spreadArraysOrObjects")
]
},
imports=[
ImportVar(
package=f"/{constants.Dirs.STATE_PATH}",
tag="spreadArraysOrObjects",
),
],
),
)
return self.operation("+", other, flip=flip)
@ -1612,11 +1633,11 @@ class Var:
v2._var_data,
step._var_data,
VarData(
imports={
"/utils/helpers/range.js": [
ImportVar(tag="range", is_default=True),
],
},
imports=[
ImportVar(
package="/utils/helpers/range", tag="range", is_default=True
),
]
),
),
)
@ -1644,9 +1665,9 @@ class Var:
_var_is_string=False,
_var_full_name_needs_state_prefix=False,
merge_var_data=VarData(
imports={
f"/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")],
},
imports=[
ImportVar(package=f"/{constants.Dirs.STATE_PATH}", tag="refs")
],
),
)
@ -1684,10 +1705,14 @@ class Var:
format.format_state_name(state_name)
): None
},
imports={
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
"react": [ImportVar(tag="useContext")],
},
imports=ImportList(
[
ImportVar(
package=f"/{constants.Dirs.CONTEXTS_PATH}", tag="StateContexts"
),
ImportVar(package="react", tag="useContext"),
]
),
)
self._var_data = VarData.merge(self._var_data, new_var_data)
self._var_full_name_needs_state_prefix = True

View File

@ -9,14 +9,14 @@ from reflex.components.radix.themes.typography.text import Text
def test_websocket_target_url():
url = WebsocketTargetURL.create()
_imports = url._get_all_imports(collapse=True)
assert list(_imports.keys()) == ["/utils/state", "/env.json"]
_imports = url._get_all_imports()
assert [i.library for i in _imports] == ["/utils/state", "/env.json"]
def test_connection_banner():
banner = ConnectionBanner.create()
_imports = banner._get_all_imports(collapse=True)
assert list(_imports.keys()) == [
_imports = banner._get_all_imports()
assert [i.library for i in _imports] == [
"react",
"/utils/context",
"/utils/state",
@ -31,8 +31,8 @@ def test_connection_banner():
def test_connection_modal():
modal = ConnectionModal.create()
_imports = modal._get_all_imports(collapse=True)
assert list(_imports.keys()) == [
_imports = modal._get_all_imports()
assert [i.library for i in _imports] == [
"react",
"/utils/context",
"/utils/state",
@ -48,4 +48,4 @@ def test_connection_modal():
def test_connection_pulser():
pulser = ConnectionPulser.create()
_custom_code = pulser._get_all_custom_code()
_imports = pulser._get_all_imports(collapse=True)
_imports = pulser._get_all_imports()