extend test_connection_banner to also test the cloud banner

This commit is contained in:
Masen Furer 2025-01-30 11:05:57 -08:00
parent 11beaab954
commit c86e4c9ee0
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4

View File

@ -1,5 +1,7 @@
"""Test case for displaying the connection banner when the websocket drops.""" """Test case for displaying the connection banner when the websocket drops."""
import asyncio
import functools
from typing import Generator from typing import Generator
import pytest import pytest
@ -11,12 +13,13 @@ from reflex.testing import AppHarness, WebDriver
from .utils import SessionStorage from .utils import SessionStorage
def ConnectionBanner(): def ConnectionBanner(is_reflex_cloud: bool = False):
"""App with a connection banner.""" """App with a connection banner."""
import asyncio
import reflex as rx import reflex as rx
# Simulate reflex cloud deploy
rx.config.get_config().is_reflex_cloud = is_reflex_cloud
class State(rx.State): class State(rx.State):
foo: int = 0 foo: int = 0
@ -40,19 +43,43 @@ def ConnectionBanner():
app.add_page(index) app.add_page(index)
@pytest.fixture(
params=[False, True], ids=["reflex_cloud_disabled", "reflex_cloud_enabled"]
)
def simulate_is_reflex_cloud(request) -> bool:
"""Fixture to simulate reflex cloud deployment.
Args:
request: pytest request fixture.
Returns:
True if reflex cloud is enabled, False otherwise.
"""
return request.param
@pytest.fixture() @pytest.fixture()
def connection_banner(tmp_path) -> Generator[AppHarness, None, None]: def connection_banner(
tmp_path,
simulate_is_reflex_cloud: bool,
) -> Generator[AppHarness, None, None]:
"""Start ConnectionBanner app at tmp_path via AppHarness. """Start ConnectionBanner app at tmp_path via AppHarness.
Args: Args:
tmp_path: pytest tmp_path fixture tmp_path: pytest tmp_path fixture
simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app.
Yields: Yields:
running AppHarness instance running AppHarness instance
""" """
with AppHarness.create( with AppHarness.create(
root=tmp_path, root=tmp_path,
app_source=ConnectionBanner, app_source=functools.partial(
ConnectionBanner, is_reflex_cloud=simulate_is_reflex_cloud
),
app_name="connection_banner_reflex_cloud"
if simulate_is_reflex_cloud
else "connection_banner",
) as harness: ) as harness:
yield harness yield harness
@ -77,6 +104,38 @@ def has_error_modal(driver: WebDriver) -> bool:
return True return True
def has_cloud_banner(driver: WebDriver) -> bool:
"""Check if the cloud banner is displayed.
Args:
driver: Selenium webdriver instance.
Returns:
True if the banner is displayed, False otherwise.
"""
try:
driver.find_element(
By.XPATH, "//*[ contains(text(), 'You ran out of compute credits.') ]"
)
except NoSuchElementException:
return False
else:
return True
def _assert_token(connection_banner, driver):
"""Poll for backend to be up.
Args:
connection_banner: AppHarness instance.
driver: Selenium webdriver instance.
"""
ss = SessionStorage(driver)
assert connection_banner._poll_for(
lambda: ss.get("token") is not None
), "token not found"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connection_banner(connection_banner: AppHarness): async def test_connection_banner(connection_banner: AppHarness):
"""Test that the connection banner is displayed when the websocket drops. """Test that the connection banner is displayed when the websocket drops.
@ -88,10 +147,7 @@ async def test_connection_banner(connection_banner: AppHarness):
assert connection_banner.backend is not None assert connection_banner.backend is not None
driver = connection_banner.frontend() driver = connection_banner.frontend()
ss = SessionStorage(driver) _assert_token(connection_banner, driver)
assert connection_banner._poll_for(
lambda: ss.get("token") is not None
), "token not found"
assert connection_banner._poll_for(lambda: not has_error_modal(driver)) assert connection_banner._poll_for(lambda: not has_error_modal(driver))
@ -132,3 +188,36 @@ async def test_connection_banner(connection_banner: AppHarness):
# Count should have incremented after coming back up # Count should have incremented after coming back up
assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2" assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"
@pytest.mark.asyncio
async def test_cloud_banner(
connection_banner: AppHarness, simulate_is_reflex_cloud: bool
):
"""Test that the connection banner is displayed when the websocket drops.
Args:
connection_banner: AppHarness instance.
simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app.
"""
assert connection_banner.app_instance is not None
assert connection_banner.backend is not None
driver = connection_banner.frontend()
driver.add_cookie({"name": "backend-enabled", "value": "truly"})
driver.refresh()
_assert_token(connection_banner, driver)
assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
driver.add_cookie({"name": "backend-enabled", "value": "false"})
driver.refresh()
if simulate_is_reflex_cloud:
assert connection_banner._poll_for(lambda: has_cloud_banner(driver))
else:
_assert_token(connection_banner, driver)
assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
driver.delete_cookie("backend-enabled")
driver.refresh()
_assert_token(connection_banner, driver)
assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))