diff --git a/integration/test_connection_banner.py b/integration/test_connection_banner.py new file mode 100644 index 000000000..c078df1f7 --- /dev/null +++ b/integration/test_connection_banner.py @@ -0,0 +1,90 @@ +"""Test case for displaying the connection banner when the websocket drops.""" + +from typing import Generator + +import pytest +from selenium.common.exceptions import NoSuchElementException +from selenium.webdriver.common.by import By + +from reflex.testing import AppHarness, WebDriver + + +def ConnectionBanner(): + """App with a connection banner.""" + import reflex as rx + + class State(rx.State): + foo: int = 0 + + def index(): + return rx.text("Hello World") + + app = rx.App(state=State) + app.add_page(index) + app.compile() + + +@pytest.fixture() +def connection_banner(tmp_path) -> Generator[AppHarness, None, None]: + """Start ConnectionBanner app at tmp_path via AppHarness. + + Args: + tmp_path: pytest tmp_path fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path, + app_source=ConnectionBanner, # type: ignore + ) as harness: + yield harness + + +CONNECTION_ERROR_XPATH = "//*[ text() = 'Connection Error' ]" + + +def has_error_modal(driver: WebDriver) -> bool: + """Check if the connection error modal is displayed. + + Args: + driver: Selenium webdriver instance. + + Returns: + True if the modal is displayed, False otherwise. + """ + try: + driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH) + return True + except NoSuchElementException: + return False + + +def test_connection_banner(connection_banner: AppHarness): + """Test that the connection banner is displayed when the websocket drops. + + Args: + connection_banner: AppHarness instance. + """ + assert connection_banner.app_instance is not None + assert connection_banner.backend is not None + driver = connection_banner.frontend() + + connection_banner._poll_for(lambda: not has_error_modal(driver)) + + # Get the backend port + backend_port = connection_banner._poll_for_servers().getsockname()[1] + + # Kill the backend + connection_banner.backend.should_exit = True + if connection_banner.backend_thread is not None: + connection_banner.backend_thread.join() + + # Error modal should now be displayed + connection_banner._poll_for(lambda: has_error_modal(driver)) + + # Bring the backend back up + connection_banner._start_backend(port=backend_port) + + # Banner should be gone now + connection_banner._poll_for(lambda: not has_error_modal(driver)) diff --git a/reflex/app.py b/reflex/app.py index a1ba828ec..a46bd89d2 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -57,6 +57,15 @@ ComponentCallable = Callable[[], Component] Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]] +def default_overlay_component() -> Component: + """Default overlay_component attribute for App. + + Returns: + The default overlay_component, which is a connection_modal. + """ + return connection_modal() + + class App(Base): """A Reflex application.""" @@ -97,7 +106,9 @@ class App(Base): event_namespace: Optional[AsyncNamespace] = None # A component that is present on every page. - overlay_component: Optional[Union[Component, ComponentCallable]] = connection_modal + overlay_component: Optional[ + Union[Component, ComponentCallable] + ] = default_overlay_component def __init__(self, *args, **kwargs): """Initialize the app. @@ -179,6 +190,13 @@ class App(Base): # Set up the admin dash. self.setup_admin_dash() + # If a State is not used and no overlay_component is specified, do not render the connection modal + if ( + self.state is DefaultState + and self.overlay_component is default_overlay_component + ): + self.overlay_component = None + def __repr__(self) -> str: """Get the string representation of the app. diff --git a/reflex/testing.py b/reflex/testing.py index c40f7744d..953f1ddbe 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -163,14 +163,14 @@ class AppHarness: self.app_module = reflex.utils.prerequisites.get_app(reload=True) self.app_instance = self.app_module.app - def _start_backend(self): + def _start_backend(self, port=0): if self.app_instance is None: raise RuntimeError("App was not initialized.") self.backend = uvicorn.Server( uvicorn.Config( app=self.app_instance.api, host="127.0.0.1", - port=0, + port=port, ) ) self.backend_thread = threading.Thread(target=self.backend.run) diff --git a/tests/test_app.py b/tests/test_app.py index 1e032500d..980c31849 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io import os.path import sys @@ -16,8 +18,15 @@ from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.view import ModelView from reflex import AdminDash, constants -from reflex.app import App, DefaultState, process, upload -from reflex.components import Box +from reflex.app import ( + App, + ComponentCallable, + DefaultState, + default_overlay_component, + process, + upload, +) +from reflex.components import Box, Component, Cond, Fragment, Text from reflex.event import Event, get_hydrate_event from reflex.middleware import HydrateMiddleware from reflex.model import Model @@ -945,3 +954,55 @@ async def test_process_events(gen_state, mocker): assert app.state_manager.get_state("token").value == 5 assert app.postprocess.call_count == 6 + + +@pytest.mark.parametrize( + ("state", "overlay_component", "exp_page_child"), + [ + (DefaultState, default_overlay_component, None), + (DefaultState, None, None), + (DefaultState, Text.create("foo"), Text), + (State, default_overlay_component, Fragment), + (State, None, None), + (State, Text.create("foo"), Text), + (State, lambda: Text.create("foo"), Text), + ], +) +def test_overlay_component( + state: State | None, + overlay_component: Component | ComponentCallable | None, + exp_page_child: Type[Component] | None, +): + """Test that the overlay component is set correctly. + + Args: + state: The state class to pass to App. + overlay_component: The overlay_component to pass to App. + exp_page_child: The type of the expected child in the page fragment. + """ + app = App(state=state, overlay_component=overlay_component) + if exp_page_child is None: + assert app.overlay_component is None + elif isinstance(exp_page_child, Fragment): + assert app.overlay_component is not None + generated_component = app._generate_component(app.overlay_component) + assert isinstance(generated_component, Fragment) + assert isinstance( + generated_component.children[0], + Cond, # ConnectionModal is a Cond under the hood + ) + else: + assert app.overlay_component is not None + assert isinstance( + app._generate_component(app.overlay_component), + exp_page_child, + ) + + app.add_page(Box.create("Index"), route="/test") + page = app.pages["test"] + if exp_page_child is not None: + assert len(page.children) == 3 + children_types = (type(child) for child in page.children) + assert exp_page_child in children_types + else: + assert len(page.children) == 2