Pass static checks

This commit is contained in:
Masen Furer 2024-04-30 12:24:36 -07:00
parent 35252464a0
commit 3423fec2a6
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
12 changed files with 199 additions and 119 deletions

View File

@ -627,7 +627,7 @@ class App(Base):
Example: Example:
>>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"}) >>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"})
""" """
page_imports = [i.package for i in imports.collapse().values() if i.install] page_imports = [i.package for i in imports if i.install and i.package]
frontend_packages = get_config().frontend_packages frontend_packages = get_config().frontend_packages
_frontend_packages = [] _frontend_packages = []
for package in frontend_packages: for package in frontend_packages:
@ -643,7 +643,7 @@ class App(Base):
continue continue
_frontend_packages.append(package) _frontend_packages.append(package)
page_imports.extend(_frontend_packages) page_imports.extend(_frontend_packages)
prerequisites.install_frontend_packages(page_imports, get_config()) prerequisites.install_frontend_packages(set(page_imports), get_config())
def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component: def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
for component in tuple(app_wrappers.values()): for component in tuple(app_wrappers.values()):

View File

@ -188,7 +188,7 @@ def _compile_component(component: Component) -> str:
def _compile_components( def _compile_components(
components: set[CustomComponent], components: set[CustomComponent],
) -> tuple[str, Dict[str, list[ImportVar]]]: ) -> tuple[str, ImportList]:
"""Compile the components. """Compile the components.
Args: Args:

View File

@ -247,7 +247,7 @@ def compile_custom_component(
render = component.get_component(component) render = component.get_component(component)
# Get the imports. # Get the imports.
component_library_name = format.format_library_name(component.library) component_library_name = format.format_library_name(component.library or "")
_imports = imports.ImportList( _imports = imports.ImportList(
imp imp
for imp in render._get_all_imports() for imp in render._get_all_imports()

View File

@ -35,7 +35,7 @@ class ChakraComponent(Component):
@classmethod @classmethod
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _get_dependencies_imports(cls) -> imports.ImportList: def _get_dependencies_imports(cls) -> List[imports.ImportVar]:
"""Get the imports from lib_dependencies for installing. """Get the imports from lib_dependencies for installing.
Returns: Returns:
@ -67,13 +67,21 @@ class ChakraProvider(ChakraComponent):
theme=Var.create("extendTheme(theme)", _var_is_local=False), theme=Var.create("extendTheme(theme)", _var_is_local=False),
) )
def _get_imports(self) -> imports.ImportDict: def _get_imports_list(self) -> List[imports.ImportVar]:
_imports = super()._get_imports() _imports = super()._get_imports_list()
_imports.setdefault(self.__fields__["library"].default, []).append( _imports.extend(
imports.ImportVar(tag="extendTheme", is_default=False), [
) imports.ImportVar(
_imports.setdefault("/utils/theme.js", []).append( package=self.__fields__["library"].default,
imports.ImportVar(tag="theme", is_default=True), tag="extendTheme",
is_default=False,
),
imports.ImportVar(
package="/utils/theme.js",
tag="theme",
is_default=True,
),
],
) )
return _imports return _imports

View File

@ -1026,7 +1026,7 @@ class Component(BaseComponent, ABC):
or format.format_library_name(dep or "") in self.transpile_packages or format.format_library_name(dep or "") in self.transpile_packages
) )
def _get_dependencies_imports(self) -> imports.ImportList: def _get_dependencies_imports(self) -> List[ImportVar]:
"""Get the imports from lib_dependencies for installing. """Get the imports from lib_dependencies for installing.
Returns: Returns:
@ -1073,7 +1073,11 @@ class Component(BaseComponent, ABC):
) )
user_hooks = self._get_hooks() user_hooks = self._get_hooks()
if user_hooks is not None and isinstance(user_hooks, Var): if (
user_hooks is not None
and isinstance(user_hooks, Var)
and user_hooks._var_data is not None
):
_imports.extend(user_hooks._var_data.imports) _imports.extend(user_hooks._var_data.imports)
return _imports return _imports
@ -1086,7 +1090,7 @@ class Component(BaseComponent, ABC):
""" """
return {} return {}
def _get_imports_list(self) -> imports.ImportList: def _get_imports_list(self) -> List[ImportVar]:
"""Internal method to get the imports as a list. """Internal method to get the imports as a list.
Returns: Returns:
@ -1117,7 +1121,7 @@ class Component(BaseComponent, ABC):
# Get static imports required for event processing. # Get static imports required for event processing.
if self.event_triggers: if self.event_triggers:
_imports.append(Imports.EVENTS) _imports.extend(Imports.EVENTS)
# Collect imports from Vars used directly by this component. # Collect imports from Vars used directly by this component.
for var in self._get_vars(): for var in self._get_vars():

View File

@ -51,11 +51,14 @@ has_too_many_connection_errors: Var = Var.create_safe(
class WebsocketTargetURL(Bare): class WebsocketTargetURL(Bare):
"""A component that renders the websocket target URL.""" """A component that renders the websocket target URL."""
def _get_imports(self) -> imports.ImportDict: def _get_imports_list(self) -> list[imports.ImportVar]:
return { return [
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")], imports.ImportVar(
"/env.json": [imports.ImportVar(tag="env", is_default=True)], library=f"/{Dirs.STATE_PATH}",
} tag="getBackendURL",
),
imports.ImportVar(library="/env.json", tag="env", is_default=True),
]
@classmethod @classmethod
def create(cls) -> Component: def create(cls) -> Component:

View File

@ -7,7 +7,6 @@ from functools import lru_cache
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
from reflex.compiler import utils
from reflex.components.component import Component, CustomComponent from reflex.components.component import Component, CustomComponent
from reflex.components.radix.themes.layout.list import ( from reflex.components.radix.themes.layout.list import (
ListItem, ListItem,
@ -154,47 +153,53 @@ class Markdown(Component):
return custom_components return custom_components
def _get_imports(self) -> imports.ImportDict: def _get_imports_list(self) -> list[imports.ImportVar]:
# Import here to avoid circular imports. # Import here to avoid circular imports.
from reflex.components.datadisplay.code import CodeBlock from reflex.components.datadisplay.code import CodeBlock
from reflex.components.radix.themes.typography.code import Code from reflex.components.radix.themes.typography.code import Code
imports = super()._get_imports() _imports = super()._get_imports_list()
# Special markdown imports. # Special markdown imports.
imports.update( _imports.extend(
{ [
"": [ImportVar(tag="katex/dist/katex.min.css")], ImportVar(library="", tag="katex/dist/katex.min.css"),
"remark-math@5.1.1": [ ImportVar(
ImportVar(tag=_REMARK_MATH._var_name, is_default=True) package="remark-math@5.1.1",
], tag=_REMARK_MATH._var_name,
"remark-gfm@3.0.1": [ is_default=True,
ImportVar(tag=_REMARK_GFM._var_name, is_default=True) ),
], ImportVar(
"remark-unwrap-images@4.0.0": [ package="remark-gfm@3.0.1",
ImportVar(tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True) tag=_REMARK_GFM._var_name,
], is_default=True,
"rehype-katex@6.0.3": [ ),
ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True) ImportVar(
], package="remark-unwrap-images@4.0.0",
"rehype-raw@6.1.1": [ tag=_REMARK_UNWRAP_IMAGES._var_name,
ImportVar(tag=_REHYPE_RAW._var_name, is_default=True) is_default=True,
], ),
} ImportVar(
package="remark-katex@6.0.3",
tag=_REHYPE_KATEX._var_name,
is_default=True,
),
ImportVar(
package="rehype-raw@6.1.1",
tag=_REHYPE_RAW._var_name,
is_default=True,
),
]
) )
# Get the imports for each component. # Get the imports for each component.
for component in self.component_map.values(): for component in self.component_map.values():
imports = utils.merge_imports( _imports.extend(component(_MOCK_ARG)._get_all_imports())
imports, component(_MOCK_ARG)._get_all_imports()
)
# Get the imports for the code components. # Get the imports for the code components.
imports = utils.merge_imports( _imports.extend(CodeBlock.create(theme="light")._get_all_imports())
imports, CodeBlock.create(theme="light")._get_imports() _imports.extend(Code.create()._get_all_imports())
) return _imports
imports = utils.merge_imports(imports, Code.create()._get_imports())
return imports
def get_component(self, tag: str, **props) -> Component: def get_component(self, tag: str, **props) -> Component:
"""Get the component for a tag and props. """Get the component for a tag and props.

View File

@ -243,13 +243,11 @@ class ThemePanel(RadixThemesComponent):
# Whether the panel is open. Defaults to False. # Whether the panel is open. Defaults to False.
default_open: Var[bool] default_open: Var[bool]
def _get_imports(self) -> dict[str, list[imports.ImportVar]]: def _get_imports_list(self) -> list[imports.ImportVar]:
return imports.merge_imports( return [
super()._get_imports(), *super()._get_imports_list(),
{ imports.ImportVar(package="react", tag="useEffect"),
"react": [imports.ImportVar(tag="useEffect")], ]
},
)
def _get_hooks(self) -> str | None: def _get_hooks(self) -> str | None:
# The panel freezes the tab if the user color preference differs from the # The panel freezes the tab if the user color preference differs from the

View File

@ -102,11 +102,13 @@ class ComponentName(Enum):
class Imports(SimpleNamespace): class Imports(SimpleNamespace):
"""Common sets of import vars.""" """Common sets of import vars."""
EVENTS: ImportList = [ EVENTS: ImportList = ImportList(
ImportVar(package="react", tag="useContext"), [
ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"), ImportVar(package="react", tag="useContext"),
ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT), ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"),
] ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT),
]
)
class Hooks(SimpleNamespace): class Hooks(SimpleNamespace):

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional
from reflex.base import Base from reflex.base import Base
from reflex.constants.installer import PackageJson from reflex.constants.installer import PackageJson
@ -91,6 +91,15 @@ class ImportVar(Base):
package: Optional[str] = None, package: Optional[str] = None,
**kwargs, **kwargs,
): ):
"""Create a new ImportVar.
Args:
package: The package to install for this import.
**kwargs: The import var fields.
Raises:
ValueError: If the package is provided with library or version.
"""
if package is not None: if package is not None:
if ( if (
kwargs.get("library", None) is not None kwargs.get("library", None) is not None
@ -128,8 +137,8 @@ class ImportVar(Base):
return self.tag or "" return self.tag or ""
@property @property
def package(self) -> str: def package(self) -> str | None:
"""The package to install for this import """The package to install for this import.
Returns: Returns:
The library name and (optional) version to be installed by npm/bun. The library name and (optional) version to be installed by npm/bun.
@ -150,10 +159,6 @@ class ImportVar(Base):
self.tag, self.tag,
self.is_default, self.is_default,
self.alias, self.alias,
# These do not fundamentally change the import in any way
# self.install,
# self.render,
# self.transpile,
) )
) )
@ -183,16 +188,22 @@ class ImportVar(Base):
Returns: Returns:
The collapsed import var with sticky props perserved. The collapsed import var with sticky props perserved.
Raises:
ValueError: If the two import vars have conflicting properties.
""" """
if self != other_import_var: if self != other_import_var:
raise ValueError("Cannot collapse two import vars with different hashes") raise ValueError("Cannot collapse two import vars with different hashes")
if self.version is not None and other_import_var.version is not None: if (
if self.version != other_import_var.version: self.version is not None
raise ValueError( and other_import_var.version is not None
"Cannot collapse two import vars with conflicting version specifiers: " and self.version != other_import_var.version
f"{self} {other_import_var}" ):
) raise ValueError(
"Cannot collapse two import vars with conflicting version specifiers: "
f"{self} {other_import_var}"
)
return type(self)( return type(self)(
library=self.library, library=self.library,
@ -210,6 +221,15 @@ class ImportList(List[ImportVar]):
"""A list of import vars.""" """A list of import vars."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Create a new ImportList (wrapper over `list`).
Any items that are not already `ImportVar` will be assumed as dicts to convert
into an ImportVar.
Args:
*args: The args to pass to list.__init__
**kwargs: The kwargs to pass to list.__init__
"""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
for ix, value in enumerate(self): for ix, value in enumerate(self):
if not isinstance(value, ImportVar): if not isinstance(value, ImportVar):
@ -217,26 +237,41 @@ class ImportList(List[ImportVar]):
self[ix] = ImportVar(**value) self[ix] = ImportVar(**value)
@classmethod @classmethod
def from_import_dict(cls, import_dict: ImportDict) -> ImportList: def from_import_dict(
return [ cls, import_dict: ImportDict | Dict[str, set[ImportVar]]
) -> ImportList:
"""Create an import list from an import dict.
Args:
import_dict: The import dict to convert.
Returns:
The import list.
"""
return cls(
ImportVar(package=lib, **imp.dict()) ImportVar(package=lib, **imp.dict())
for lib, imps in import_dict.items() for lib, imps in import_dict.items()
for imp in imps for imp in imps
] )
def collapse(self) -> ImportDict: def collapse(self) -> ImportDict:
"""When collapsing an import list, prefer packages with version specifiers.""" """When collapsing an import list, prefer packages with version specifiers.
collapsed = {}
Returns:
The collapsed import dict ({library_name: [import_var1, ...]}).
"""
collapsed: dict[str, dict[ImportVar, ImportVar]] = {}
for imp in self: for imp in self:
collapsed.setdefault(imp.library, {}) lib = imp.library or ""
if imp in collapsed[imp.library]: collapsed.setdefault(lib, {})
if imp in collapsed[lib]:
# Need to check if the current import has any special properties that need to # Need to check if the current import has any special properties that need to
# be preserved, like the version specifier, install, or transpile. # be preserved, like the version specifier, install, or transpile.
existing_imp = collapsed[imp.library][imp] existing_imp = collapsed[lib][imp]
collapsed[imp.library][imp] = existing_imp.collapse(imp) collapsed[lib][imp] = existing_imp.collapse(imp)
else: else:
collapsed[imp.library][imp] = imp collapsed[lib][imp] = imp
return {lib: set(imps) for lib, imps in collapsed.items()} return {lib: list(set(imps)) for lib, imps in collapsed.items()}
ImportDict = Dict[str, Set[ImportVar]] ImportDict = Dict[str, List[ImportVar]]

View File

@ -34,7 +34,7 @@ from typing import (
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.utils import console, format, imports, serializers, types from reflex.utils import console, format, serializers, types
# This module used to export ImportVar itself, so we still import it for export here # This module used to export ImportVar itself, so we still import it for export here
from reflex.utils.imports import ImportDict, ImportList, ImportVar from reflex.utils.imports import ImportDict, ImportList, ImportVar
@ -116,7 +116,7 @@ class VarData(Base):
state: str = "" state: str = ""
# Imports needed to render this var # Imports needed to render this var
imports: ImportList = [] imports: ImportList = ImportList()
# Hooks that need to be present in the component to render this var # Hooks that need to be present in the component to render this var
hooks: Dict[str, None] = {} hooks: Dict[str, None] = {}
@ -126,7 +126,24 @@ class VarData(Base):
# segments. # segments.
interpolations: List[Tuple[int, int]] = [] interpolations: List[Tuple[int, int]] = []
def __init__(self, imports: ImportDict | ImportList = None, **kwargs): def __init__(
self,
imports: ImportList
| List[ImportVar | Dict[str, Optional[Union[str, bool]]]]
| ImportDict
| Dict[str, set[ImportVar]]
| None = None,
**kwargs,
):
"""Initialize the VarData.
If imports is an ImportDict it will be converted to an ImportList and a
deprecation warning will be displayed.
Args:
imports: The imports needed to render this var.
**kwargs: Additional fields to set.
"""
if isinstance(imports, dict): if isinstance(imports, dict):
imports = ImportList.from_import_dict(imports) imports = ImportList.from_import_dict(imports)
console.deprecate( console.deprecate(
@ -135,9 +152,12 @@ class VarData(Base):
deprecation_version="0.5.0", deprecation_version="0.5.0",
removal_version="0.6.0", removal_version="0.6.0",
) )
elif imports is None: else:
imports = [] imports = ImportList(imports or [])
super().__init__(imports=imports, **kwargs) super().__init__(
imports=imports, # type: ignore
**kwargs,
)
@classmethod @classmethod
def merge(cls, *others: VarData | None) -> VarData | None: def merge(cls, *others: VarData | None) -> VarData | None:
@ -150,7 +170,7 @@ class VarData(Base):
The merged var data object. The merged var data object.
""" """
state = "" state = ""
_imports = [] _imports = ImportList()
hooks = {} hooks = {}
interpolations = [] interpolations = []
for var_data in others: for var_data in others:
@ -1059,11 +1079,12 @@ class Var:
",", other, fn="spreadArraysOrObjects", flip=flip ",", other, fn="spreadArraysOrObjects", flip=flip
)._replace( )._replace(
merge_var_data=VarData( merge_var_data=VarData(
imports={ imports=[
f"/{constants.Dirs.STATE_PATH}": [ ImportVar(
ImportVar(tag="spreadArraysOrObjects") package=f"/{constants.Dirs.STATE_PATH}",
] tag="spreadArraysOrObjects",
}, ),
],
), ),
) )
return self.operation("+", other, flip=flip) return self.operation("+", other, flip=flip)
@ -1612,11 +1633,11 @@ class Var:
v2._var_data, v2._var_data,
step._var_data, step._var_data,
VarData( VarData(
imports={ imports=[
"/utils/helpers/range.js": [ ImportVar(
ImportVar(tag="range", is_default=True), package="/utils/helpers/range", tag="range", is_default=True
], ),
}, ]
), ),
), ),
) )
@ -1644,9 +1665,9 @@ class Var:
_var_is_string=False, _var_is_string=False,
_var_full_name_needs_state_prefix=False, _var_full_name_needs_state_prefix=False,
merge_var_data=VarData( merge_var_data=VarData(
imports={ imports=[
f"/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")], ImportVar(package=f"/{constants.Dirs.STATE_PATH}", tag="refs")
}, ],
), ),
) )
@ -1684,10 +1705,14 @@ class Var:
format.format_state_name(state_name) format.format_state_name(state_name)
): None ): None
}, },
imports={ imports=ImportList(
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")], [
"react": [ImportVar(tag="useContext")], ImportVar(
}, package=f"/{constants.Dirs.CONTEXTS_PATH}", tag="StateContexts"
),
ImportVar(package="react", tag="useContext"),
]
),
) )
self._var_data = VarData.merge(self._var_data, new_var_data) self._var_data = VarData.merge(self._var_data, new_var_data)
self._var_full_name_needs_state_prefix = True self._var_full_name_needs_state_prefix = True

View File

@ -9,14 +9,14 @@ from reflex.components.radix.themes.typography.text import Text
def test_websocket_target_url(): def test_websocket_target_url():
url = WebsocketTargetURL.create() url = WebsocketTargetURL.create()
_imports = url._get_all_imports(collapse=True) _imports = url._get_all_imports()
assert list(_imports.keys()) == ["/utils/state", "/env.json"] assert [i.library for i in _imports] == ["/utils/state", "/env.json"]
def test_connection_banner(): def test_connection_banner():
banner = ConnectionBanner.create() banner = ConnectionBanner.create()
_imports = banner._get_all_imports(collapse=True) _imports = banner._get_all_imports()
assert list(_imports.keys()) == [ assert [i.library for i in _imports] == [
"react", "react",
"/utils/context", "/utils/context",
"/utils/state", "/utils/state",
@ -31,8 +31,8 @@ def test_connection_banner():
def test_connection_modal(): def test_connection_modal():
modal = ConnectionModal.create() modal = ConnectionModal.create()
_imports = modal._get_all_imports(collapse=True) _imports = modal._get_all_imports()
assert list(_imports.keys()) == [ assert [i.library for i in _imports] == [
"react", "react",
"/utils/context", "/utils/context",
"/utils/state", "/utils/state",
@ -48,4 +48,4 @@ def test_connection_modal():
def test_connection_pulser(): def test_connection_pulser():
pulser = ConnectionPulser.create() pulser = ConnectionPulser.create()
_custom_code = pulser._get_all_custom_code() _custom_code = pulser._get_all_custom_code()
_imports = pulser._get_all_imports(collapse=True) _imports = pulser._get_all_imports()