diff --git a/reflex/app.py b/reflex/app.py index 7b7010521..281727b3f 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -164,11 +164,11 @@ def default_backend_exception_handler(exception: Exception) -> EventSpec: return window_alert("\n".join(error_message)) -def default_overlay_component() -> Component: - """Default overlay_component attribute for App. +def extra_overlay_function() -> Optional[Component]: + """Extra overlay function to add to the overlay component. Returns: - The default overlay_component, which is a connection_modal. + The extra overlay function. """ config = get_config() @@ -178,7 +178,8 @@ def default_overlay_component() -> Component: module, _, function_name = extra_config.rpartition(".") try: module = __import__(module) - config_overlay = getattr(module, function_name)() + config_overlay = Fragment.create(getattr(module, function_name)()) + config_overlay._get_all_imports() except Exception as e: from reflex.compiler.utils import save_error @@ -188,13 +189,27 @@ def default_overlay_component() -> Component: f"Error loading extra_overlay_function {extra_config}. Error saved to {log_path}" ) - return Fragment.create( - connection_pulser(), - connection_toaster(), - *([config_overlay] if config_overlay else []), - *([backend_disabled()] if config.is_reflex_cloud else []), - *codespaces.codespaces_auto_redirect(), - ) + return config_overlay + + +def default_overlay_component() -> Component: + """Default overlay_component attribute for App. + + Returns: + The default overlay_component, which is a connection_modal. + """ + config = get_config() + from reflex.components.component import memo + + def default_overlay_components(): + return Fragment.create( + connection_pulser(), + connection_toaster(), + *([backend_disabled()] if config.is_reflex_cloud else []), + *codespaces.codespaces_auto_redirect(), + ) + + return Fragment.create(memo(default_overlay_components)()) def default_error_boundary(*children: Component) -> Component: @@ -266,11 +281,26 @@ class App(MiddlewareMixin, LifespanMixin): # A component that is present on every page (defaults to the Connection Error banner). overlay_component: Optional[Union[Component, ComponentCallable]] = ( - dataclasses.field(default_factory=default_overlay_component) + dataclasses.field(default=None) ) # Error boundary component to wrap the app with. - error_boundary: Optional[ComponentCallable] = default_error_boundary + error_boundary: Optional[ComponentCallable] = dataclasses.field(default=None) + + # App wraps to be applied to the whole app. Expected to be a dictionary of (order, name) to a function that takes whether the state is enabled and optionally returns a component. + app_wraps: Dict[tuple[int, str], Callable[[bool], Optional[Component]]] = ( + dataclasses.field( + default_factory=lambda: { + (55, "ErrorBoundary"): ( + lambda stateful: default_error_boundary() if stateful else None + ), + (5, "Overlay"): ( + lambda stateful: default_overlay_component() if stateful else None + ), + (4, "ExtraOverlay"): lambda stateful: extra_overlay_function(), + } + ) + ) # Components to add to the head of every page. head_components: List[Component] = dataclasses.field(default_factory=list) @@ -880,25 +910,6 @@ class App(MiddlewareMixin, LifespanMixin): for k, component in self._pages.items(): self._pages[k] = self._add_overlay_to_component(component) - def _add_error_boundary_to_component(self, component: Component) -> Component: - if self.error_boundary is None: - return component - - component = self.error_boundary(*component.children) - - return component - - def _setup_error_boundary(self): - """If a State is not used and no error_boundary is specified, do not render the error boundary.""" - if self._state is None and self.error_boundary is default_error_boundary: - self.error_boundary = None - - for k, component in self._pages.items(): - # Skip the 404 page - if k == constants.Page404.SLUG: - continue - self._pages[k] = self._add_error_boundary_to_component(component) - def _setup_sticky_badge(self): """Add the sticky badge to the app.""" for k, component in self._pages.items(): @@ -1039,7 +1050,6 @@ class App(MiddlewareMixin, LifespanMixin): self._validate_var_dependencies() self._setup_overlay_component() - self._setup_error_boundary() if is_prod_mode() and config.show_built_with_reflex: self._setup_sticky_badge() @@ -1066,6 +1076,22 @@ class App(MiddlewareMixin, LifespanMixin): # Add the custom components from the page to the set. custom_components |= component._get_all_custom_components() + # Add the app wraps to the app. + for key, app_wrap in self.app_wraps.items(): + component = app_wrap(self._state is not None) + if component is not None: + app_wrappers[key] = component + custom_components |= component._get_all_custom_components() + + if self.error_boundary: + console.deprecate( + feature_name="App.error_boundary", + reason="Use app_wraps instead.", + deprecation_version="0.7.1", + removal_version="0.8.0", + ) + app_wrappers[(55, "ErrorBoundary")] = self.error_boundary() + # Perform auto-memoization of stateful components. with console.timing("Auto-memoize StatefulComponents"): ( diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index c2a76aad3..7cd87fb71 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -78,6 +78,7 @@ def _compile_app(app_root: Component) -> str: hooks=app_root._get_all_hooks(), window_libraries=window_libraries, render=app_root.render(), + dynamic_imports=app_root._get_all_dynamic_imports(), ) diff --git a/reflex/components/component.py b/reflex/components/component.py index 6e4c6c37f..9466933c5 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -23,6 +23,8 @@ from typing import ( Union, ) +from typing_extensions import Self + import reflex.state from reflex.base import Base from reflex.compiler.templates import STATEFUL_COMPONENT @@ -685,7 +687,7 @@ class Component(BaseComponent, ABC): } @classmethod - def create(cls, *children, **props) -> Component: + def create(cls, *children, **props) -> Self: """Create the component. Args: diff --git a/reflex/testing.py b/reflex/testing.py index 25f9e7aac..e463ddea7 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -43,6 +43,7 @@ import reflex.utils.exec import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes +from reflex.components.component import CustomComponent from reflex.config import environment from reflex.state import ( BaseState, @@ -254,6 +255,7 @@ class AppHarness: # disable telemetry reporting for tests os.environ["TELEMETRY_ENABLED"] = "false" + CustomComponent.create().get_component.cache_clear() self.app_path.mkdir(parents=True, exist_ok=True) if self.app_source is not None: app_globals = self._get_globals_from_signature(self.app_source) diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 7318b7ff6..3e729cbf8 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -254,15 +254,12 @@ def get_reload_paths() -> Sequence[Path]: if config.app_module is not None and config.app_module.__file__: module_path = Path(config.app_module.__file__).resolve().parent - while module_path.parent.name: - if any( - sibling_file.name == "__init__.py" - for sibling_file in module_path.parent.iterdir() - ): - # go up a level to find dir without `__init__.py` - module_path = module_path.parent - else: - break + while module_path.parent.name and any( + sibling_file.name == "__init__.py" + for sibling_file in module_path.parent.iterdir() + ): + # go up a level to find dir without `__init__.py` + module_path = module_path.parent reload_paths = [module_path] diff --git a/tests/units/test_app.py b/tests/units/test_app.py index ae5a01c1a..5d25b09ac 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1299,6 +1299,7 @@ def test_app_wrap_compile_theme( app_js_lines = [ line.strip() for line in app_js_contents.splitlines() if line.strip() ] + lines = "".join(app_js_lines) assert ( "function AppWrap({children}) {" "return (" @@ -1313,7 +1314,7 @@ def test_app_wrap_compile_theme( + ("" if react_strict_mode else "") + ")" "}" - ) in "".join(app_js_lines) + ) in lines @pytest.mark.parametrize( @@ -1362,6 +1363,7 @@ def test_app_wrap_priority( app_js_lines = [ line.strip() for line in app_js_contents.splitlines() if line.strip() ] + lines = "".join(app_js_lines) assert ( "function AppWrap({children}) {" "return (" + ("" if react_strict_mode else "") + "" @@ -1374,9 +1376,8 @@ def test_app_wrap_priority( "" "" "" - "" + ("" if react_strict_mode else "") + ")" - "}" - ) in "".join(app_js_lines) + "" + ("" if react_strict_mode else "") + ) in lines def test_app_state_determination():