App: only render default overlay_component when DefaultState is not used (#1744)

This commit is contained in:
Masen Furer 2023-09-05 16:22:25 -07:00 committed by GitHub
parent 38c5503f94
commit 2e014422f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 174 additions and 5 deletions

View File

@ -0,0 +1,90 @@
"""Test case for displaying the connection banner when the websocket drops."""
from typing import Generator
import pytest
from selenium.common.exceptions import NoSuchElementException
from selenium.webdriver.common.by import By
from reflex.testing import AppHarness, WebDriver
def ConnectionBanner():
"""App with a connection banner."""
import reflex as rx
class State(rx.State):
foo: int = 0
def index():
return rx.text("Hello World")
app = rx.App(state=State)
app.add_page(index)
app.compile()
@pytest.fixture()
def connection_banner(tmp_path) -> Generator[AppHarness, None, None]:
"""Start ConnectionBanner app at tmp_path via AppHarness.
Args:
tmp_path: pytest tmp_path fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path,
app_source=ConnectionBanner, # type: ignore
) as harness:
yield harness
CONNECTION_ERROR_XPATH = "//*[ text() = 'Connection Error' ]"
def has_error_modal(driver: WebDriver) -> bool:
"""Check if the connection error modal is displayed.
Args:
driver: Selenium webdriver instance.
Returns:
True if the modal is displayed, False otherwise.
"""
try:
driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH)
return True
except NoSuchElementException:
return False
def test_connection_banner(connection_banner: AppHarness):
"""Test that the connection banner is displayed when the websocket drops.
Args:
connection_banner: AppHarness instance.
"""
assert connection_banner.app_instance is not None
assert connection_banner.backend is not None
driver = connection_banner.frontend()
connection_banner._poll_for(lambda: not has_error_modal(driver))
# Get the backend port
backend_port = connection_banner._poll_for_servers().getsockname()[1]
# Kill the backend
connection_banner.backend.should_exit = True
if connection_banner.backend_thread is not None:
connection_banner.backend_thread.join()
# Error modal should now be displayed
connection_banner._poll_for(lambda: has_error_modal(driver))
# Bring the backend back up
connection_banner._start_backend(port=backend_port)
# Banner should be gone now
connection_banner._poll_for(lambda: not has_error_modal(driver))

View File

@ -57,6 +57,15 @@ ComponentCallable = Callable[[], Component]
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
def default_overlay_component() -> Component:
"""Default overlay_component attribute for App.
Returns:
The default overlay_component, which is a connection_modal.
"""
return connection_modal()
class App(Base):
"""A Reflex application."""
@ -97,7 +106,9 @@ class App(Base):
event_namespace: Optional[AsyncNamespace] = None
# A component that is present on every page.
overlay_component: Optional[Union[Component, ComponentCallable]] = connection_modal
overlay_component: Optional[
Union[Component, ComponentCallable]
] = default_overlay_component
def __init__(self, *args, **kwargs):
"""Initialize the app.
@ -179,6 +190,13 @@ class App(Base):
# Set up the admin dash.
self.setup_admin_dash()
# If a State is not used and no overlay_component is specified, do not render the connection modal
if (
self.state is DefaultState
and self.overlay_component is default_overlay_component
):
self.overlay_component = None
def __repr__(self) -> str:
"""Get the string representation of the app.

View File

@ -163,14 +163,14 @@ class AppHarness:
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
self.app_instance = self.app_module.app
def _start_backend(self):
def _start_backend(self, port=0):
if self.app_instance is None:
raise RuntimeError("App was not initialized.")
self.backend = uvicorn.Server(
uvicorn.Config(
app=self.app_instance.api,
host="127.0.0.1",
port=0,
port=port,
)
)
self.backend_thread = threading.Thread(target=self.backend.run)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import io
import os.path
import sys
@ -16,8 +18,15 @@ from starlette_admin.contrib.sqla.admin import Admin
from starlette_admin.contrib.sqla.view import ModelView
from reflex import AdminDash, constants
from reflex.app import App, DefaultState, process, upload
from reflex.components import Box
from reflex.app import (
App,
ComponentCallable,
DefaultState,
default_overlay_component,
process,
upload,
)
from reflex.components import Box, Component, Cond, Fragment, Text
from reflex.event import Event, get_hydrate_event
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
@ -945,3 +954,55 @@ async def test_process_events(gen_state, mocker):
assert app.state_manager.get_state("token").value == 5
assert app.postprocess.call_count == 6
@pytest.mark.parametrize(
("state", "overlay_component", "exp_page_child"),
[
(DefaultState, default_overlay_component, None),
(DefaultState, None, None),
(DefaultState, Text.create("foo"), Text),
(State, default_overlay_component, Fragment),
(State, None, None),
(State, Text.create("foo"), Text),
(State, lambda: Text.create("foo"), Text),
],
)
def test_overlay_component(
state: State | None,
overlay_component: Component | ComponentCallable | None,
exp_page_child: Type[Component] | None,
):
"""Test that the overlay component is set correctly.
Args:
state: The state class to pass to App.
overlay_component: The overlay_component to pass to App.
exp_page_child: The type of the expected child in the page fragment.
"""
app = App(state=state, overlay_component=overlay_component)
if exp_page_child is None:
assert app.overlay_component is None
elif isinstance(exp_page_child, Fragment):
assert app.overlay_component is not None
generated_component = app._generate_component(app.overlay_component)
assert isinstance(generated_component, Fragment)
assert isinstance(
generated_component.children[0],
Cond, # ConnectionModal is a Cond under the hood
)
else:
assert app.overlay_component is not None
assert isinstance(
app._generate_component(app.overlay_component),
exp_page_child,
)
app.add_page(Box.create("Index"), route="/test")
page = app.pages["test"]
if exp_page_child is not None:
assert len(page.children) == 3
children_types = (type(child) for child in page.children)
assert exp_page_child in children_types
else:
assert len(page.children) == 2