use add_imports everywhere (#3448)

This commit is contained in:
Thomas Brandého 2024-06-12 18:26:45 +02:00 committed by GitHub
parent 991f6e0183
commit 462b023019
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 469 additions and 304 deletions

View File

@ -28,13 +28,14 @@ from reflex.components.component import Component, ComponentStyle, CustomCompone
from reflex.state import BaseState, Cookie, LocalStorage from reflex.state import BaseState, Cookie, LocalStorage
from reflex.style import Style from reflex.style import Style
from reflex.utils import console, format, imports, path_ops from reflex.utils import console, format, imports, path_ops
from reflex.utils.imports import ImportVar, ParsedImportDict
from reflex.vars import Var from reflex.vars import Var
# To re-export this function. # To re-export this function.
merge_imports = imports.merge_imports merge_imports = imports.merge_imports
def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]: def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
"""Compile an import statement. """Compile an import statement.
Args: Args:
@ -59,7 +60,7 @@ def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list
return default, list(rest) return default, list(rest)
def validate_imports(import_dict: imports.ImportDict): def validate_imports(import_dict: ParsedImportDict):
"""Verify that the same Tag is not used in multiple import. """Verify that the same Tag is not used in multiple import.
Args: Args:
@ -82,7 +83,7 @@ def validate_imports(import_dict: imports.ImportDict):
used_tags[import_name] = lib used_tags[import_name] = lib
def compile_imports(import_dict: imports.ImportDict) -> list[dict]: def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
"""Compile an import dict. """Compile an import dict.
Args: Args:
@ -91,7 +92,7 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
Returns: Returns:
The list of import dict. The list of import dict.
""" """
collapsed_import_dict = imports.collapse_imports(import_dict) collapsed_import_dict: ParsedImportDict = imports.collapse_imports(import_dict)
validate_imports(collapsed_import_dict) validate_imports(collapsed_import_dict)
import_dicts = [] import_dicts = []
for lib, fields in collapsed_import_dict.items(): for lib, fields in collapsed_import_dict.items():
@ -231,7 +232,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
def compile_custom_component( def compile_custom_component(
component: CustomComponent, component: CustomComponent,
) -> tuple[dict, imports.ImportDict]: ) -> tuple[dict, ParsedImportDict]:
"""Compile a custom component. """Compile a custom component.
Args: Args:
@ -244,7 +245,7 @@ def compile_custom_component(
render = component.get_component(component) render = component.get_component(component)
# Get the imports. # Get the imports.
imports = { imports: ParsedImportDict = {
lib: fields lib: fields
for lib, fields in render._get_all_imports().items() for lib, fields in render._get_all_imports().items()
if lib != component.library if lib != component.library

View File

@ -5,14 +5,14 @@ from functools import lru_cache
from typing import List, Literal from typing import List, Literal
from reflex.components.component import Component from reflex.components.component import Component
from reflex.utils import imports from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
class ChakraComponent(Component): class ChakraComponent(Component):
"""A component that wraps a Chakra component.""" """A component that wraps a Chakra component."""
library = "@chakra-ui/react@2.6.1" library: str = "@chakra-ui/react@2.6.1" # type: ignore
lib_dependencies: List[str] = [ lib_dependencies: List[str] = [
"@chakra-ui/system@2.5.7", "@chakra-ui/system@2.5.7",
"framer-motion@10.16.4", "framer-motion@10.16.4",
@ -35,14 +35,14 @@ class ChakraComponent(Component):
@classmethod @classmethod
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _get_dependencies_imports(cls) -> imports.ImportDict: def _get_dependencies_imports(cls) -> ImportDict:
"""Get the imports from lib_dependencies for installing. """Get the imports from lib_dependencies for installing.
Returns: Returns:
The dependencies imports of the component. The dependencies imports of the component.
""" """
return { return {
dep: [imports.ImportVar(tag=None, render=False)] dep: [ImportVar(tag=None, render=False)]
for dep in [ for dep in [
"@chakra-ui/system@2.5.7", "@chakra-ui/system@2.5.7",
"framer-motion@10.16.4", "framer-motion@10.16.4",
@ -70,15 +70,16 @@ class ChakraProvider(ChakraComponent):
), ),
) )
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
_imports = super()._get_imports() """Add imports for the ChakraProvider component.
_imports.setdefault(self.__fields__["library"].default, []).append(
imports.ImportVar(tag="extendTheme", is_default=False), Returns:
) The import dict for the component.
_imports.setdefault("/utils/theme.js", []).append( """
imports.ImportVar(tag="theme", is_default=True), return {
) self.library: ImportVar(tag="extendTheme", is_default=False),
return _imports "/utils/theme.js": ImportVar(tag="theme", is_default=True),
}
@staticmethod @staticmethod
@lru_cache(maxsize=None) @lru_cache(maxsize=None)

View File

@ -10,7 +10,7 @@ from reflex.style import Style
from functools import lru_cache from functools import lru_cache
from typing import List, Literal from typing import List, Literal
from reflex.components.component import Component from reflex.components.component import Component
from reflex.utils import imports from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
class ChakraComponent(Component): class ChakraComponent(Component):
@ -155,6 +155,7 @@ class ChakraProvider(ChakraComponent):
A new ChakraProvider component. A new ChakraProvider component.
""" """
... ...
def add_imports(self) -> ImportDict: ...
chakra_provider = ChakraProvider.create() chakra_provider = ChakraProvider.create()

View File

@ -11,7 +11,7 @@ from reflex.components.component import Component
from reflex.components.core.debounce import DebounceInput from reflex.components.core.debounce import DebounceInput
from reflex.components.literals import LiteralInputType from reflex.components.literals import LiteralInputType
from reflex.constants import EventTriggers, MemoizationMode from reflex.constants import EventTriggers, MemoizationMode
from reflex.utils import imports from reflex.utils.imports import ImportDict
from reflex.vars import Var from reflex.vars import Var
@ -59,11 +59,13 @@ class Input(ChakraComponent):
# The name of the form field # The name of the form field
name: Var[str] name: Var[str]
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
return imports.merge_imports( """Add imports for the Input component.
super()._get_imports(),
{"/utils/state": {imports.ImportVar(tag="set_val")}}, Returns:
) The import dict.
"""
return {"/utils/state": "set_val"}
def get_event_triggers(self) -> Dict[str, Any]: def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.

View File

@ -17,10 +17,11 @@ from reflex.components.component import Component
from reflex.components.core.debounce import DebounceInput from reflex.components.core.debounce import DebounceInput
from reflex.components.literals import LiteralInputType from reflex.components.literals import LiteralInputType
from reflex.constants import EventTriggers, MemoizationMode from reflex.constants import EventTriggers, MemoizationMode
from reflex.utils import imports from reflex.utils.imports import ImportDict
from reflex.vars import Var from reflex.vars import Var
class Input(ChakraComponent): class Input(ChakraComponent):
def add_imports(self) -> ImportDict: ...
def get_event_triggers(self) -> Dict[str, Any]: ... def get_event_triggers(self) -> Dict[str, Any]: ...
@overload @overload
@classmethod @classmethod

View File

@ -4,7 +4,7 @@
from reflex.components.chakra import ChakraComponent from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.next.link import NextLink from reflex.components.next.link import NextLink
from reflex.utils import imports from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, Var
next_link = NextLink.create() next_link = NextLink.create()
@ -32,8 +32,13 @@ class Link(ChakraComponent):
# If true, the link will open in new tab. # If true, the link will open in new tab.
is_external: Var[bool] is_external: Var[bool]
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
return {**super()._get_imports(), **next_link._get_imports()} """Add imports for the link component.
Returns:
The import dict.
"""
return next_link._get_imports() # type: ignore
@classmethod @classmethod
def create(cls, *children, **props) -> Component: def create(cls, *children, **props) -> Component:

View File

@ -10,12 +10,13 @@ from reflex.style import Style
from reflex.components.chakra import ChakraComponent from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.next.link import NextLink from reflex.components.next.link import NextLink
from reflex.utils import imports from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, Var
next_link = NextLink.create() next_link = NextLink.create()
class Link(ChakraComponent): class Link(ChakraComponent):
def add_imports(self) -> ImportDict: ...
@overload @overload
@classmethod @classmethod
def create( # type: ignore def create( # type: ignore

View File

@ -44,7 +44,7 @@ from reflex.event import (
) )
from reflex.style import Style, format_as_emotion from reflex.style import Style, format_as_emotion
from reflex.utils import console, format, imports, types from reflex.utils import console, format, imports, types
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.utils.serializers import serializer from reflex.utils.serializers import serializer
from reflex.vars import BaseVar, Var, VarData from reflex.vars import BaseVar, Var, VarData
@ -95,7 +95,7 @@ class BaseComponent(Base, ABC):
""" """
@abstractmethod @abstractmethod
def _get_all_imports(self) -> imports.ImportDict: def _get_all_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.
Returns: Returns:
@ -213,7 +213,7 @@ class Component(BaseComponent, ABC):
# State class associated with this component instance # State class associated with this component instance
State: Optional[Type[reflex.state.State]] = None State: Optional[Type[reflex.state.State]] = None
def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the component. """Add imports for the component.
This method should be implemented by subclasses to add new imports for the component. This method should be implemented by subclasses to add new imports for the component.
@ -1224,7 +1224,7 @@ class Component(BaseComponent, ABC):
# Return the dynamic imports # Return the dynamic imports
return dynamic_imports return dynamic_imports
def _get_props_imports(self) -> List[str]: def _get_props_imports(self) -> List[ParsedImportDict]:
"""Get the imports needed for components props. """Get the imports needed for components props.
Returns: Returns:
@ -1250,7 +1250,7 @@ class Component(BaseComponent, ABC):
or format.format_library_name(dep or "") in self.transpile_packages or format.format_library_name(dep or "") in self.transpile_packages
) )
def _get_dependencies_imports(self) -> imports.ImportDict: def _get_dependencies_imports(self) -> ParsedImportDict:
"""Get the imports from lib_dependencies for installing. """Get the imports from lib_dependencies for installing.
Returns: Returns:
@ -1267,7 +1267,7 @@ class Component(BaseComponent, ABC):
for dep in self.lib_dependencies for dep in self.lib_dependencies
} }
def _get_hooks_imports(self) -> imports.ImportDict: def _get_hooks_imports(self) -> ParsedImportDict:
"""Get the imports required by certain hooks. """Get the imports required by certain hooks.
Returns: Returns:
@ -1308,7 +1308,7 @@ class Component(BaseComponent, ABC):
return imports.merge_imports(_imports, *other_imports) return imports.merge_imports(_imports, *other_imports)
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.
Returns: Returns:
@ -1328,25 +1328,15 @@ class Component(BaseComponent, ABC):
var._var_data.imports for var in self._get_vars() if var._var_data var._var_data.imports for var in self._get_vars() if var._var_data
] ]
# If any subclass implements add_imports, merge the imports. added_import_dicts: list[ParsedImportDict] = []
def _make_list(
value: str | ImportVar | list[str | ImportVar],
) -> list[str | ImportVar]:
if isinstance(value, (str, ImportVar)):
return [value]
return value
_added_import_dicts = []
for clz in self._iter_parent_classes_with_method("add_imports"): for clz in self._iter_parent_classes_with_method("add_imports"):
_added_import_dicts.append( list_of_import_dict = clz.add_imports(self)
{
package: [ if not isinstance(list_of_import_dict, list):
ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag list_of_import_dict = [list_of_import_dict]
for tag in _make_list(maybe_tags)
] for import_dict in list_of_import_dict:
for package, maybe_tags in clz.add_imports(self).items() added_import_dicts.append(parse_imports(import_dict))
}
)
return imports.merge_imports( return imports.merge_imports(
*self._get_props_imports(), *self._get_props_imports(),
@ -1355,10 +1345,10 @@ class Component(BaseComponent, ABC):
_imports, _imports,
event_imports, event_imports,
*var_imports, *var_imports,
*_added_import_dicts, *added_import_dicts,
) )
def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict: def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component and its children. """Get all the libraries and fields that are used by the component and its children.
Args: Args:
@ -1453,7 +1443,7 @@ class Component(BaseComponent, ABC):
**self._get_special_hooks(), **self._get_special_hooks(),
} }
def _get_added_hooks(self) -> dict[str, imports.ImportDict]: def _get_added_hooks(self) -> dict[str, ImportDict]:
"""Get the hooks added via `add_hooks` method. """Get the hooks added via `add_hooks` method.
Returns: Returns:
@ -1842,7 +1832,7 @@ memo = custom_component
class NoSSRComponent(Component): class NoSSRComponent(Component):
"""A dynamic component that is not rendered on the server.""" """A dynamic component that is not rendered on the server."""
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> ParsedImportDict:
"""Get the imports for the component. """Get the imports for the component.
Returns: Returns:
@ -2185,7 +2175,7 @@ class StatefulComponent(BaseComponent):
""" """
return {} return {}
def _get_all_imports(self) -> imports.ImportDict: def _get_all_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.
Returns: Returns:

View File

@ -19,7 +19,7 @@ from reflex.components.radix.themes.typography.text import Text
from reflex.components.sonner.toast import Toaster, ToastProps from reflex.components.sonner.toast import Toaster, ToastProps
from reflex.constants import Dirs, Hooks, Imports from reflex.constants import Dirs, Hooks, Imports
from reflex.constants.compiler import CompileVars from reflex.constants.compiler import CompileVars
from reflex.utils import imports from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.serializers import serialize from reflex.utils.serializers import serialize
from reflex.vars import Var, VarData from reflex.vars import Var, VarData
@ -65,10 +65,15 @@ has_too_many_connection_errors: Var = Var.create_safe(
class WebsocketTargetURL(Bare): class WebsocketTargetURL(Bare):
"""A component that renders the websocket target URL.""" """A component that renders the websocket target URL."""
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
"""Add imports for the websocket target URL component.
Returns:
The import dict.
"""
return { return {
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")], f"/{Dirs.STATE_PATH}": [ImportVar(tag="getBackendURL")],
"/env.json": [imports.ImportVar(tag="env", is_default=True)], "/env.json": [ImportVar(tag="env", is_default=True)],
} }
@classmethod @classmethod
@ -98,7 +103,7 @@ def default_connection_error() -> list[str | Var | Component]:
class ConnectionToaster(Toaster): class ConnectionToaster(Toaster):
"""A connection toaster component.""" """A connection toaster component."""
def add_hooks(self) -> list[str]: def add_hooks(self) -> list[str | Var]:
"""Add the hooks for the connection toaster. """Add the hooks for the connection toaster.
Returns: Returns:
@ -116,7 +121,7 @@ class ConnectionToaster(Toaster):
duration=120000, duration=120000,
id=toast_id, id=toast_id,
) )
hook = Var.create( hook = Var.create_safe(
f""" f"""
const toast_props = {serialize(props)}; const toast_props = {serialize(props)};
const [userDismissed, setUserDismissed] = useState(false); const [userDismissed, setUserDismissed] = useState(false);
@ -135,22 +140,17 @@ useEffect(() => {{
}}, [{connect_errors}]);""", }}, [{connect_errors}]);""",
_var_is_string=False, _var_is_string=False,
) )
imports: ImportDict = {
hook._var_data = VarData.merge( # type: ignore "react": ["useEffect", "useState"],
**target_url._get_imports(), # type: ignore
}
hook._var_data = VarData.merge(
connect_errors._var_data, connect_errors._var_data,
VarData( VarData(imports=imports),
imports={
"react": [
imports.ImportVar(tag="useEffect"),
imports.ImportVar(tag="useState"),
],
**target_url._get_imports(),
}
),
) )
return [ return [
Hooks.EVENTS, Hooks.EVENTS,
hook, # type: ignore hook,
] ]
@ -216,10 +216,11 @@ class WifiOffPulse(Icon):
"""A wifi_off icon with an animated opacity pulse.""" """A wifi_off icon with an animated opacity pulse."""
@classmethod @classmethod
def create(cls, **props) -> Component: def create(cls, *children, **props) -> Icon:
"""Create a wifi_off icon with an animated opacity pulse. """Create a wifi_off icon with an animated opacity pulse.
Args: Args:
*children: The children of the component.
**props: The properties of the component. **props: The properties of the component.
Returns: Returns:
@ -237,11 +238,13 @@ class WifiOffPulse(Icon):
**props, **props,
) )
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
return imports.merge_imports( """Add imports for the WifiOffPulse component.
super()._get_imports(),
{"@emotion/react": [imports.ImportVar(tag="keyframes")]}, Returns:
) The import dict.
"""
return {"@emotion/react": [ImportVar(tag="keyframes")]}
def _get_custom_code(self) -> str | None: def _get_custom_code(self) -> str | None:
return """ return """

View File

@ -23,7 +23,7 @@ from reflex.components.radix.themes.typography.text import Text
from reflex.components.sonner.toast import Toaster, ToastProps from reflex.components.sonner.toast import Toaster, ToastProps
from reflex.constants import Dirs, Hooks, Imports from reflex.constants import Dirs, Hooks, Imports
from reflex.constants.compiler import CompileVars from reflex.constants.compiler import CompileVars
from reflex.utils import imports from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.serializers import serialize from reflex.utils.serializers import serialize
from reflex.vars import Var, VarData from reflex.vars import Var, VarData
@ -35,6 +35,7 @@ has_connection_errors: Var
has_too_many_connection_errors: Var has_too_many_connection_errors: Var
class WebsocketTargetURL(Bare): class WebsocketTargetURL(Bare):
def add_imports(self) -> ImportDict: ...
@overload @overload
@classmethod @classmethod
def create( # type: ignore def create( # type: ignore
@ -104,7 +105,7 @@ class WebsocketTargetURL(Bare):
def default_connection_error() -> list[str | Var | Component]: ... def default_connection_error() -> list[str | Var | Component]: ...
class ConnectionToaster(Toaster): class ConnectionToaster(Toaster):
def add_hooks(self) -> list[str]: ... def add_hooks(self) -> list[str | Var]: ...
@overload @overload
@classmethod @classmethod
def create( # type: ignore def create( # type: ignore
@ -430,6 +431,7 @@ class WifiOffPulse(Icon):
"""Create a wifi_off icon with an animated opacity pulse. """Create a wifi_off icon with an animated opacity pulse.
Args: Args:
*children: The children of the component.
size: The size of the icon in pixels. size: The size of the icon in pixels.
style: The style of the component. style: The style of the component.
key: A unique key for the component. key: A unique key for the component.
@ -443,6 +445,7 @@ class WifiOffPulse(Icon):
The icon component with default props applied. The icon component with default props applied.
""" """
... ...
def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: ...
class ConnectionPulser(Div): class ConnectionPulser(Div):
@overload @overload

View File

@ -10,11 +10,12 @@ from reflex.components.tags import CondTag, Tag
from reflex.constants import Dirs from reflex.constants import Dirs
from reflex.constants.colors import Color from reflex.constants.colors import Color
from reflex.style import LIGHT_COLOR_MODE, color_mode from reflex.style import LIGHT_COLOR_MODE, color_mode
from reflex.utils import format, imports from reflex.utils import format
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var, VarData from reflex.vars import Var, VarData
_IS_TRUE_IMPORT = { _IS_TRUE_IMPORT: ImportDict = {
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")], f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
} }
@ -96,12 +97,16 @@ class Cond(MemoizationLeaf):
cond_state=f"isTrue({self.cond._var_full_name})", cond_state=f"isTrue({self.cond._var_full_name})",
) )
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
return imports.merge_imports( """Add imports for the Cond component.
super()._get_imports(),
getattr(self.cond._var_data, "imports", {}), Returns:
_IS_TRUE_IMPORT, The import dict for the component.
"""
cond_imports: dict[str, str | ImportVar | list[str | ImportVar]] = getattr(
self.cond._var_data, "imports", {}
) )
return {**cond_imports, **_IS_TRUE_IMPORT}
@overload @overload

View File

@ -8,8 +8,9 @@ from reflex.components.component import BaseComponent, Component, MemoizationLea
from reflex.components.core.colors import Color from reflex.components.core.colors import Color
from reflex.components.tags import MatchTag, Tag from reflex.components.tags import MatchTag, Tag
from reflex.style import Style from reflex.style import Style
from reflex.utils import format, imports, types from reflex.utils import format, types
from reflex.utils.exceptions import MatchTypeError from reflex.utils.exceptions import MatchTypeError
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var, VarData from reflex.vars import BaseVar, Var, VarData
@ -268,11 +269,13 @@ class Match(MemoizationLeaf):
tag.name = "match" tag.name = "match"
return dict(tag) return dict(tag)
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
return imports.merge_imports( """Add imports for the Match component.
super()._get_imports(),
getattr(self.cond._var_data, "imports", {}), Returns:
) The import dict.
"""
return getattr(self.cond._var_data, "imports", {})
match = Match.create match = Match.create

View File

@ -19,17 +19,15 @@ from reflex.event import (
call_script, call_script,
parse_args_spec, parse_args_spec,
) )
from reflex.utils import imports from reflex.utils.imports import ImportVar
from reflex.vars import BaseVar, CallableVar, Var, VarData from reflex.vars import BaseVar, CallableVar, Var, VarData
DEFAULT_UPLOAD_ID: str = "default" DEFAULT_UPLOAD_ID: str = "default"
upload_files_context_var_data: VarData = VarData( upload_files_context_var_data: VarData = VarData(
imports={ imports={
"react": [imports.ImportVar(tag="useContext")], "react": "useContext",
f"/{Dirs.CONTEXTS_PATH}": [ f"/{Dirs.CONTEXTS_PATH}": "UploadFilesContext",
imports.ImportVar(tag="UploadFilesContext"),
],
}, },
hooks={ hooks={
"const [filesById, setFilesById] = useContext(UploadFilesContext);": None, "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
@ -133,8 +131,8 @@ uploaded_files_url_prefix: Var = Var.create_safe(
_var_is_string=False, _var_is_string=False,
_var_data=VarData( _var_data=VarData(
imports={ imports={
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")], f"/{Dirs.STATE_PATH}": "getBackendURL",
"/env.json": [imports.ImportVar(tag="env", is_default=True)], "/env.json": ImportVar(tag="env", is_default=True),
} }
), ),
) )

View File

@ -23,7 +23,7 @@ from reflex.event import (
call_script, call_script,
parse_args_spec, parse_args_spec,
) )
from reflex.utils import imports from reflex.utils.imports import ImportVar
from reflex.vars import BaseVar, CallableVar, Var, VarData from reflex.vars import BaseVar, CallableVar, Var, VarData
DEFAULT_UPLOAD_ID: str DEFAULT_UPLOAD_ID: str

View File

@ -12,8 +12,8 @@ from reflex.components.core.cond import color_mode_cond
from reflex.constants.colors import Color from reflex.constants.colors import Color
from reflex.event import set_clipboard from reflex.event import set_clipboard
from reflex.style import Style from reflex.style import Style
from reflex.utils import format, imports from reflex.utils import format
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
LiteralCodeBlockTheme = Literal[ LiteralCodeBlockTheme = Literal[
@ -381,42 +381,45 @@ class CodeBlock(Component):
# Props passed down to the code tag. # Props passed down to the code tag.
code_tag_props: Var[Dict[str, str]] code_tag_props: Var[Dict[str, str]]
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
merged_imports = super()._get_imports() """Add imports for the CodeBlock component.
# Get all themes from a cond literal
Returns:
The import dict.
"""
imports_: ImportDict = {}
themes = re.findall(r"`(.*?)`", self.theme._var_name) themes = re.findall(r"`(.*?)`", self.theme._var_name)
if not themes: if not themes:
themes = [self.theme._var_name] themes = [self.theme._var_name]
merged_imports = imports.merge_imports(
merged_imports, imports_.update(
{ {
f"react-syntax-highlighter/dist/cjs/styles/prism/{self.convert_theme_name(theme)}": { f"react-syntax-highlighter/dist/cjs/styles/prism/{self.convert_theme_name(theme)}": [
ImportVar( ImportVar(
tag=format.to_camel_case(self.convert_theme_name(theme)), tag=format.to_camel_case(self.convert_theme_name(theme)),
is_default=True, is_default=True,
install=False, install=False,
) )
} ]
for theme in themes for theme in themes
}, }
) )
if ( if (
self.language is not None self.language is not None
and self.language._var_name in LiteralCodeLanguage.__args__ # type: ignore and self.language._var_name in LiteralCodeLanguage.__args__ # type: ignore
): ):
merged_imports = imports.merge_imports( imports_[
merged_imports, f"react-syntax-highlighter/dist/cjs/languages/prism/{self.language._var_name}"
{ ] = [
f"react-syntax-highlighter/dist/cjs/languages/prism/{self.language._var_name}": { ImportVar(
ImportVar( tag=format.to_camel_case(self.language._var_name),
tag=format.to_camel_case(self.language._var_name), is_default=True,
is_default=True, install=False,
install=False, )
) ]
}
}, return imports_
)
return merged_imports
def _get_custom_code(self) -> Optional[str]: def _get_custom_code(self) -> Optional[str]:
if ( if (

View File

@ -17,8 +17,8 @@ from reflex.components.core.cond import color_mode_cond
from reflex.constants.colors import Color from reflex.constants.colors import Color
from reflex.event import set_clipboard from reflex.event import set_clipboard
from reflex.style import Style from reflex.style import Style
from reflex.utils import format, imports from reflex.utils import format
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
LiteralCodeBlockTheme = Literal[ LiteralCodeBlockTheme = Literal[
@ -351,6 +351,7 @@ LiteralCodeLanguage = Literal[
] ]
class CodeBlock(Component): class CodeBlock(Component):
def add_imports(self) -> ImportDict: ...
@overload @overload
@classmethod @classmethod
def create( # type: ignore def create( # type: ignore

View File

@ -2,13 +2,14 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from reflex.base import Base from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent from reflex.components.component import Component, NoSSRComponent
from reflex.components.literals import LiteralRowMarker from reflex.components.literals import LiteralRowMarker
from reflex.utils import console, format, imports, types from reflex.event import EventHandler
from reflex.utils.imports import ImportVar from reflex.utils import console, format, types
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.serializers import serializer from reflex.utils.serializers import serializer
from reflex.vars import Var, get_unique_variable_name from reflex.vars import Var, get_unique_variable_name
@ -205,51 +206,66 @@ class DataEditor(NoSSRComponent):
# global theme # global theme
theme: Var[Union[DataEditorTheme, Dict]] theme: Var[Union[DataEditorTheme, Dict]]
def _get_imports(self): # Triggered when a cell is activated.
return imports.merge_imports( on_cell_activated: EventHandler[lambda pos: [pos]]
super()._get_imports(),
{
"": {
ImportVar(
tag=f"{format.format_library_name(self.library)}/dist/index.css"
)
},
self.library: {ImportVar(tag="GridCellKind")},
"/utils/helpers/dataeditor.js": {
ImportVar(
tag=f"formatDataEditorCells", is_default=False, install=False
),
},
},
)
def get_event_triggers(self) -> Dict[str, Callable]: # Triggered when a cell is clicked.
"""The event triggers of the component. on_cell_clicked: EventHandler[lambda pos: [pos]]
# Triggered when a cell is right-clicked.
on_cell_context_menu: EventHandler[lambda pos: [pos]]
# Triggered when a cell is edited.
on_cell_edited: EventHandler[lambda pos, data: [pos, data]]
# Triggered when a group header is clicked.
on_group_header_clicked: EventHandler[lambda pos, data: [pos, data]]
# Triggered when a group header is right-clicked.
on_group_header_context_menu: EventHandler[lambda grp_idx, data: [grp_idx, data]]
# Triggered when a group header is renamed.
on_group_header_renamed: EventHandler[lambda idx, val: [idx, val]]
# Triggered when a header is clicked.
on_header_clicked: EventHandler[lambda pos: [pos]]
# Triggered when a header is right-clicked.
on_header_context_menu: EventHandler[lambda pos: [pos]]
# Triggered when a header menu is clicked.
on_header_menu_click: EventHandler[lambda col, pos: [col, pos]]
# Triggered when an item is hovered.
on_item_hovered: EventHandler[lambda pos: [pos]]
# Triggered when a selection is deleted.
on_delete: EventHandler[lambda selection: [selection]]
# Triggered when editing is finished.
on_finished_editing: EventHandler[lambda new_value, movement: [new_value, movement]]
# Triggered when a row is appended.
on_row_appended: EventHandler[lambda: []]
# Triggered when the selection is cleared.
on_selection_cleared: EventHandler[lambda: []]
# Triggered when a column is resized.
on_column_resize: EventHandler[lambda col, width: [col, width]]
def add_imports(self) -> ImportDict:
"""Add imports for the component.
Returns: Returns:
The dict describing the event triggers. The import dict.
""" """
def edit_sig(pos, data: dict[str, Any]):
return [pos, data]
return { return {
"on_cell_activated": lambda pos: [pos], "": f"{format.format_library_name(self.library)}/dist/index.css",
"on_cell_clicked": lambda pos: [pos], self.library: "GridCellKind",
"on_cell_context_menu": lambda pos: [pos], "/utils/helpers/dataeditor.js": ImportVar(
"on_cell_edited": edit_sig, tag="formatDataEditorCells", is_default=False, install=False
"on_group_header_clicked": edit_sig, ),
"on_group_header_context_menu": lambda grp_idx, data: [grp_idx, data],
"on_group_header_renamed": lambda idx, val: [idx, val],
"on_header_clicked": lambda pos: [pos],
"on_header_context_menu": lambda pos: [pos],
"on_header_menu_click": lambda col, pos: [col, pos],
"on_item_hovered": lambda pos: [pos],
"on_delete": lambda selection: [selection],
"on_finished_editing": lambda new_value, movement: [new_value, movement],
"on_row_appended": lambda: [],
"on_selection_cleared": lambda: [],
"on_column_resize": lambda col, width: [col, width],
} }
def add_hooks(self) -> list[str]: def add_hooks(self) -> list[str]:

View File

@ -8,12 +8,13 @@ from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style from reflex.style import Style
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from reflex.base import Base from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent from reflex.components.component import Component, NoSSRComponent
from reflex.components.literals import LiteralRowMarker from reflex.components.literals import LiteralRowMarker
from reflex.utils import console, format, imports, types from reflex.event import EventHandler
from reflex.utils.imports import ImportVar from reflex.utils import console, format, types
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.serializers import serializer from reflex.utils.serializers import serializer
from reflex.vars import Var, get_unique_variable_name from reflex.vars import Var, get_unique_variable_name
@ -80,7 +81,7 @@ class DataEditorTheme(Base):
text_medium: Optional[str] text_medium: Optional[str]
class DataEditor(NoSSRComponent): class DataEditor(NoSSRComponent):
def get_event_triggers(self) -> Dict[str, Callable]: ... def add_imports(self) -> ImportDict: ...
def add_hooks(self) -> list[str]: ... def add_hooks(self) -> list[str]: ...
@overload @overload
@classmethod @classmethod
@ -136,6 +137,9 @@ class DataEditor(NoSSRComponent):
class_name: Optional[Any] = None, class_name: Optional[Any] = None,
autofocus: Optional[bool] = None, autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
on_blur: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_cell_activated: Optional[ on_cell_activated: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar] Union[EventHandler, EventSpec, list, function, BaseVar]
] = None, ] = None,
@ -148,15 +152,27 @@ class DataEditor(NoSSRComponent):
on_cell_edited: Optional[ on_cell_edited: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar] Union[EventHandler, EventSpec, list, function, BaseVar]
] = None, ] = None,
on_click: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_column_resize: Optional[ on_column_resize: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar] Union[EventHandler, EventSpec, list, function, BaseVar]
] = None, ] = None,
on_context_menu: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_delete: Optional[ on_delete: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar] Union[EventHandler, EventSpec, list, function, BaseVar]
] = None, ] = None,
on_double_click: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_finished_editing: Optional[ on_finished_editing: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar] Union[EventHandler, EventSpec, list, function, BaseVar]
] = None, ] = None,
on_focus: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_group_header_clicked: Optional[ on_group_header_clicked: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar] Union[EventHandler, EventSpec, list, function, BaseVar]
] = None, ] = None,
@ -178,12 +194,42 @@ class DataEditor(NoSSRComponent):
on_item_hovered: Optional[ on_item_hovered: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar] Union[EventHandler, EventSpec, list, function, BaseVar]
] = None, ] = None,
on_mount: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_down: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_enter: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_leave: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_move: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_out: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_over: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_mouse_up: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_row_appended: Optional[ on_row_appended: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar] Union[EventHandler, EventSpec, list, function, BaseVar]
] = None, ] = None,
on_scroll: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
on_selection_cleared: Optional[ on_selection_cleared: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar] Union[EventHandler, EventSpec, list, function, BaseVar]
] = None, ] = None,
on_unmount: Optional[
Union[EventHandler, EventSpec, list, function, BaseVar]
] = None,
**props **props
) -> "DataEditor": ) -> "DataEditor":
"""Create the DataEditor component. """Create the DataEditor component.

View File

@ -11,8 +11,8 @@ from reflex.components.el.element import Element
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.constants import Dirs, EventTriggers from reflex.constants import Dirs, EventTriggers
from reflex.event import EventChain from reflex.event import EventChain
from reflex.utils import imports
from reflex.utils.format import format_event_chain from reflex.utils.format import format_event_chain
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, Var
from .base import BaseHTML from .base import BaseHTML
@ -169,17 +169,16 @@ class Form(BaseHTML):
).hexdigest() ).hexdigest()
return form return form
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
return imports.merge_imports( """Add imports needed by the form component.
super()._get_imports(),
{ Returns:
"react": {imports.ImportVar(tag="useCallback")}, The imports for the form component.
f"/{Dirs.STATE_PATH}": { """
imports.ImportVar(tag="getRefValue"), return {
imports.ImportVar(tag="getRefValues"), "react": "useCallback",
}, f"/{Dirs.STATE_PATH}": ["getRefValue", "getRefValues"],
}, }
)
def add_hooks(self) -> list[str]: def add_hooks(self) -> list[str]:
"""Add hooks for the form. """Add hooks for the form.

View File

@ -14,8 +14,8 @@ from reflex.components.el.element import Element
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.constants import Dirs, EventTriggers from reflex.constants import Dirs, EventTriggers
from reflex.event import EventChain from reflex.event import EventChain
from reflex.utils import imports
from reflex.utils.format import format_event_chain from reflex.utils.format import format_event_chain
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, Var
from .base import BaseHTML from .base import BaseHTML
@ -581,6 +581,7 @@ class Form(BaseHTML):
The form component. The form component.
""" """
... ...
def add_imports(self) -> ImportDict: ...
def add_hooks(self) -> list[str]: ... def add_hooks(self) -> list[str]: ...
class Input(BaseHTML): class Input(BaseHTML):

View File

@ -6,7 +6,8 @@ from typing import Any, Dict, List, Union
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.tags import Tag from reflex.components.tags import Tag
from reflex.utils import imports, types from reflex.utils import types
from reflex.utils.imports import ImportDict
from reflex.utils.serializers import serialize from reflex.utils.serializers import serialize
from reflex.vars import BaseVar, ComputedVar, Var from reflex.vars import BaseVar, ComputedVar, Var
@ -102,11 +103,13 @@ class DataTable(Gridjs):
**props, **props,
) )
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
return imports.merge_imports( """Add the imports for the datatable component.
super()._get_imports(),
{"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}}, Returns:
) The import dict for the component.
"""
return {"": "gridjs/dist/theme/mermaid.css"}
def _render(self) -> Tag: def _render(self) -> Tag:
if isinstance(self.data, Var) and types.is_dataframe(self.data._var_type): if isinstance(self.data, Var) and types.is_dataframe(self.data._var_type):

View File

@ -10,7 +10,8 @@ from reflex.style import Style
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.tags import Tag from reflex.components.tags import Tag
from reflex.utils import imports, types from reflex.utils import types
from reflex.utils.imports import ImportDict
from reflex.utils.serializers import serialize from reflex.utils.serializers import serialize
from reflex.vars import BaseVar, ComputedVar, Var from reflex.vars import BaseVar, ComputedVar, Var
@ -180,3 +181,4 @@ class DataTable(Gridjs):
ValueError: If a pandas dataframe is passed in and columns are also provided. ValueError: If a pandas dataframe is passed in and columns are also provided.
""" """
... ...
def add_imports(self) -> ImportDict: ...

View File

@ -7,7 +7,6 @@ from functools import lru_cache
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
from reflex.compiler import utils
from reflex.components.component import Component, CustomComponent from reflex.components.component import Component, CustomComponent
from reflex.components.radix.themes.layout.list import ( from reflex.components.radix.themes.layout.list import (
ListItem, ListItem,
@ -18,8 +17,8 @@ from reflex.components.radix.themes.typography.heading import Heading
from reflex.components.radix.themes.typography.link import Link from reflex.components.radix.themes.typography.link import Link
from reflex.components.radix.themes.typography.text import Text from reflex.components.radix.themes.typography.text import Text
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.utils import imports, types from reflex.utils import types
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
# Special vars used in the component map. # Special vars used in the component map.
@ -145,47 +144,41 @@ class Markdown(Component):
return custom_components return custom_components
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict | list[ImportDict]:
# Import here to avoid circular imports. """Add imports for the markdown component.
Returns:
The imports for the markdown component.
"""
from reflex.components.datadisplay.code import CodeBlock from reflex.components.datadisplay.code import CodeBlock
from reflex.components.radix.themes.typography.code import Code from reflex.components.radix.themes.typography.code import Code
imports = super()._get_imports() return [
# Special markdown imports.
imports.update(
{ {
"": [ImportVar(tag="katex/dist/katex.min.css")], "": "katex/dist/katex.min.css",
"remark-math@5.1.1": [ "remark-math@5.1.1": ImportVar(
ImportVar(tag=_REMARK_MATH._var_name, is_default=True) tag=_REMARK_MATH._var_name, is_default=True
], ),
"remark-gfm@3.0.1": [ "remark-gfm@3.0.1": ImportVar(
ImportVar(tag=_REMARK_GFM._var_name, is_default=True) tag=_REMARK_GFM._var_name, is_default=True
], ),
"remark-unwrap-images@4.0.0": [ "remark-unwrap-images@4.0.0": ImportVar(
ImportVar(tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True) tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True
], ),
"rehype-katex@6.0.3": [ "rehype-katex@6.0.3": ImportVar(
ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True) tag=_REHYPE_KATEX._var_name, is_default=True
], ),
"rehype-raw@6.1.1": [ "rehype-raw@6.1.1": ImportVar(
ImportVar(tag=_REHYPE_RAW._var_name, is_default=True) tag=_REHYPE_RAW._var_name, is_default=True
], ),
} },
) *[
component(_MOCK_ARG)._get_imports() # type: ignore
# Get the imports for each component. for component in self.component_map.values()
for component in self.component_map.values(): ],
imports = utils.merge_imports( CodeBlock.create(theme="light")._get_imports(), # type: ignore,
imports, component(_MOCK_ARG)._get_all_imports() Code.create()._get_imports(), # type: ignore,
) ]
# Get the imports for the code components.
imports = utils.merge_imports(
imports, CodeBlock.create(theme="light")._get_imports()
)
imports = utils.merge_imports(imports, Code.create()._get_imports())
return imports
def get_component(self, tag: str, **props) -> Component: def get_component(self, tag: str, **props) -> Component:
"""Get the component for a tag and props. """Get the component for a tag and props.

View File

@ -11,7 +11,6 @@ import textwrap
from functools import lru_cache from functools import lru_cache
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
from reflex.compiler import utils
from reflex.components.component import Component, CustomComponent from reflex.components.component import Component, CustomComponent
from reflex.components.radix.themes.layout.list import ( from reflex.components.radix.themes.layout.list import (
ListItem, ListItem,
@ -22,8 +21,8 @@ from reflex.components.radix.themes.typography.heading import Heading
from reflex.components.radix.themes.typography.link import Link from reflex.components.radix.themes.typography.link import Link
from reflex.components.radix.themes.typography.text import Text from reflex.components.radix.themes.typography.text import Text
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.utils import imports, types from reflex.utils import types
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
_CHILDREN = Var.create_safe("children", _var_is_local=False, _var_is_string=False) _CHILDREN = Var.create_safe("children", _var_is_local=False, _var_is_string=False)
@ -124,6 +123,7 @@ class Markdown(Component):
The markdown component. The markdown component.
""" """
... ...
def add_imports(self) -> ImportDict | list[ImportDict]: ...
def get_component(self, tag: str, **props) -> Component: ... def get_component(self, tag: str, **props) -> Component: ...
def format_component(self, tag: str, **props) -> str: ... def format_component(self, tag: str, **props) -> str: ...
def format_component_map(self) -> dict[str, str]: ... def format_component_map(self) -> dict[str, str]: ...

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from reflex.base import Base from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent from reflex.components.component import Component, NoSSRComponent
from reflex.utils import imports from reflex.utils.imports import ImportDict
from reflex.vars import Var from reflex.vars import Var
@ -90,14 +90,15 @@ class Moment(NoSSRComponent):
# Display the date in the given timezone. # Display the date in the given timezone.
tz: Var[str] tz: Var[str]
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
merged_imports = super()._get_imports() """Add the imports for the Moment component.
Returns:
The import dict for the component.
"""
if self.tz is not None: if self.tz is not None:
merged_imports = imports.merge_imports( return {"moment-timezone": ""}
merged_imports, return {}
{"moment-timezone": {imports.ImportVar(tag="")}},
)
return merged_imports
def get_event_triggers(self) -> Dict[str, Any]: def get_event_triggers(self) -> Dict[str, Any]:
"""Get the events triggers signatures for the component. """Get the events triggers signatures for the component.

View File

@ -10,7 +10,7 @@ from reflex.style import Style
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from reflex.base import Base from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent from reflex.components.component import Component, NoSSRComponent
from reflex.utils import imports from reflex.utils.imports import ImportDict
from reflex.vars import Var from reflex.vars import Var
class MomentDelta(Base): class MomentDelta(Base):
@ -25,6 +25,7 @@ class MomentDelta(Base):
milliseconds: Optional[int] milliseconds: Optional[int]
class Moment(NoSSRComponent): class Moment(NoSSRComponent):
def add_imports(self) -> ImportDict: ...
def get_event_triggers(self) -> Dict[str, Any]: ... def get_event_triggers(self) -> Dict[str, Any]: ...
@overload @overload
@classmethod @classmethod

View File

@ -11,7 +11,6 @@ from reflex.components.lucide.icon import Icon
from reflex.components.radix.primitives.base import RadixPrimitiveComponent from reflex.components.radix.primitives.base import RadixPrimitiveComponent
from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius
from reflex.style import Style from reflex.style import Style
from reflex.utils import imports
from reflex.vars import Var, get_uuid_string_var from reflex.vars import Var, get_uuid_string_var
LiteralAccordionType = Literal["single", "multiple"] LiteralAccordionType = Literal["single", "multiple"]
@ -413,13 +412,13 @@ class AccordionContent(AccordionComponent):
alias = "RadixAccordionContent" alias = "RadixAccordionContent"
def add_imports(self) -> imports.ImportDict: def add_imports(self) -> dict:
"""Add imports to the component. """Add imports to the component.
Returns: Returns:
The imports of the component. The imports of the component.
""" """
return {"@emotion/react": [imports.ImportVar(tag="keyframes")]} return {"@emotion/react": "keyframes"}
@classmethod @classmethod
def create(cls, *children, **props) -> Component: def create(cls, *children, **props) -> Component:

View File

@ -15,7 +15,6 @@ from reflex.components.lucide.icon import Icon
from reflex.components.radix.primitives.base import RadixPrimitiveComponent from reflex.components.radix.primitives.base import RadixPrimitiveComponent
from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius
from reflex.style import Style from reflex.style import Style
from reflex.utils import imports
from reflex.vars import Var, get_uuid_string_var from reflex.vars import Var, get_uuid_string_var
LiteralAccordionType = Literal["single", "multiple"] LiteralAccordionType = Literal["single", "multiple"]
@ -899,7 +898,7 @@ class AccordionIcon(Icon):
... ...
class AccordionContent(AccordionComponent): class AccordionContent(AccordionComponent):
def add_imports(self) -> imports.ImportDict: ... def add_imports(self) -> dict: ...
@overload @overload
@classmethod @classmethod
def create( # type: ignore def create( # type: ignore

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Literal
from reflex.components import Component from reflex.components import Component
from reflex.components.tags import Tag from reflex.components.tags import Tag
from reflex.config import get_config from reflex.config import get_config
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"] LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
@ -209,13 +209,13 @@ class Theme(RadixThemesComponent):
children = [ThemePanel.create(), *children] children = [ThemePanel.create(), *children]
return super().create(*children, **props) return super().create(*children, **props)
def add_imports(self) -> dict[str, list[ImportVar] | ImportVar]: def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the Theme component. """Add imports for the Theme component.
Returns: Returns:
The import dict. The import dict.
""" """
_imports: dict[str, list[ImportVar] | ImportVar] = { _imports: ImportDict = {
"/utils/theme.js": [ImportVar(tag="theme", is_default=True)], "/utils/theme.js": [ImportVar(tag="theme", is_default=True)],
} }
if get_config().tailwind is None: if get_config().tailwind is None:

View File

@ -11,7 +11,7 @@ from typing import Any, Dict, Literal
from reflex.components import Component from reflex.components import Component
from reflex.components.tags import Tag from reflex.components.tags import Tag
from reflex.config import get_config from reflex.config import get_config
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"] LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
@ -580,7 +580,7 @@ class Theme(RadixThemesComponent):
A new component instance. A new component instance.
""" """
... ...
def add_imports(self) -> dict[str, list[ImportVar] | ImportVar]: ... def add_imports(self) -> ImportDict | list[ImportDict]: ...
class ThemePanel(RadixThemesComponent): class ThemePanel(RadixThemesComponent):
def add_imports(self) -> dict[str, str]: ... def add_imports(self) -> dict[str, str]: ...

View File

@ -12,7 +12,7 @@ from reflex.components.core.colors import color
from reflex.components.core.cond import cond from reflex.components.core.cond import cond
from reflex.components.el.elements.inline import A from reflex.components.el.elements.inline import A
from reflex.components.next.link import NextLink from reflex.components.next.link import NextLink
from reflex.utils import imports from reflex.utils.imports import ImportDict
from reflex.vars import Var from reflex.vars import Var
from ..base import ( from ..base import (
@ -59,8 +59,13 @@ class Link(RadixThemesComponent, A, MemoizationLeaf):
# If True, the link will open in a new tab # If True, the link will open in a new tab
is_external: Var[bool] is_external: Var[bool]
def _get_imports(self) -> imports.ImportDict: def add_imports(self) -> ImportDict:
return {**super()._get_imports(), **next_link._get_imports()} """Add imports for the Link component.
Returns:
The import dict.
"""
return next_link._get_imports() # type: ignore
@classmethod @classmethod
def create(cls, *children, **props) -> Component: def create(cls, *children, **props) -> Component:

View File

@ -13,7 +13,7 @@ from reflex.components.core.colors import color
from reflex.components.core.cond import cond from reflex.components.core.cond import cond
from reflex.components.el.elements.inline import A from reflex.components.el.elements.inline import A
from reflex.components.next.link import NextLink from reflex.components.next.link import NextLink
from reflex.utils import imports from reflex.utils.imports import ImportDict
from reflex.vars import Var from reflex.vars import Var
from ..base import LiteralAccentColor, RadixThemesComponent from ..base import LiteralAccentColor, RadixThemesComponent
from .base import LiteralTextSize, LiteralTextTrim, LiteralTextWeight from .base import LiteralTextSize, LiteralTextTrim, LiteralTextWeight
@ -22,6 +22,7 @@ LiteralLinkUnderline = Literal["auto", "hover", "always", "none"]
next_link = NextLink.create() next_link = NextLink.create()
class Link(RadixThemesComponent, A, MemoizationLeaf): class Link(RadixThemesComponent, A, MemoizationLeaf):
def add_imports(self) -> ImportDict: ...
@overload @overload
@classmethod @classmethod
def create( # type: ignore def create( # type: ignore

View File

@ -8,7 +8,7 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent from reflex.components.component import Component, NoSSRComponent
from reflex.constants import EventTriggers from reflex.constants import EventTriggers
from reflex.utils.format import to_camel_case from reflex.utils.format import to_camel_case
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
@ -176,12 +176,15 @@ class Editor(NoSSRComponent):
# default: False # default: False
disable_toolbar: Var[bool] disable_toolbar: Var[bool]
def _get_imports(self): def add_imports(self) -> ImportDict:
imports = super()._get_imports() """Add imports for the Editor component.
imports[""] = [
ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False) Returns:
] The import dict.
return imports """
return {
"": ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False)
}
def get_event_triggers(self) -> Dict[str, Any]: def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.

View File

@ -13,7 +13,7 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent from reflex.components.component import Component, NoSSRComponent
from reflex.constants import EventTriggers from reflex.constants import EventTriggers
from reflex.utils.format import to_camel_case from reflex.utils.format import to_camel_case
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var from reflex.vars import Var
class EditorButtonList(list, enum.Enum): class EditorButtonList(list, enum.Enum):
@ -48,6 +48,7 @@ class EditorOptions(Base):
button_list: Optional[List[Union[List[str], str]]] button_list: Optional[List[Union[List[str], str]]]
class Editor(NoSSRComponent): class Editor(NoSSRComponent):
def add_imports(self) -> ImportDict: ...
def get_event_triggers(self) -> Dict[str, Any]: ... def get_event_triggers(self) -> Dict[str, Any]: ...
@overload @overload
@classmethod @classmethod

View File

@ -3,12 +3,12 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
from reflex.base import Base from reflex.base import Base
def merge_imports(*imports) -> ImportDict: def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
"""Merge multiple import dicts together. """Merge multiple import dicts together.
Args: Args:
@ -24,7 +24,31 @@ def merge_imports(*imports) -> ImportDict:
return all_imports return all_imports
def collapse_imports(imports: ImportDict) -> ImportDict: def parse_imports(imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
"""Parse the import dict into a standard format.
Args:
imports: The import dict to parse.
Returns:
The parsed import dict.
"""
def _make_list(value: ImportTypes) -> list[str | ImportVar] | list[ImportVar]:
if isinstance(value, (str, ImportVar)):
return [value]
return value
return {
package: [
ImportVar(tag=tag) if isinstance(tag, str) else tag
for tag in _make_list(maybe_tags)
]
for package, maybe_tags in imports.items()
}
def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict:
"""Remove all duplicate ImportVar within an ImportDict. """Remove all duplicate ImportVar within an ImportDict.
Args: Args:
@ -33,7 +57,10 @@ def collapse_imports(imports: ImportDict) -> ImportDict:
Returns: Returns:
The collapsed import dict. The collapsed import dict.
""" """
return {lib: list(set(import_vars)) for lib, import_vars in imports.items()} return {
lib: list(set(import_vars)) if isinstance(import_vars, list) else import_vars
for lib, import_vars in imports.items()
}
class ImportVar(Base): class ImportVar(Base):
@ -90,4 +117,6 @@ class ImportVar(Base):
) )
ImportDict = Dict[str, List[ImportVar]] ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]]
ImportDict = Dict[str, ImportTypes]
ParsedImportDict = Dict[str, List[ImportVar]]

View File

@ -39,7 +39,12 @@ from reflex.utils import console, imports, serializers, types
from reflex.utils.exceptions import VarAttributeError, VarTypeError, VarValueError from reflex.utils.exceptions import VarAttributeError, VarTypeError, VarValueError
# This module used to export ImportVar itself, so we still import it for export here # This module used to export ImportVar itself, so we still import it for export here
from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.imports import (
ImportDict,
ImportVar,
ParsedImportDict,
parse_imports,
)
from reflex.utils.types import override from reflex.utils.types import override
if TYPE_CHECKING: if TYPE_CHECKING:
@ -120,7 +125,7 @@ class VarData(Base):
state: str = "" state: str = ""
# Imports needed to render this var # Imports needed to render this var
imports: ImportDict = {} imports: ParsedImportDict = {}
# Hooks that need to be present in the component to render this var # Hooks that need to be present in the component to render this var
hooks: Dict[str, None] = {} hooks: Dict[str, None] = {}
@ -130,6 +135,19 @@ class VarData(Base):
# segments. # segments.
interpolations: List[Tuple[int, int]] = [] interpolations: List[Tuple[int, int]] = []
def __init__(
self, imports: Union[ImportDict, ParsedImportDict] | None = None, **kwargs: Any
):
"""Initialize the var data.
Args:
imports: The imports needed to render this var.
**kwargs: The var data fields.
"""
if imports:
kwargs["imports"] = parse_imports(imports)
super().__init__(**kwargs)
@classmethod @classmethod
def merge(cls, *others: VarData | None) -> VarData | None: def merge(cls, *others: VarData | None) -> VarData | None:
"""Merge multiple var data objects. """Merge multiple var data objects.

View File

@ -10,7 +10,7 @@ from reflex.base import Base as Base
from reflex.state import State as State from reflex.state import State as State
from reflex.state import BaseState as BaseState from reflex.state import BaseState as BaseState
from reflex.utils import console as console, format as format, types as types from reflex.utils import console as console, format as format, types as types
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar, ImportDict, ParsedImportDict
from types import FunctionType from types import FunctionType
from typing import ( from typing import (
Any, Any,
@ -36,7 +36,7 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
class VarData(Base): class VarData(Base):
state: str = "" state: str = ""
imports: dict[str, List[ImportVar]] = {} imports: Union[ImportDict, ParsedImportDict] = {}
hooks: Dict[str, None] = {} hooks: Dict[str, None] = {}
interpolations: List[Tuple[int, int]] = [] interpolations: List[Tuple[int, int]] = []
@classmethod @classmethod

View File

@ -4,8 +4,7 @@ from typing import List
import pytest import pytest
from reflex.compiler import compiler, utils from reflex.compiler import compiler, utils
from reflex.utils import imports from reflex.utils.imports import ImportVar, ParsedImportDict
from reflex.utils.imports import ImportVar
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -93,7 +92,7 @@ def test_compile_import_statement(
), ),
], ],
) )
def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]): def test_compile_imports(import_dict: ParsedImportDict, test_dicts: List[dict]):
"""Test the compile_imports function. """Test the compile_imports function.
Args: Args:

View File

@ -20,7 +20,7 @@ from reflex.event import EventChain, EventHandler, parse_args_spec
from reflex.state import BaseState from reflex.state import BaseState
from reflex.style import Style from reflex.style import Style
from reflex.utils import imports from reflex.utils import imports
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import BaseVar, Var, VarData from reflex.vars import BaseVar, Var, VarData
@ -56,7 +56,7 @@ def component1() -> Type[Component]:
# A test string/number prop. # A test string/number prop.
text_or_number: Var[Union[int, str]] text_or_number: Var[Union[int, str]]
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> ParsedImportDict:
return {"react": [ImportVar(tag="Component")]} return {"react": [ImportVar(tag="Component")]}
def _get_custom_code(self) -> str: def _get_custom_code(self) -> str:
@ -89,7 +89,7 @@ def component2() -> Type[Component]:
"on_close": lambda e0: [e0], "on_close": lambda e0: [e0],
} }
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> ParsedImportDict:
return {"react-redux": [ImportVar(tag="connect")]} return {"react-redux": [ImportVar(tag="connect")]}
def _get_custom_code(self) -> str: def _get_custom_code(self) -> str:
@ -1773,21 +1773,15 @@ def test_invalid_event_trigger():
), ),
) )
def test_component_add_imports(tags): def test_component_add_imports(tags):
def _list_to_import_vars(tags: List[str]) -> List[ImportVar]:
return [
ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
for tag in tags
]
class BaseComponent(Component): class BaseComponent(Component):
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> ImportDict:
return {} return {}
class Reference(Component): class Reference(Component):
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> ParsedImportDict:
return imports.merge_imports( return imports.merge_imports(
super()._get_imports(), super()._get_imports(),
{"react": _list_to_import_vars(tags)}, parse_imports({"react": tags}),
{"foo": [ImportVar(tag="bar")]}, {"foo": [ImportVar(tag="bar")]},
) )
@ -1806,10 +1800,12 @@ def test_component_add_imports(tags):
baseline = Reference.create() baseline = Reference.create()
test = Test.create() test = Test.create()
assert baseline._get_all_imports() == { assert baseline._get_all_imports() == parse_imports(
"react": _list_to_import_vars(tags), {
"foo": [ImportVar(tag="bar")], "react": tags,
} "foo": [ImportVar(tag="bar")],
}
)
assert test._get_all_imports() == baseline._get_all_imports() assert test._get_all_imports() == baseline._get_all_imports()

View File

@ -1,6 +1,12 @@
import pytest import pytest
from reflex.utils.imports import ImportVar, merge_imports from reflex.utils.imports import (
ImportDict,
ImportVar,
ParsedImportDict,
merge_imports,
parse_imports,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -76,3 +82,32 @@ def test_merge_imports(input_1, input_2, output):
for key in output: for key in output:
assert set(res[key]) == set(output[key]) assert set(res[key]) == set(output[key])
@pytest.mark.parametrize(
"input, output",
[
({}, {}),
(
{"react": "Component"},
{"react": [ImportVar(tag="Component")]},
),
(
{"react": ["Component"]},
{"react": [ImportVar(tag="Component")]},
),
(
{"react": ["Component", ImportVar(tag="useState")]},
{"react": [ImportVar(tag="Component"), ImportVar(tag="useState")]},
),
(
{"react": ["Component"], "foo": "anotherFunction"},
{
"react": [ImportVar(tag="Component")],
"foo": [ImportVar(tag="anotherFunction")],
},
),
],
)
def test_parse_imports(input: ImportDict, output: ParsedImportDict):
assert parse_imports(input) == output