diff --git a/reflex/app.py b/reflex/app.py index 5c698d1a7..f7764829d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -114,6 +114,9 @@ class App(Base): # The async server name space event_namespace: Optional[EventNamespace] = None + # Components to add to the head of every page. + head_components: List[Component] = [] + # A component that is present on every page. overlay_component: Optional[ Union[Component, ComponentCallable] @@ -401,6 +404,12 @@ class App(Base): # Add script tags if given if script_tags: + console.deprecate( + feature_name="Passing script tags to add_page", + reason="Add script components as children to the page component instead", + deprecation_version="v0.2.9", + removal_version="v0.2.11", + ) component.children.extend(script_tags) # Add the page. @@ -629,7 +638,7 @@ class App(Base): compile_results.append(compiler.compile_root_stylesheet(self.stylesheets)) # Compile the root document. - compile_results.append(compiler.compile_document_root()) + compile_results.append(compiler.compile_document_root(self.head_components)) # Compile the theme. compile_results.append(compiler.compile_theme(self.style)) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 1799f9883..9a7c76a29 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -3,7 +3,7 @@ from __future__ import annotations import os from pathlib import Path -from typing import List, Set, Tuple, Type +from typing import Type from reflex import constants from reflex.compiler import templates, utils @@ -121,7 +121,7 @@ def _compile_page( ) -def compile_root_stylesheet(stylesheets: List[str]) -> Tuple[str, str]: +def compile_root_stylesheet(stylesheets: list[str]) -> tuple[str, str]: """Compile the root stylesheet. Args: @@ -137,7 +137,7 @@ def compile_root_stylesheet(stylesheets: List[str]) -> Tuple[str, str]: return output_path, code -def _compile_root_stylesheet(stylesheets: List[str]) -> str: +def _compile_root_stylesheet(stylesheets: list[str]) -> str: """Compile the root stylesheet. Args: @@ -182,7 +182,7 @@ def _compile_component(component: Component) -> str: return templates.COMPONENT.render(component=component) -def _compile_components(components: Set[CustomComponent]) -> str: +def _compile_components(components: set[CustomComponent]) -> str: """Compile the components. Args: @@ -226,9 +226,12 @@ def _compile_tailwind( ) -def compile_document_root() -> Tuple[str, str]: +def compile_document_root(head_components: list[Component]) -> tuple[str, str]: """Compile the document root. + Args: + head_components: The components to include in the head. + Returns: The path and code of the compiled document root. """ @@ -236,13 +239,14 @@ def compile_document_root() -> Tuple[str, str]: output_path = utils.get_page_path(constants.DOCUMENT_ROOT) # Create the document root. - document_root = utils.create_document_root() + document_root = utils.create_document_root(head_components) + # Compile the document root. code = _compile_document_root(document_root) return output_path, code -def compile_theme(style: ComponentStyle) -> Tuple[str, str]: +def compile_theme(style: ComponentStyle) -> tuple[str, str]: """Compile the theme. Args: @@ -261,9 +265,7 @@ def compile_theme(style: ComponentStyle) -> Tuple[str, str]: return output_path, code -def compile_contexts( - state: Type[State], -) -> Tuple[str, str]: +def compile_contexts(state: Type[State]) -> tuple[str, str]: """Compile the initial state / context. Args: @@ -279,10 +281,8 @@ def compile_contexts( def compile_page( - path: str, - component: Component, - state: Type[State], -) -> Tuple[str, str]: + path: str, component: Component, state: Type[State] +) -> tuple[str, str]: """Compile a single page. Args: @@ -301,7 +301,7 @@ def compile_page( return output_path, code -def compile_components(components: Set[CustomComponent]): +def compile_components(components: set[CustomComponent]): """Compile the custom components. Args: diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 27c4d930c..617427312 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -2,7 +2,7 @@ from __future__ import annotations import os -from typing import Any, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Type from urllib.parse import urlparse from pydantic.fields import ModelField @@ -31,7 +31,7 @@ from reflex.vars import ImportVar merge_imports = imports.merge_imports -def compile_import_statement(fields: Set[ImportVar]) -> Tuple[str, Set[str]]: +def compile_import_statement(fields: set[ImportVar]) -> tuple[str, set[str]]: """Compile an import statement. Args: @@ -79,7 +79,7 @@ def validate_imports(imports: imports.ImportDict): used_tags[import_name] = lib -def compile_imports(imports: imports.ImportDict) -> List[dict]: +def compile_imports(imports: imports.ImportDict) -> list[dict]: """Compile an import dict. Args: @@ -112,7 +112,7 @@ def compile_imports(imports: imports.ImportDict) -> List[dict]: return import_dicts -def get_import_dict(lib: str, default: str = "", rest: Optional[Set] = None) -> Dict: +def get_import_dict(lib: str, default: str = "", rest: set[str] | None = None) -> dict: """Get dictionary for import template. Args: @@ -130,7 +130,7 @@ def get_import_dict(lib: str, default: str = "", rest: Optional[Set] = None) -> } -def compile_state(state: Type[State]) -> Dict: +def compile_state(state: Type[State]) -> dict: """Compile the state of the app. Args: @@ -225,7 +225,7 @@ def compile_client_storage(state: Type[State]) -> dict[str, dict]: def compile_custom_component( component: CustomComponent, -) -> Tuple[dict, imports.ImportDict]: +) -> tuple[dict, imports.ImportDict]: """Compile a custom component. Args: @@ -258,14 +258,18 @@ def compile_custom_component( ) -def create_document_root() -> Component: +def create_document_root(head_components: list[Component] | None = None) -> Component: """Create the document root. + Args: + head_components: The components to add to the head. + Returns: The document root. """ + head_components = head_components or [] return Html.create( - DocumentHead.create(), + DocumentHead.create(*head_components), Body.create( ColorModeScript.create(), Main.create(), @@ -274,7 +278,7 @@ def create_document_root() -> Component: ) -def create_theme(style: ComponentStyle) -> Dict: +def create_theme(style: ComponentStyle) -> dict: """Create the base style for the app. Args: @@ -350,11 +354,11 @@ def get_components_path() -> str: return os.path.join(constants.WEB_UTILS_DIR, "components" + constants.JS_EXT) -def get_asset_path(filename: Optional[str] = None) -> str: +def get_asset_path(filename: str | None = None) -> str: """Get the path for an asset. Args: - filename: Optional, if given, is added to the root path of assets dir. + filename: If given, is added to the root path of assets dir. Returns: The path of the asset. @@ -366,7 +370,7 @@ def get_asset_path(filename: Optional[str] = None) -> str: def add_meta( - page: Component, title: str, image: str, description: str, meta: List[Dict] + page: Component, title: str, image: str, description: str, meta: list[dict] ) -> Component: """Add metadata to a page. @@ -406,7 +410,7 @@ def write_page(path: str, code: str): f.write(code) -def empty_dir(path: str, keep_files: Optional[List[str]] = None): +def empty_dir(path: str, keep_files: list[str] | None = None): """Remove all files and folders in a directory except for the keep_files. Args: diff --git a/tests/compiler/test_compiler.py b/tests/compiler/test_compiler.py index 0b6cd6625..5329b9778 100644 --- a/tests/compiler/test_compiler.py +++ b/tests/compiler/test_compiler.py @@ -189,3 +189,22 @@ def test_compile_nonexistent_stylesheet(tmp_path, mocker): with pytest.raises(FileNotFoundError): compiler.compile_root_stylesheet(stylesheets) + + +def test_create_document_root(): + """Test that the document root is created correctly.""" + # Test with no components. + root = utils.create_document_root() + assert isinstance(root, utils.Html) + assert isinstance(root.children[0], utils.DocumentHead) + # No children in head. + assert len(root.children[0].children) == 0 + + # Test with components. + comps = [ + utils.NextScript.create(src="foo.js"), + utils.NextScript.create(src="bar.js"), + ] + root = utils.create_document_root(head_components=comps) # type: ignore + # Two children in head. + assert len(root.children[0].children) == 2