use add_imports everywhere (#3448)

This commit is contained in:
Thomas Brandého 2024-06-12 18:26:45 +02:00 committed by GitHub
parent 991f6e0183
commit 462b023019
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 469 additions and 304 deletions

View File

@ -28,13 +28,14 @@ from reflex.components.component import Component, ComponentStyle, CustomCompone
from reflex.state import BaseState, Cookie, LocalStorage
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
from reflex.utils.imports import ImportVar, ParsedImportDict
from reflex.vars import Var
# To re-export this function.
merge_imports = imports.merge_imports
def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
"""Compile an import statement.
Args:
@ -59,7 +60,7 @@ def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list
return default, list(rest)
def validate_imports(import_dict: imports.ImportDict):
def validate_imports(import_dict: ParsedImportDict):
"""Verify that the same Tag is not used in multiple import.
Args:
@ -82,7 +83,7 @@ def validate_imports(import_dict: imports.ImportDict):
used_tags[import_name] = lib
def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
"""Compile an import dict.
Args:
@ -91,7 +92,7 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
Returns:
The list of import dict.
"""
collapsed_import_dict = imports.collapse_imports(import_dict)
collapsed_import_dict: ParsedImportDict = imports.collapse_imports(import_dict)
validate_imports(collapsed_import_dict)
import_dicts = []
for lib, fields in collapsed_import_dict.items():
@ -231,7 +232,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
def compile_custom_component(
component: CustomComponent,
) -> tuple[dict, imports.ImportDict]:
) -> tuple[dict, ParsedImportDict]:
"""Compile a custom component.
Args:
@ -244,7 +245,7 @@ def compile_custom_component(
render = component.get_component(component)
# Get the imports.
imports = {
imports: ParsedImportDict = {
lib: fields
for lib, fields in render._get_all_imports().items()
if lib != component.library

View File

@ -5,14 +5,14 @@ from functools import lru_cache
from typing import List, Literal
from reflex.components.component import Component
from reflex.utils import imports
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
class ChakraComponent(Component):
"""A component that wraps a Chakra component."""
library = "@chakra-ui/react@2.6.1"
library: str = "@chakra-ui/react@2.6.1" # type: ignore
lib_dependencies: List[str] = [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
@ -35,14 +35,14 @@ class ChakraComponent(Component):
@classmethod
@lru_cache(maxsize=None)
def _get_dependencies_imports(cls) -> imports.ImportDict:
def _get_dependencies_imports(cls) -> ImportDict:
"""Get the imports from lib_dependencies for installing.
Returns:
The dependencies imports of the component.
"""
return {
dep: [imports.ImportVar(tag=None, render=False)]
dep: [ImportVar(tag=None, render=False)]
for dep in [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
@ -70,15 +70,16 @@ class ChakraProvider(ChakraComponent):
),
)
def _get_imports(self) -> imports.ImportDict:
_imports = super()._get_imports()
_imports.setdefault(self.__fields__["library"].default, []).append(
imports.ImportVar(tag="extendTheme", is_default=False),
)
_imports.setdefault("/utils/theme.js", []).append(
imports.ImportVar(tag="theme", is_default=True),
)
return _imports
def add_imports(self) -> ImportDict:
"""Add imports for the ChakraProvider component.
Returns:
The import dict for the component.
"""
return {
self.library: ImportVar(tag="extendTheme", is_default=False),
"/utils/theme.js": ImportVar(tag="theme", is_default=True),
}
@staticmethod
@lru_cache(maxsize=None)

View File

@ -10,7 +10,7 @@ from reflex.style import Style
from functools import lru_cache
from typing import List, Literal
from reflex.components.component import Component
from reflex.utils import imports
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
class ChakraComponent(Component):
@ -155,6 +155,7 @@ class ChakraProvider(ChakraComponent):
A new ChakraProvider component.
"""
...
def add_imports(self) -> ImportDict: ...
chakra_provider = ChakraProvider.create()

View File

@ -11,7 +11,7 @@ from reflex.components.component import Component
from reflex.components.core.debounce import DebounceInput
from reflex.components.literals import LiteralInputType
from reflex.constants import EventTriggers, MemoizationMode
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var
@ -59,11 +59,13 @@ class Input(ChakraComponent):
# The name of the form field
name: Var[str]
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"/utils/state": {imports.ImportVar(tag="set_val")}},
)
def add_imports(self) -> ImportDict:
"""Add imports for the Input component.
Returns:
The import dict.
"""
return {"/utils/state": "set_val"}
def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler.

View File

@ -17,10 +17,11 @@ from reflex.components.component import Component
from reflex.components.core.debounce import DebounceInput
from reflex.components.literals import LiteralInputType
from reflex.constants import EventTriggers, MemoizationMode
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var
class Input(ChakraComponent):
def add_imports(self) -> ImportDict: ...
def get_event_triggers(self) -> Dict[str, Any]: ...
@overload
@classmethod

View File

@ -4,7 +4,7 @@
from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component
from reflex.components.next.link import NextLink
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var
next_link = NextLink.create()
@ -32,8 +32,13 @@ class Link(ChakraComponent):
# If true, the link will open in new tab.
is_external: Var[bool]
def _get_imports(self) -> imports.ImportDict:
return {**super()._get_imports(), **next_link._get_imports()}
def add_imports(self) -> ImportDict:
"""Add imports for the link component.
Returns:
The import dict.
"""
return next_link._get_imports() # type: ignore
@classmethod
def create(cls, *children, **props) -> Component:

View File

@ -10,12 +10,13 @@ from reflex.style import Style
from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component
from reflex.components.next.link import NextLink
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var
next_link = NextLink.create()
class Link(ChakraComponent):
def add_imports(self) -> ImportDict: ...
@overload
@classmethod
def create( # type: ignore

View File

@ -44,7 +44,7 @@ from reflex.event import (
)
from reflex.style import Style, format_as_emotion
from reflex.utils import console, format, imports, types
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.utils.serializers import serializer
from reflex.vars import BaseVar, Var, VarData
@ -95,7 +95,7 @@ class BaseComponent(Base, ABC):
"""
@abstractmethod
def _get_all_imports(self) -> imports.ImportDict:
def _get_all_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component.
Returns:
@ -213,7 +213,7 @@ class Component(BaseComponent, ABC):
# State class associated with this component instance
State: Optional[Type[reflex.state.State]] = None
def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the component.
This method should be implemented by subclasses to add new imports for the component.
@ -1224,7 +1224,7 @@ class Component(BaseComponent, ABC):
# Return the dynamic imports
return dynamic_imports
def _get_props_imports(self) -> List[str]:
def _get_props_imports(self) -> List[ParsedImportDict]:
"""Get the imports needed for components props.
Returns:
@ -1250,7 +1250,7 @@ class Component(BaseComponent, ABC):
or format.format_library_name(dep or "") in self.transpile_packages
)
def _get_dependencies_imports(self) -> imports.ImportDict:
def _get_dependencies_imports(self) -> ParsedImportDict:
"""Get the imports from lib_dependencies for installing.
Returns:
@ -1267,7 +1267,7 @@ class Component(BaseComponent, ABC):
for dep in self.lib_dependencies
}
def _get_hooks_imports(self) -> imports.ImportDict:
def _get_hooks_imports(self) -> ParsedImportDict:
"""Get the imports required by certain hooks.
Returns:
@ -1308,7 +1308,7 @@ class Component(BaseComponent, ABC):
return imports.merge_imports(_imports, *other_imports)
def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component.
Returns:
@ -1328,25 +1328,15 @@ class Component(BaseComponent, ABC):
var._var_data.imports for var in self._get_vars() if var._var_data
]
# If any subclass implements add_imports, merge the imports.
def _make_list(
value: str | ImportVar | list[str | ImportVar],
) -> list[str | ImportVar]:
if isinstance(value, (str, ImportVar)):
return [value]
return value
_added_import_dicts = []
added_import_dicts: list[ParsedImportDict] = []
for clz in self._iter_parent_classes_with_method("add_imports"):
_added_import_dicts.append(
{
package: [
ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
for tag in _make_list(maybe_tags)
]
for package, maybe_tags in clz.add_imports(self).items()
}
)
list_of_import_dict = clz.add_imports(self)
if not isinstance(list_of_import_dict, list):
list_of_import_dict = [list_of_import_dict]
for import_dict in list_of_import_dict:
added_import_dicts.append(parse_imports(import_dict))
return imports.merge_imports(
*self._get_props_imports(),
@ -1355,10 +1345,10 @@ class Component(BaseComponent, ABC):
_imports,
event_imports,
*var_imports,
*_added_import_dicts,
*added_import_dicts,
)
def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component and its children.
Args:
@ -1453,7 +1443,7 @@ class Component(BaseComponent, ABC):
**self._get_special_hooks(),
}
def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
def _get_added_hooks(self) -> dict[str, ImportDict]:
"""Get the hooks added via `add_hooks` method.
Returns:
@ -1842,7 +1832,7 @@ memo = custom_component
class NoSSRComponent(Component):
"""A dynamic component that is not rendered on the server."""
def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ParsedImportDict:
"""Get the imports for the component.
Returns:
@ -2185,7 +2175,7 @@ class StatefulComponent(BaseComponent):
"""
return {}
def _get_all_imports(self) -> imports.ImportDict:
def _get_all_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component.
Returns:

View File

@ -19,7 +19,7 @@ from reflex.components.radix.themes.typography.text import Text
from reflex.components.sonner.toast import Toaster, ToastProps
from reflex.constants import Dirs, Hooks, Imports
from reflex.constants.compiler import CompileVars
from reflex.utils import imports
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.serializers import serialize
from reflex.vars import Var, VarData
@ -65,10 +65,15 @@ has_too_many_connection_errors: Var = Var.create_safe(
class WebsocketTargetURL(Bare):
"""A component that renders the websocket target URL."""
def _get_imports(self) -> imports.ImportDict:
def add_imports(self) -> ImportDict:
"""Add imports for the websocket target URL component.
Returns:
The import dict.
"""
return {
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
"/env.json": [imports.ImportVar(tag="env", is_default=True)],
f"/{Dirs.STATE_PATH}": [ImportVar(tag="getBackendURL")],
"/env.json": [ImportVar(tag="env", is_default=True)],
}
@classmethod
@ -98,7 +103,7 @@ def default_connection_error() -> list[str | Var | Component]:
class ConnectionToaster(Toaster):
"""A connection toaster component."""
def add_hooks(self) -> list[str]:
def add_hooks(self) -> list[str | Var]:
"""Add the hooks for the connection toaster.
Returns:
@ -116,7 +121,7 @@ class ConnectionToaster(Toaster):
duration=120000,
id=toast_id,
)
hook = Var.create(
hook = Var.create_safe(
f"""
const toast_props = {serialize(props)};
const [userDismissed, setUserDismissed] = useState(false);
@ -135,22 +140,17 @@ useEffect(() => {{
}}, [{connect_errors}]);""",
_var_is_string=False,
)
hook._var_data = VarData.merge( # type: ignore
imports: ImportDict = {
"react": ["useEffect", "useState"],
**target_url._get_imports(), # type: ignore
}
hook._var_data = VarData.merge(
connect_errors._var_data,
VarData(
imports={
"react": [
imports.ImportVar(tag="useEffect"),
imports.ImportVar(tag="useState"),
],
**target_url._get_imports(),
}
),
VarData(imports=imports),
)
return [
Hooks.EVENTS,
hook, # type: ignore
hook,
]
@ -216,10 +216,11 @@ class WifiOffPulse(Icon):
"""A wifi_off icon with an animated opacity pulse."""
@classmethod
def create(cls, **props) -> Component:
def create(cls, *children, **props) -> Icon:
"""Create a wifi_off icon with an animated opacity pulse.
Args:
*children: The children of the component.
**props: The properties of the component.
Returns:
@ -237,11 +238,13 @@ class WifiOffPulse(Icon):
**props,
)
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"@emotion/react": [imports.ImportVar(tag="keyframes")]},
)
def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
"""Add imports for the WifiOffPulse component.
Returns:
The import dict.
"""
return {"@emotion/react": [ImportVar(tag="keyframes")]}
def _get_custom_code(self) -> str | None:
return """

View File

@ -23,7 +23,7 @@ from reflex.components.radix.themes.typography.text import Text
from reflex.components.sonner.toast import Toaster, ToastProps
from reflex.constants import Dirs, Hooks, Imports
from reflex.constants.compiler import CompileVars
from reflex.utils import imports
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.serializers import serialize
from reflex.vars import Var, VarData
@ -35,6 +35,7 @@ has_connection_errors: Var
has_too_many_connection_errors: Var
class WebsocketTargetURL(Bare):
def add_imports(self) -> ImportDict: ...
@overload
@classmethod
def create( # type: ignore
@ -104,7 +105,7 @@ class WebsocketTargetURL(Bare):
def default_connection_error() -> list[str | Var | Component]: ...
class ConnectionToaster(Toaster):
def add_hooks(self) -> list[str]: ...
def add_hooks(self) -> list[str | Var]: ...
@overload
@classmethod
def create( # type: ignore
@ -430,6 +431,7 @@ class WifiOffPulse(Icon):
"""Create a wifi_off icon with an animated opacity pulse.
Args:
*children: The children of the component.
size: The size of the icon in pixels.
style: The style of the component.
key: A unique key for the component.
@ -443,6 +445,7 @@ class WifiOffPulse(Icon):
The icon component with default props applied.
"""
...
def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: ...
class ConnectionPulser(Div):
@overload

View File

@ -10,11 +10,12 @@ from reflex.components.tags import CondTag, Tag
from reflex.constants import Dirs
from reflex.constants.colors import Color
from reflex.style import LIGHT_COLOR_MODE, color_mode
from reflex.utils import format, imports
from reflex.utils import format
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var, VarData
_IS_TRUE_IMPORT = {
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
_IS_TRUE_IMPORT: ImportDict = {
f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
}
@ -96,12 +97,16 @@ class Cond(MemoizationLeaf):
cond_state=f"isTrue({self.cond._var_full_name})",
)
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
getattr(self.cond._var_data, "imports", {}),
_IS_TRUE_IMPORT,
def add_imports(self) -> ImportDict:
"""Add imports for the Cond component.
Returns:
The import dict for the component.
"""
cond_imports: dict[str, str | ImportVar | list[str | ImportVar]] = getattr(
self.cond._var_data, "imports", {}
)
return {**cond_imports, **_IS_TRUE_IMPORT}
@overload

View File

@ -8,8 +8,9 @@ from reflex.components.component import BaseComponent, Component, MemoizationLea
from reflex.components.core.colors import Color
from reflex.components.tags import MatchTag, Tag
from reflex.style import Style
from reflex.utils import format, imports, types
from reflex.utils import format, types
from reflex.utils.exceptions import MatchTypeError
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var, VarData
@ -268,11 +269,13 @@ class Match(MemoizationLeaf):
tag.name = "match"
return dict(tag)
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
getattr(self.cond._var_data, "imports", {}),
)
def add_imports(self) -> ImportDict:
"""Add imports for the Match component.
Returns:
The import dict.
"""
return getattr(self.cond._var_data, "imports", {})
match = Match.create

View File

@ -19,17 +19,15 @@ from reflex.event import (
call_script,
parse_args_spec,
)
from reflex.utils import imports
from reflex.utils.imports import ImportVar
from reflex.vars import BaseVar, CallableVar, Var, VarData
DEFAULT_UPLOAD_ID: str = "default"
upload_files_context_var_data: VarData = VarData(
imports={
"react": [imports.ImportVar(tag="useContext")],
f"/{Dirs.CONTEXTS_PATH}": [
imports.ImportVar(tag="UploadFilesContext"),
],
"react": "useContext",
f"/{Dirs.CONTEXTS_PATH}": "UploadFilesContext",
},
hooks={
"const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
@ -133,8 +131,8 @@ uploaded_files_url_prefix: Var = Var.create_safe(
_var_is_string=False,
_var_data=VarData(
imports={
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
"/env.json": [imports.ImportVar(tag="env", is_default=True)],
f"/{Dirs.STATE_PATH}": "getBackendURL",
"/env.json": ImportVar(tag="env", is_default=True),
}
),
)

View File

@ -23,7 +23,7 @@ from reflex.event import (
call_script,
parse_args_spec,
)
from reflex.utils import imports
from reflex.utils.imports import ImportVar
from reflex.vars import BaseVar, CallableVar, Var, VarData
DEFAULT_UPLOAD_ID: str

View File

@ -12,8 +12,8 @@ from reflex.components.core.cond import color_mode_cond
from reflex.constants.colors import Color
from reflex.event import set_clipboard
from reflex.style import Style
from reflex.utils import format, imports
from reflex.utils.imports import ImportVar
from reflex.utils import format
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
LiteralCodeBlockTheme = Literal[
@ -381,42 +381,45 @@ class CodeBlock(Component):
# Props passed down to the code tag.
code_tag_props: Var[Dict[str, str]]
def _get_imports(self) -> imports.ImportDict:
merged_imports = super()._get_imports()
# Get all themes from a cond literal
def add_imports(self) -> ImportDict:
"""Add imports for the CodeBlock component.
Returns:
The import dict.
"""
imports_: ImportDict = {}
themes = re.findall(r"`(.*?)`", self.theme._var_name)
if not themes:
themes = [self.theme._var_name]
merged_imports = imports.merge_imports(
merged_imports,
imports_.update(
{
f"react-syntax-highlighter/dist/cjs/styles/prism/{self.convert_theme_name(theme)}": {
f"react-syntax-highlighter/dist/cjs/styles/prism/{self.convert_theme_name(theme)}": [
ImportVar(
tag=format.to_camel_case(self.convert_theme_name(theme)),
is_default=True,
install=False,
)
}
]
for theme in themes
},
}
)
if (
self.language is not None
and self.language._var_name in LiteralCodeLanguage.__args__ # type: ignore
):
merged_imports = imports.merge_imports(
merged_imports,
{
f"react-syntax-highlighter/dist/cjs/languages/prism/{self.language._var_name}": {
ImportVar(
tag=format.to_camel_case(self.language._var_name),
is_default=True,
install=False,
)
}
},
)
return merged_imports
imports_[
f"react-syntax-highlighter/dist/cjs/languages/prism/{self.language._var_name}"
] = [
ImportVar(
tag=format.to_camel_case(self.language._var_name),
is_default=True,
install=False,
)
]
return imports_
def _get_custom_code(self) -> Optional[str]:
if (

View File

@ -17,8 +17,8 @@ from reflex.components.core.cond import color_mode_cond
from reflex.constants.colors import Color
from reflex.event import set_clipboard
from reflex.style import Style
from reflex.utils import format, imports
from reflex.utils.imports import ImportVar
from reflex.utils import format
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
LiteralCodeBlockTheme = Literal[
@ -351,6 +351,7 @@ LiteralCodeLanguage = Literal[
]
class CodeBlock(Component):
def add_imports(self) -> ImportDict: ...
@overload
@classmethod
def create( # type: ignore

View File

@ -2,13 +2,14 @@
from __future__ import annotations
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.components.literals import LiteralRowMarker
from reflex.utils import console, format, imports, types
from reflex.utils.imports import ImportVar
from reflex.event import EventHandler
from reflex.utils import console, format, types
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.serializers import serializer
from reflex.vars import Var, get_unique_variable_name
@ -205,51 +206,66 @@ class DataEditor(NoSSRComponent):
# global theme
theme: Var[Union[DataEditorTheme, Dict]]
def _get_imports(self):
return imports.merge_imports(
super()._get_imports(),
{
"": {
ImportVar(
tag=f"{format.format_library_name(self.library)}/dist/index.css"
)
},
self.library: {ImportVar(tag="GridCellKind")},
"/utils/helpers/dataeditor.js": {
ImportVar(
tag=f"formatDataEditorCells", is_default=False, install=False
),
},
},
)
# Triggered when a cell is activated.
on_cell_activated: EventHandler[lambda pos: [pos]]
def get_event_triggers(self) -> Dict[str, Callable]:
"""The event triggers of the component.
# Triggered when a cell is clicked.
on_cell_clicked: EventHandler[lambda pos: [pos]]
# Triggered when a cell is right-clicked.
on_cell_context_menu: EventHandler[lambda pos: [pos]]
# Triggered when a cell is edited.
on_cell_edited: EventHandler[lambda pos, data: [pos, data]]
# Triggered when a group header is clicked.
on_group_header_clicked: EventHandler[lambda pos, data: [pos, data]]
# Triggered when a group header is right-clicked.
on_group_header_context_menu: EventHandler[lambda grp_idx, data: [grp_idx, data]]
# Triggered when a group header is renamed.
on_group_header_renamed: EventHandler[lambda idx, val: [idx, val]]
# Triggered when a header is clicked.
on_header_clicked: EventHandler[lambda pos: [pos]]
# Triggered when a header is right-clicked.
on_header_context_menu: EventHandler[lambda pos: [pos]]
# Triggered when a header menu is clicked.
on_header_menu_click: EventHandler[lambda col, pos: [col, pos]]
# Triggered when an item is hovered.
on_item_hovered: EventHandler[lambda pos: [pos]]
# Triggered when a selection is deleted.
on_delete: EventHandler[lambda selection: [selection]]
# Triggered when editing is finished.
on_finished_editing: EventHandler[lambda new_value, movement: [new_value, movement]]
# Triggered when a row is appended.
on_row_appended: EventHandler[lambda: []]
# Triggered when the selection is cleared.
on_selection_cleared: EventHandler[lambda: []]
# Triggered when a column is resized.
on_column_resize: EventHandler[lambda col, width: [col, width]]
def add_imports(self) -> ImportDict:
"""Add imports for the component.
Returns:
The dict describing the event triggers.
The import dict.
"""
def edit_sig(pos, data: dict[str, Any]):
return [pos, data]
return {
"on_cell_activated": lambda pos: [pos],
"on_cell_clicked": lambda pos: [pos],
"on_cell_context_menu": lambda pos: [pos],
"on_cell_edited": edit_sig,
"on_group_header_clicked": edit_sig,
"on_group_header_context_menu": lambda grp_idx, data: [grp_idx, data],
"on_group_header_renamed": lambda idx, val: [idx, val],
"on_header_clicked": lambda pos: [pos],
"on_header_context_menu": lambda pos: [pos],
"on_header_menu_click": lambda col, pos: [col, pos],
"on_item_hovered": lambda pos: [pos],
"on_delete": lambda selection: [selection],
"on_finished_editing": lambda new_value, movement: [new_value, movement],
"on_row_appended": lambda: [],
"on_selection_cleared": lambda: [],
"on_column_resize": lambda col, width: [col, width],
"": f"{format.format_library_name(self.library)}/dist/index.css",
self.library: "GridCellKind",
"/utils/helpers/dataeditor.js": ImportVar(
tag="formatDataEditorCells", is_default=False, install=False
),
}
def add_hooks(self) -> list[str]:

View File

@ -8,12 +8,13 @@ from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.components.literals import LiteralRowMarker
from reflex.utils import console, format, imports, types
from reflex.utils.imports import ImportVar
from reflex.event import EventHandler
from reflex.utils import console, format, types
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.serializers import serializer
from reflex.vars import Var, get_unique_variable_name
@ -80,7 +81,7 @@ class DataEditorTheme(Base):
text_medium: Optional[str]
class DataEditor(NoSSRComponent):
def get_event_triggers(self) -> Dict[str, Callable]: ...
def add_imports(self) -> ImportDict: ...
def add_hooks(self) -> list[str]: ...
@overload
@classmethod
@ -136,6 +137,9 @@ class DataEditor(NoSSRComponent):
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
on_blur: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_cell_activated: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
@ -148,15 +152,27 @@ class DataEditor(NoSSRComponent):
on_cell_edited: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_click: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_column_resize: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_context_menu: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_delete: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_double_click: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_finished_editing: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_focus: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_group_header_clicked: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
@ -178,12 +194,42 @@ class DataEditor(NoSSRComponent):
on_item_hovered: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mount: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_down: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_enter: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_leave: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_move: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_out: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_over: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_up: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_row_appended: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_scroll: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_selection_cleared: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_unmount: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
**props
) -> "DataEditor":
"""Create the DataEditor component.

View File

@ -11,8 +11,8 @@ from reflex.components.el.element import Element
from reflex.components.tags.tag import Tag
from reflex.constants import Dirs, EventTriggers
from reflex.event import EventChain
from reflex.utils import imports
from reflex.utils.format import format_event_chain
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var
from .base import BaseHTML
@ -169,17 +169,16 @@ class Form(BaseHTML):
).hexdigest()
return form
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{
"react": {imports.ImportVar(tag="useCallback")},
f"/{Dirs.STATE_PATH}": {
imports.ImportVar(tag="getRefValue"),
imports.ImportVar(tag="getRefValues"),
},
},
)
def add_imports(self) -> ImportDict:
"""Add imports needed by the form component.
Returns:
The imports for the form component.
"""
return {
"react": "useCallback",
f"/{Dirs.STATE_PATH}": ["getRefValue", "getRefValues"],
}
def add_hooks(self) -> list[str]:
"""Add hooks for the form.

View File

@ -14,8 +14,8 @@ from reflex.components.el.element import Element
from reflex.components.tags.tag import Tag
from reflex.constants import Dirs, EventTriggers
from reflex.event import EventChain
from reflex.utils import imports
from reflex.utils.format import format_event_chain
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var
from .base import BaseHTML
@ -581,6 +581,7 @@ class Form(BaseHTML):
The form component.
"""
...
def add_imports(self) -> ImportDict: ...
def add_hooks(self) -> list[str]: ...
class Input(BaseHTML):

View File

@ -6,7 +6,8 @@ from typing import Any, Dict, List, Union
from reflex.components.component import Component
from reflex.components.tags import Tag
from reflex.utils import imports, types
from reflex.utils import types
from reflex.utils.imports import ImportDict
from reflex.utils.serializers import serialize
from reflex.vars import BaseVar, ComputedVar, Var
@ -102,11 +103,13 @@ class DataTable(Gridjs):
**props,
)
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
)
def add_imports(self) -> ImportDict:
"""Add the imports for the datatable component.
Returns:
The import dict for the component.
"""
return {"": "gridjs/dist/theme/mermaid.css"}
def _render(self) -> Tag:
if isinstance(self.data, Var) and types.is_dataframe(self.data._var_type):

View File

@ -10,7 +10,8 @@ from reflex.style import Style
from typing import Any, Dict, List, Union
from reflex.components.component import Component
from reflex.components.tags import Tag
from reflex.utils import imports, types
from reflex.utils import types
from reflex.utils.imports import ImportDict
from reflex.utils.serializers import serialize
from reflex.vars import BaseVar, ComputedVar, Var
@ -180,3 +181,4 @@ class DataTable(Gridjs):
ValueError: If a pandas dataframe is passed in and columns are also provided.
"""
...
def add_imports(self) -> ImportDict: ...

View File

@ -7,7 +7,6 @@ from functools import lru_cache
from hashlib import md5
from typing import Any, Callable, Dict, Union
from reflex.compiler import utils
from reflex.components.component import Component, CustomComponent
from reflex.components.radix.themes.layout.list import (
ListItem,
@ -18,8 +17,8 @@ from reflex.components.radix.themes.typography.heading import Heading
from reflex.components.radix.themes.typography.link import Link
from reflex.components.radix.themes.typography.text import Text
from reflex.components.tags.tag import Tag
from reflex.utils import imports, types
from reflex.utils.imports import ImportVar
from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
# Special vars used in the component map.
@ -145,47 +144,41 @@ class Markdown(Component):
return custom_components
def _get_imports(self) -> imports.ImportDict:
# Import here to avoid circular imports.
def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the markdown component.
Returns:
The imports for the markdown component.
"""
from reflex.components.datadisplay.code import CodeBlock
from reflex.components.radix.themes.typography.code import Code
imports = super()._get_imports()
# Special markdown imports.
imports.update(
return [
{
"": [ImportVar(tag="katex/dist/katex.min.css")],
"remark-math@5.1.1": [
ImportVar(tag=_REMARK_MATH._var_name, is_default=True)
],
"remark-gfm@3.0.1": [
ImportVar(tag=_REMARK_GFM._var_name, is_default=True)
],
"remark-unwrap-images@4.0.0": [
ImportVar(tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True)
],
"rehype-katex@6.0.3": [
ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True)
],
"rehype-raw@6.1.1": [
ImportVar(tag=_REHYPE_RAW._var_name, is_default=True)
],
}
)
# Get the imports for each component.
for component in self.component_map.values():
imports = utils.merge_imports(
imports, component(_MOCK_ARG)._get_all_imports()
)
# Get the imports for the code components.
imports = utils.merge_imports(
imports, CodeBlock.create(theme="light")._get_imports()
)
imports = utils.merge_imports(imports, Code.create()._get_imports())
return imports
"": "katex/dist/katex.min.css",
"remark-math@5.1.1": ImportVar(
tag=_REMARK_MATH._var_name, is_default=True
),
"remark-gfm@3.0.1": ImportVar(
tag=_REMARK_GFM._var_name, is_default=True
),
"remark-unwrap-images@4.0.0": ImportVar(
tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True
),
"rehype-katex@6.0.3": ImportVar(
tag=_REHYPE_KATEX._var_name, is_default=True
),
"rehype-raw@6.1.1": ImportVar(
tag=_REHYPE_RAW._var_name, is_default=True
),
},
*[
component(_MOCK_ARG)._get_imports() # type: ignore
for component in self.component_map.values()
],
CodeBlock.create(theme="light")._get_imports(), # type: ignore,
Code.create()._get_imports(), # type: ignore,
]
def get_component(self, tag: str, **props) -> Component:
"""Get the component for a tag and props.

View File

@ -11,7 +11,6 @@ import textwrap
from functools import lru_cache
from hashlib import md5
from typing import Any, Callable, Dict, Union
from reflex.compiler import utils
from reflex.components.component import Component, CustomComponent
from reflex.components.radix.themes.layout.list import (
ListItem,
@ -22,8 +21,8 @@ from reflex.components.radix.themes.typography.heading import Heading
from reflex.components.radix.themes.typography.link import Link
from reflex.components.radix.themes.typography.text import Text
from reflex.components.tags.tag import Tag
from reflex.utils import imports, types
from reflex.utils.imports import ImportVar
from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
_CHILDREN = Var.create_safe("children", _var_is_local=False, _var_is_string=False)
@ -124,6 +123,7 @@ class Markdown(Component):
The markdown component.
"""
...
def add_imports(self) -> ImportDict | list[ImportDict]: ...
def get_component(self, tag: str, **props) -> Component: ...
def format_component(self, tag: str, **props) -> str: ...
def format_component_map(self) -> dict[str, str]: ...

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var
@ -90,14 +90,15 @@ class Moment(NoSSRComponent):
# Display the date in the given timezone.
tz: Var[str]
def _get_imports(self) -> imports.ImportDict:
merged_imports = super()._get_imports()
def add_imports(self) -> ImportDict:
"""Add the imports for the Moment component.
Returns:
The import dict for the component.
"""
if self.tz is not None:
merged_imports = imports.merge_imports(
merged_imports,
{"moment-timezone": {imports.ImportVar(tag="")}},
)
return merged_imports
return {"moment-timezone": ""}
return {}
def get_event_triggers(self) -> Dict[str, Any]:
"""Get the events triggers signatures for the component.

View File

@ -10,7 +10,7 @@ from reflex.style import Style
from typing import Any, Dict, List, Optional
from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var
class MomentDelta(Base):
@ -25,6 +25,7 @@ class MomentDelta(Base):
milliseconds: Optional[int]
class Moment(NoSSRComponent):
def add_imports(self) -> ImportDict: ...
def get_event_triggers(self) -> Dict[str, Any]: ...
@overload
@classmethod

View File

@ -11,7 +11,6 @@ from reflex.components.lucide.icon import Icon
from reflex.components.radix.primitives.base import RadixPrimitiveComponent
from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius
from reflex.style import Style
from reflex.utils import imports
from reflex.vars import Var, get_uuid_string_var
LiteralAccordionType = Literal["single", "multiple"]
@ -413,13 +412,13 @@ class AccordionContent(AccordionComponent):
alias = "RadixAccordionContent"
def add_imports(self) -> imports.ImportDict:
def add_imports(self) -> dict:
"""Add imports to the component.
Returns:
The imports of the component.
"""
return {"@emotion/react": [imports.ImportVar(tag="keyframes")]}
return {"@emotion/react": "keyframes"}
@classmethod
def create(cls, *children, **props) -> Component:

View File

@ -15,7 +15,6 @@ from reflex.components.lucide.icon import Icon
from reflex.components.radix.primitives.base import RadixPrimitiveComponent
from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius
from reflex.style import Style
from reflex.utils import imports
from reflex.vars import Var, get_uuid_string_var
LiteralAccordionType = Literal["single", "multiple"]
@ -899,7 +898,7 @@ class AccordionIcon(Icon):
...
class AccordionContent(AccordionComponent):
def add_imports(self) -> imports.ImportDict: ...
def add_imports(self) -> dict: ...
@overload
@classmethod
def create( # type: ignore

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Literal
from reflex.components import Component
from reflex.components.tags import Tag
from reflex.config import get_config
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
@ -209,13 +209,13 @@ class Theme(RadixThemesComponent):
children = [ThemePanel.create(), *children]
return super().create(*children, **props)
def add_imports(self) -> dict[str, list[ImportVar] | ImportVar]:
def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the Theme component.
Returns:
The import dict.
"""
_imports: dict[str, list[ImportVar] | ImportVar] = {
_imports: ImportDict = {
"/utils/theme.js": [ImportVar(tag="theme", is_default=True)],
}
if get_config().tailwind is None:

View File

@ -11,7 +11,7 @@ from typing import Any, Dict, Literal
from reflex.components import Component
from reflex.components.tags import Tag
from reflex.config import get_config
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
@ -580,7 +580,7 @@ class Theme(RadixThemesComponent):
A new component instance.
"""
...
def add_imports(self) -> dict[str, list[ImportVar] | ImportVar]: ...
def add_imports(self) -> ImportDict | list[ImportDict]: ...
class ThemePanel(RadixThemesComponent):
def add_imports(self) -> dict[str, str]: ...

View File

@ -12,7 +12,7 @@ from reflex.components.core.colors import color
from reflex.components.core.cond import cond
from reflex.components.el.elements.inline import A
from reflex.components.next.link import NextLink
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var
from ..base import (
@ -59,8 +59,13 @@ class Link(RadixThemesComponent, A, MemoizationLeaf):
# If True, the link will open in a new tab
is_external: Var[bool]
def _get_imports(self) -> imports.ImportDict:
return {**super()._get_imports(), **next_link._get_imports()}
def add_imports(self) -> ImportDict:
"""Add imports for the Link component.
Returns:
The import dict.
"""
return next_link._get_imports() # type: ignore
@classmethod
def create(cls, *children, **props) -> Component:

View File

@ -13,7 +13,7 @@ from reflex.components.core.colors import color
from reflex.components.core.cond import cond
from reflex.components.el.elements.inline import A
from reflex.components.next.link import NextLink
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var
from ..base import LiteralAccentColor, RadixThemesComponent
from .base import LiteralTextSize, LiteralTextTrim, LiteralTextWeight
@ -22,6 +22,7 @@ LiteralLinkUnderline = Literal["auto", "hover", "always", "none"]
next_link = NextLink.create()
class Link(RadixThemesComponent, A, MemoizationLeaf):
def add_imports(self) -> ImportDict: ...
@overload
@classmethod
def create( # type: ignore

View File

@ -8,7 +8,7 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.constants import EventTriggers
from reflex.utils.format import to_camel_case
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
@ -176,12 +176,15 @@ class Editor(NoSSRComponent):
# default: False
disable_toolbar: Var[bool]
def _get_imports(self):
imports = super()._get_imports()
imports[""] = [
ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False)
]
return imports
def add_imports(self) -> ImportDict:
"""Add imports for the Editor component.
Returns:
The import dict.
"""
return {
"": ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False)
}
def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler.

View File

@ -13,7 +13,7 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.constants import EventTriggers
from reflex.utils.format import to_camel_case
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var
class EditorButtonList(list, enum.Enum):
@ -48,6 +48,7 @@ class EditorOptions(Base):
button_list: Optional[List[Union[List[str], str]]]
class Editor(NoSSRComponent):
def add_imports(self) -> ImportDict: ...
def get_event_triggers(self) -> Dict[str, Any]: ...
@overload
@classmethod

View File

@ -3,12 +3,12 @@
from __future__ import annotations
from collections import defaultdict
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
from reflex.base import Base
def merge_imports(*imports) -> ImportDict:
def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
"""Merge multiple import dicts together.
Args:
@ -24,7 +24,31 @@ def merge_imports(*imports) -> ImportDict:
return all_imports
def collapse_imports(imports: ImportDict) -> ImportDict:
def parse_imports(imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
"""Parse the import dict into a standard format.
Args:
imports: The import dict to parse.
Returns:
The parsed import dict.
"""
def _make_list(value: ImportTypes) -> list[str | ImportVar] | list[ImportVar]:
if isinstance(value, (str, ImportVar)):
return [value]
return value
return {
package: [
ImportVar(tag=tag) if isinstance(tag, str) else tag
for tag in _make_list(maybe_tags)
]
for package, maybe_tags in imports.items()
}
def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict:
"""Remove all duplicate ImportVar within an ImportDict.
Args:
@ -33,7 +57,10 @@ def collapse_imports(imports: ImportDict) -> ImportDict:
Returns:
The collapsed import dict.
"""
return {lib: list(set(import_vars)) for lib, import_vars in imports.items()}
return {
lib: list(set(import_vars)) if isinstance(import_vars, list) else import_vars
for lib, import_vars in imports.items()
}
class ImportVar(Base):
@ -90,4 +117,6 @@ class ImportVar(Base):
)
ImportDict = Dict[str, List[ImportVar]]
ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]]
ImportDict = Dict[str, ImportTypes]
ParsedImportDict = Dict[str, List[ImportVar]]

View File

@ -39,7 +39,12 @@ from reflex.utils import console, imports, serializers, types
from reflex.utils.exceptions import VarAttributeError, VarTypeError, VarValueError
# This module used to export ImportVar itself, so we still import it for export here
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.imports import (
ImportDict,
ImportVar,
ParsedImportDict,
parse_imports,
)
from reflex.utils.types import override
if TYPE_CHECKING:
@ -120,7 +125,7 @@ class VarData(Base):
state: str = ""
# Imports needed to render this var
imports: ImportDict = {}
imports: ParsedImportDict = {}
# Hooks that need to be present in the component to render this var
hooks: Dict[str, None] = {}
@ -130,6 +135,19 @@ class VarData(Base):
# segments.
interpolations: List[Tuple[int, int]] = []
def __init__(
self, imports: Union[ImportDict, ParsedImportDict] | None = None, **kwargs: Any
):
"""Initialize the var data.
Args:
imports: The imports needed to render this var.
**kwargs: The var data fields.
"""
if imports:
kwargs["imports"] = parse_imports(imports)
super().__init__(**kwargs)
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None:
"""Merge multiple var data objects.

View File

@ -10,7 +10,7 @@ from reflex.base import Base as Base
from reflex.state import State as State
from reflex.state import BaseState as BaseState
from reflex.utils import console as console, format as format, types as types
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportVar, ImportDict, ParsedImportDict
from types import FunctionType
from typing import (
Any,
@ -36,7 +36,7 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
class VarData(Base):
state: str = ""
imports: dict[str, List[ImportVar]] = {}
imports: Union[ImportDict, ParsedImportDict] = {}
hooks: Dict[str, None] = {}
interpolations: List[Tuple[int, int]] = []
@classmethod

View File

@ -4,8 +4,7 @@ from typing import List
import pytest
from reflex.compiler import compiler, utils
from reflex.utils import imports
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportVar, ParsedImportDict
@pytest.mark.parametrize(
@ -93,7 +92,7 @@ def test_compile_import_statement(
),
],
)
def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]):
def test_compile_imports(import_dict: ParsedImportDict, test_dicts: List[dict]):
"""Test the compile_imports function.
Args:

View File

@ -20,7 +20,7 @@ from reflex.event import EventChain, EventHandler, parse_args_spec
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import BaseVar, Var, VarData
@ -56,7 +56,7 @@ def component1() -> Type[Component]:
# A test string/number prop.
text_or_number: Var[Union[int, str]]
def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ParsedImportDict:
return {"react": [ImportVar(tag="Component")]}
def _get_custom_code(self) -> str:
@ -89,7 +89,7 @@ def component2() -> Type[Component]:
"on_close": lambda e0: [e0],
}
def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ParsedImportDict:
return {"react-redux": [ImportVar(tag="connect")]}
def _get_custom_code(self) -> str:
@ -1773,21 +1773,15 @@ def test_invalid_event_trigger():
),
)
def test_component_add_imports(tags):
def _list_to_import_vars(tags: List[str]) -> List[ImportVar]:
return [
ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
for tag in tags
]
class BaseComponent(Component):
def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ImportDict:
return {}
class Reference(Component):
def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ParsedImportDict:
return imports.merge_imports(
super()._get_imports(),
{"react": _list_to_import_vars(tags)},
parse_imports({"react": tags}),
{"foo": [ImportVar(tag="bar")]},
)
@ -1806,10 +1800,12 @@ def test_component_add_imports(tags):
baseline = Reference.create()
test = Test.create()
assert baseline._get_all_imports() == {
"react": _list_to_import_vars(tags),
"foo": [ImportVar(tag="bar")],
}
assert baseline._get_all_imports() == parse_imports(
{
"react": tags,
"foo": [ImportVar(tag="bar")],
}
)
assert test._get_all_imports() == baseline._get_all_imports()

View File

@ -1,6 +1,12 @@
import pytest
from reflex.utils.imports import ImportVar, merge_imports
from reflex.utils.imports import (
ImportDict,
ImportVar,
ParsedImportDict,
merge_imports,
parse_imports,
)
@pytest.mark.parametrize(
@ -76,3 +82,32 @@ def test_merge_imports(input_1, input_2, output):
for key in output:
assert set(res[key]) == set(output[key])
@pytest.mark.parametrize(
"input, output",
[
({}, {}),
(
{"react": "Component"},
{"react": [ImportVar(tag="Component")]},
),
(
{"react": ["Component"]},
{"react": [ImportVar(tag="Component")]},
),
(
{"react": ["Component", ImportVar(tag="useState")]},
{"react": [ImportVar(tag="Component"), ImportVar(tag="useState")]},
),
(
{"react": ["Component"], "foo": "anotherFunction"},
{
"react": [ImportVar(tag="Component")],
"foo": [ImportVar(tag="anotherFunction")],
},
),
],
)
def test_parse_imports(input: ImportDict, output: ParsedImportDict):
assert parse_imports(input) == output