App: only render default overlay_component when DefaultState is not used (#1744)
This commit is contained in:
parent
38c5503f94
commit
2e014422f5
90
integration/test_connection_banner.py
Normal file
90
integration/test_connection_banner.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""Test case for displaying the connection banner when the websocket drops."""
|
||||||
|
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from selenium.common.exceptions import NoSuchElementException
|
||||||
|
from selenium.webdriver.common.by import By
|
||||||
|
|
||||||
|
from reflex.testing import AppHarness, WebDriver
|
||||||
|
|
||||||
|
|
||||||
|
def ConnectionBanner():
|
||||||
|
"""App with a connection banner."""
|
||||||
|
import reflex as rx
|
||||||
|
|
||||||
|
class State(rx.State):
|
||||||
|
foo: int = 0
|
||||||
|
|
||||||
|
def index():
|
||||||
|
return rx.text("Hello World")
|
||||||
|
|
||||||
|
app = rx.App(state=State)
|
||||||
|
app.add_page(index)
|
||||||
|
app.compile()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def connection_banner(tmp_path) -> Generator[AppHarness, None, None]:
|
||||||
|
"""Start ConnectionBanner app at tmp_path via AppHarness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmp_path: pytest tmp_path fixture
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
running AppHarness instance
|
||||||
|
"""
|
||||||
|
with AppHarness.create(
|
||||||
|
root=tmp_path,
|
||||||
|
app_source=ConnectionBanner, # type: ignore
|
||||||
|
) as harness:
|
||||||
|
yield harness
|
||||||
|
|
||||||
|
|
||||||
|
CONNECTION_ERROR_XPATH = "//*[ text() = 'Connection Error' ]"
|
||||||
|
|
||||||
|
|
||||||
|
def has_error_modal(driver: WebDriver) -> bool:
|
||||||
|
"""Check if the connection error modal is displayed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
driver: Selenium webdriver instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the modal is displayed, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH)
|
||||||
|
return True
|
||||||
|
except NoSuchElementException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_banner(connection_banner: AppHarness):
|
||||||
|
"""Test that the connection banner is displayed when the websocket drops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_banner: AppHarness instance.
|
||||||
|
"""
|
||||||
|
assert connection_banner.app_instance is not None
|
||||||
|
assert connection_banner.backend is not None
|
||||||
|
driver = connection_banner.frontend()
|
||||||
|
|
||||||
|
connection_banner._poll_for(lambda: not has_error_modal(driver))
|
||||||
|
|
||||||
|
# Get the backend port
|
||||||
|
backend_port = connection_banner._poll_for_servers().getsockname()[1]
|
||||||
|
|
||||||
|
# Kill the backend
|
||||||
|
connection_banner.backend.should_exit = True
|
||||||
|
if connection_banner.backend_thread is not None:
|
||||||
|
connection_banner.backend_thread.join()
|
||||||
|
|
||||||
|
# Error modal should now be displayed
|
||||||
|
connection_banner._poll_for(lambda: has_error_modal(driver))
|
||||||
|
|
||||||
|
# Bring the backend back up
|
||||||
|
connection_banner._start_backend(port=backend_port)
|
||||||
|
|
||||||
|
# Banner should be gone now
|
||||||
|
connection_banner._poll_for(lambda: not has_error_modal(driver))
|
@ -57,6 +57,15 @@ ComponentCallable = Callable[[], Component]
|
|||||||
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
|
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
|
||||||
|
|
||||||
|
|
||||||
|
def default_overlay_component() -> Component:
|
||||||
|
"""Default overlay_component attribute for App.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The default overlay_component, which is a connection_modal.
|
||||||
|
"""
|
||||||
|
return connection_modal()
|
||||||
|
|
||||||
|
|
||||||
class App(Base):
|
class App(Base):
|
||||||
"""A Reflex application."""
|
"""A Reflex application."""
|
||||||
|
|
||||||
@ -97,7 +106,9 @@ class App(Base):
|
|||||||
event_namespace: Optional[AsyncNamespace] = None
|
event_namespace: Optional[AsyncNamespace] = None
|
||||||
|
|
||||||
# A component that is present on every page.
|
# A component that is present on every page.
|
||||||
overlay_component: Optional[Union[Component, ComponentCallable]] = connection_modal
|
overlay_component: Optional[
|
||||||
|
Union[Component, ComponentCallable]
|
||||||
|
] = default_overlay_component
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
"""Initialize the app.
|
"""Initialize the app.
|
||||||
@ -179,6 +190,13 @@ class App(Base):
|
|||||||
# Set up the admin dash.
|
# Set up the admin dash.
|
||||||
self.setup_admin_dash()
|
self.setup_admin_dash()
|
||||||
|
|
||||||
|
# If a State is not used and no overlay_component is specified, do not render the connection modal
|
||||||
|
if (
|
||||||
|
self.state is DefaultState
|
||||||
|
and self.overlay_component is default_overlay_component
|
||||||
|
):
|
||||||
|
self.overlay_component = None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Get the string representation of the app.
|
"""Get the string representation of the app.
|
||||||
|
|
||||||
|
@ -163,14 +163,14 @@ class AppHarness:
|
|||||||
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
|
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
|
||||||
self.app_instance = self.app_module.app
|
self.app_instance = self.app_module.app
|
||||||
|
|
||||||
def _start_backend(self):
|
def _start_backend(self, port=0):
|
||||||
if self.app_instance is None:
|
if self.app_instance is None:
|
||||||
raise RuntimeError("App was not initialized.")
|
raise RuntimeError("App was not initialized.")
|
||||||
self.backend = uvicorn.Server(
|
self.backend = uvicorn.Server(
|
||||||
uvicorn.Config(
|
uvicorn.Config(
|
||||||
app=self.app_instance.api,
|
app=self.app_instance.api,
|
||||||
host="127.0.0.1",
|
host="127.0.0.1",
|
||||||
port=0,
|
port=port,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.backend_thread = threading.Thread(target=self.backend.run)
|
self.backend_thread = threading.Thread(target=self.backend.run)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
@ -16,8 +18,15 @@ from starlette_admin.contrib.sqla.admin import Admin
|
|||||||
from starlette_admin.contrib.sqla.view import ModelView
|
from starlette_admin.contrib.sqla.view import ModelView
|
||||||
|
|
||||||
from reflex import AdminDash, constants
|
from reflex import AdminDash, constants
|
||||||
from reflex.app import App, DefaultState, process, upload
|
from reflex.app import (
|
||||||
from reflex.components import Box
|
App,
|
||||||
|
ComponentCallable,
|
||||||
|
DefaultState,
|
||||||
|
default_overlay_component,
|
||||||
|
process,
|
||||||
|
upload,
|
||||||
|
)
|
||||||
|
from reflex.components import Box, Component, Cond, Fragment, Text
|
||||||
from reflex.event import Event, get_hydrate_event
|
from reflex.event import Event, get_hydrate_event
|
||||||
from reflex.middleware import HydrateMiddleware
|
from reflex.middleware import HydrateMiddleware
|
||||||
from reflex.model import Model
|
from reflex.model import Model
|
||||||
@ -945,3 +954,55 @@ async def test_process_events(gen_state, mocker):
|
|||||||
|
|
||||||
assert app.state_manager.get_state("token").value == 5
|
assert app.state_manager.get_state("token").value == 5
|
||||||
assert app.postprocess.call_count == 6
|
assert app.postprocess.call_count == 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("state", "overlay_component", "exp_page_child"),
|
||||||
|
[
|
||||||
|
(DefaultState, default_overlay_component, None),
|
||||||
|
(DefaultState, None, None),
|
||||||
|
(DefaultState, Text.create("foo"), Text),
|
||||||
|
(State, default_overlay_component, Fragment),
|
||||||
|
(State, None, None),
|
||||||
|
(State, Text.create("foo"), Text),
|
||||||
|
(State, lambda: Text.create("foo"), Text),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_overlay_component(
|
||||||
|
state: State | None,
|
||||||
|
overlay_component: Component | ComponentCallable | None,
|
||||||
|
exp_page_child: Type[Component] | None,
|
||||||
|
):
|
||||||
|
"""Test that the overlay component is set correctly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The state class to pass to App.
|
||||||
|
overlay_component: The overlay_component to pass to App.
|
||||||
|
exp_page_child: The type of the expected child in the page fragment.
|
||||||
|
"""
|
||||||
|
app = App(state=state, overlay_component=overlay_component)
|
||||||
|
if exp_page_child is None:
|
||||||
|
assert app.overlay_component is None
|
||||||
|
elif isinstance(exp_page_child, Fragment):
|
||||||
|
assert app.overlay_component is not None
|
||||||
|
generated_component = app._generate_component(app.overlay_component)
|
||||||
|
assert isinstance(generated_component, Fragment)
|
||||||
|
assert isinstance(
|
||||||
|
generated_component.children[0],
|
||||||
|
Cond, # ConnectionModal is a Cond under the hood
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert app.overlay_component is not None
|
||||||
|
assert isinstance(
|
||||||
|
app._generate_component(app.overlay_component),
|
||||||
|
exp_page_child,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_page(Box.create("Index"), route="/test")
|
||||||
|
page = app.pages["test"]
|
||||||
|
if exp_page_child is not None:
|
||||||
|
assert len(page.children) == 3
|
||||||
|
children_types = (type(child) for child in page.children)
|
||||||
|
assert exp_page_child in children_types
|
||||||
|
else:
|
||||||
|
assert len(page.children) == 2
|
||||||
|
Loading…
Reference in New Issue
Block a user