add backend disabled dialog (#4715)

* add backend disabled dialog

* pyi that guy

* pyi the other guy

* extend test_connection_banner to also test the cloud banner

* oops, need asyncio _inside_ the app

* Update reflex/components/core/banner.py

Co-authored-by: Masen Furer <m_github@0x26.net>

* use universal cookies

* fix pre-commit

* revert universal cookie 🍪

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
Khaleel Al-Adhami 2025-01-31 13:00:56 -08:00 committed by GitHub
parent 6231f82248
commit 335816cbf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 295 additions and 17 deletions

View File

@ -106,6 +106,18 @@ export const getBackendURL = (url_str) => {
return endpoint; return endpoint;
}; };
/**
* Check if the backend is disabled.
*
* @returns True if the backend is disabled, false otherwise.
*/
export const isBackendDisabled = () => {
const cookie = document.cookie
.split("; ")
.find((row) => row.startsWith("backend-enabled="));
return cookie !== undefined && cookie.split("=")[1] == "false";
};
/** /**
* Determine if any event in the event queue is stateful. * Determine if any event in the event queue is stateful.
* *
@ -301,10 +313,7 @@ export const applyEvent = async (event, socket) => {
// Send the event to the server. // Send the event to the server.
if (socket) { if (socket) {
socket.emit( socket.emit("event", event);
"event",
event,
);
return true; return true;
} }
@ -497,7 +506,7 @@ export const uploadFiles = async (
return false; return false;
} }
const upload_ref_name = `__upload_controllers_${upload_id}` const upload_ref_name = `__upload_controllers_${upload_id}`;
if (refs[upload_ref_name]) { if (refs[upload_ref_name]) {
console.log("Upload already in progress for ", upload_id); console.log("Upload already in progress for ", upload_id);
@ -815,7 +824,7 @@ export const useEventLoop = (
return; return;
} }
// only use websockets if state is present // only use websockets if state is present
if (Object.keys(initialState).length > 1) { if (Object.keys(initialState).length > 1 && !isBackendDisabled()) {
// Initialize the websocket connection. // Initialize the websocket connection.
if (!socket.current) { if (!socket.current) {
connect( connect(

View File

@ -59,7 +59,11 @@ from reflex.components.component import (
ComponentStyle, ComponentStyle,
evaluate_style_namespaces, evaluate_style_namespaces,
) )
from reflex.components.core.banner import connection_pulser, connection_toaster from reflex.components.core.banner import (
backend_disabled,
connection_pulser,
connection_toaster,
)
from reflex.components.core.breakpoints import set_breakpoints from reflex.components.core.breakpoints import set_breakpoints
from reflex.components.core.client_side_routing import ( from reflex.components.core.client_side_routing import (
Default404Page, Default404Page,
@ -158,9 +162,12 @@ def default_overlay_component() -> Component:
Returns: Returns:
The default overlay_component, which is a connection_modal. The default overlay_component, which is a connection_modal.
""" """
config = get_config()
return Fragment.create( return Fragment.create(
connection_pulser(), connection_pulser(),
connection_toaster(), connection_toaster(),
*([backend_disabled()] if config.is_reflex_cloud else []),
*codespaces.codespaces_auto_redirect(), *codespaces.codespaces_auto_redirect(),
) )

View File

@ -4,8 +4,10 @@ from __future__ import annotations
from typing import Optional from typing import Optional
from reflex import constants
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.core.cond import cond from reflex.components.core.cond import cond
from reflex.components.datadisplay.logo import svg_logo
from reflex.components.el.elements.typography import Div from reflex.components.el.elements.typography import Div
from reflex.components.lucide.icon import Icon from reflex.components.lucide.icon import Icon
from reflex.components.radix.themes.components.dialog import ( from reflex.components.radix.themes.components.dialog import (
@ -293,7 +295,84 @@ class ConnectionPulser(Div):
) )
class BackendDisabled(Div):
"""A component that displays a message when the backend is disabled."""
@classmethod
def create(cls, **props) -> Component:
"""Create a backend disabled component.
Args:
**props: The properties of the component.
Returns:
The backend disabled component.
"""
import reflex as rx
is_backend_disabled = Var(
"backendDisabled",
_var_type=bool,
_var_data=VarData(
hooks={
"const [backendDisabled, setBackendDisabled] = useState(false);": None,
"useEffect(() => { setBackendDisabled(isBackendDisabled()); }, []);": None,
},
imports={
f"$/{constants.Dirs.STATE_PATH}": [
ImportVar(tag="isBackendDisabled")
],
},
),
)
return super().create(
rx.cond(
is_backend_disabled,
rx.box(
rx.box(
rx.card(
rx.vstack(
svg_logo(),
rx.text(
"You ran out of compute credits.",
),
rx.callout(
rx.fragment(
"Please upgrade your plan or raise your compute credits at ",
rx.link(
"Reflex Cloud.",
href="https://cloud.reflex.dev/",
),
),
width="100%",
icon="info",
variant="surface",
),
),
font_size="20px",
font_family='"Inter", "Helvetica", "Arial", sans-serif',
variant="classic",
),
position="fixed",
top="50%",
left="50%",
transform="translate(-50%, -50%)",
width="40ch",
max_width="90vw",
),
position="fixed",
z_index=9999,
backdrop_filter="grayscale(1) blur(5px)",
width="100dvw",
height="100dvh",
),
)
)
connection_banner = ConnectionBanner.create connection_banner = ConnectionBanner.create
connection_modal = ConnectionModal.create connection_modal = ConnectionModal.create
connection_toaster = ConnectionToaster.create connection_toaster = ConnectionToaster.create
connection_pulser = ConnectionPulser.create connection_pulser = ConnectionPulser.create
backend_disabled = BackendDisabled.create

View File

@ -350,7 +350,93 @@ class ConnectionPulser(Div):
""" """
... ...
class BackendDisabled(Div):
@overload
@classmethod
def create( # type: ignore
cls,
*children,
access_key: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
auto_capitalize: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
content_editable: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
context_menu: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
dir: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
draggable: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
enter_key_hint: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
hidden: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
input_mode: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
item_prop: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
lang: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
role: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
slot: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
spell_check: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
tab_index: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
title: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
on_blur: Optional[EventType[[], BASE_STATE]] = None,
on_click: Optional[EventType[[], BASE_STATE]] = None,
on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
on_double_click: Optional[EventType[[], BASE_STATE]] = None,
on_focus: Optional[EventType[[], BASE_STATE]] = None,
on_mount: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
on_scroll: Optional[EventType[[], BASE_STATE]] = None,
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props,
) -> "BackendDisabled":
"""Create a backend disabled component.
Args:
access_key: Provides a hint for generating a keyboard shortcut for the current element.
auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user.
content_editable: Indicates whether the element's content is editable.
context_menu: Defines the ID of a <menu> element which will serve as the element's context menu.
dir: Defines the text direction. Allowed values are ltr (Left-To-Right) or rtl (Right-To-Left)
draggable: Defines whether the element can be dragged.
enter_key_hint: Hints what media types the media element is able to play.
hidden: Defines whether the element is hidden.
input_mode: Defines the type of the element.
item_prop: Defines the name of the element for metadata purposes.
lang: Defines the language used in the element.
role: Defines the role of the element.
slot: Assigns a slot in a shadow DOM shadow tree to an element.
spell_check: Defines whether the element may be checked for spelling errors.
tab_index: Defines the position of the current element in the tabbing order.
title: Defines a tooltip for the element.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: The properties of the component.
Returns:
The backend disabled component.
"""
...
connection_banner = ConnectionBanner.create connection_banner = ConnectionBanner.create
connection_modal = ConnectionModal.create connection_modal = ConnectionModal.create
connection_toaster = ConnectionToaster.create connection_toaster = ConnectionToaster.create
connection_pulser = ConnectionPulser.create connection_pulser = ConnectionPulser.create
backend_disabled = BackendDisabled.create

View File

@ -20,7 +20,7 @@ class Card(elements.Div, RadixThemesComponent):
# Card size: "1" - "5" # Card size: "1" - "5"
size: Var[Responsive[Literal["1", "2", "3", "4", "5"],]] size: Var[Responsive[Literal["1", "2", "3", "4", "5"],]]
# Variant of Card: "solid" | "soft" | "outline" | "ghost" # Variant of Card: "surface" | "classic" | "ghost"
variant: Var[Literal["surface", "classic", "ghost"]] variant: Var[Literal["surface", "classic", "ghost"]]

View File

@ -94,7 +94,7 @@ class Card(elements.Div, RadixThemesComponent):
*children: Child components. *children: Child components.
as_child: Change the default rendered element for the one passed as a child, merging their props and behavior. as_child: Change the default rendered element for the one passed as a child, merging their props and behavior.
size: Card size: "1" - "5" size: Card size: "1" - "5"
variant: Variant of Card: "solid" | "soft" | "outline" | "ghost" variant: Variant of Card: "surface" | "classic" | "ghost"
access_key: Provides a hint for generating a keyboard shortcut for the current element. access_key: Provides a hint for generating a keyboard shortcut for the current element.
auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user.
content_editable: Indicates whether the element's content is editable. content_editable: Indicates whether the element's content is editable.

View File

@ -703,6 +703,9 @@ class Config(Base):
# Path to file containing key-values pairs to override in the environment; Dotenv format. # Path to file containing key-values pairs to override in the environment; Dotenv format.
env_file: Optional[str] = None env_file: Optional[str] = None
# Whether the app is running in the reflex cloud environment.
is_reflex_cloud: bool = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialize the config values. """Initialize the config values.

View File

@ -1,5 +1,6 @@
"""Test case for displaying the connection banner when the websocket drops.""" """Test case for displaying the connection banner when the websocket drops."""
import functools
from typing import Generator from typing import Generator
import pytest import pytest
@ -11,12 +12,19 @@ from reflex.testing import AppHarness, WebDriver
from .utils import SessionStorage from .utils import SessionStorage
def ConnectionBanner(): def ConnectionBanner(is_reflex_cloud: bool = False):
"""App with a connection banner.""" """App with a connection banner.
Args:
is_reflex_cloud: The value for config.is_reflex_cloud.
"""
import asyncio import asyncio
import reflex as rx import reflex as rx
# Simulate reflex cloud deploy
rx.config.get_config().is_reflex_cloud = is_reflex_cloud
class State(rx.State): class State(rx.State):
foo: int = 0 foo: int = 0
@ -40,19 +48,43 @@ def ConnectionBanner():
app.add_page(index) app.add_page(index)
@pytest.fixture(
params=[False, True], ids=["reflex_cloud_disabled", "reflex_cloud_enabled"]
)
def simulate_is_reflex_cloud(request) -> bool:
"""Fixture to simulate reflex cloud deployment.
Args:
request: pytest request fixture.
Returns:
True if reflex cloud is enabled, False otherwise.
"""
return request.param
@pytest.fixture() @pytest.fixture()
def connection_banner(tmp_path) -> Generator[AppHarness, None, None]: def connection_banner(
tmp_path,
simulate_is_reflex_cloud: bool,
) -> Generator[AppHarness, None, None]:
"""Start ConnectionBanner app at tmp_path via AppHarness. """Start ConnectionBanner app at tmp_path via AppHarness.
Args: Args:
tmp_path: pytest tmp_path fixture tmp_path: pytest tmp_path fixture
simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app.
Yields: Yields:
running AppHarness instance running AppHarness instance
""" """
with AppHarness.create( with AppHarness.create(
root=tmp_path, root=tmp_path,
app_source=ConnectionBanner, app_source=functools.partial(
ConnectionBanner, is_reflex_cloud=simulate_is_reflex_cloud
),
app_name="connection_banner_reflex_cloud"
if simulate_is_reflex_cloud
else "connection_banner",
) as harness: ) as harness:
yield harness yield harness
@ -77,6 +109,38 @@ def has_error_modal(driver: WebDriver) -> bool:
return True return True
def has_cloud_banner(driver: WebDriver) -> bool:
"""Check if the cloud banner is displayed.
Args:
driver: Selenium webdriver instance.
Returns:
True if the banner is displayed, False otherwise.
"""
try:
driver.find_element(
By.XPATH, "//*[ contains(text(), 'You ran out of compute credits.') ]"
)
except NoSuchElementException:
return False
else:
return True
def _assert_token(connection_banner, driver):
"""Poll for backend to be up.
Args:
connection_banner: AppHarness instance.
driver: Selenium webdriver instance.
"""
ss = SessionStorage(driver)
assert connection_banner._poll_for(
lambda: ss.get("token") is not None
), "token not found"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connection_banner(connection_banner: AppHarness): async def test_connection_banner(connection_banner: AppHarness):
"""Test that the connection banner is displayed when the websocket drops. """Test that the connection banner is displayed when the websocket drops.
@ -88,10 +152,7 @@ async def test_connection_banner(connection_banner: AppHarness):
assert connection_banner.backend is not None assert connection_banner.backend is not None
driver = connection_banner.frontend() driver = connection_banner.frontend()
ss = SessionStorage(driver) _assert_token(connection_banner, driver)
assert connection_banner._poll_for(
lambda: ss.get("token") is not None
), "token not found"
assert connection_banner._poll_for(lambda: not has_error_modal(driver)) assert connection_banner._poll_for(lambda: not has_error_modal(driver))
@ -132,3 +193,36 @@ async def test_connection_banner(connection_banner: AppHarness):
# Count should have incremented after coming back up # Count should have incremented after coming back up
assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2" assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"
@pytest.mark.asyncio
async def test_cloud_banner(
connection_banner: AppHarness, simulate_is_reflex_cloud: bool
):
"""Test that the connection banner is displayed when the websocket drops.
Args:
connection_banner: AppHarness instance.
simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app.
"""
assert connection_banner.app_instance is not None
assert connection_banner.backend is not None
driver = connection_banner.frontend()
driver.add_cookie({"name": "backend-enabled", "value": "truly"})
driver.refresh()
_assert_token(connection_banner, driver)
assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
driver.add_cookie({"name": "backend-enabled", "value": "false"})
driver.refresh()
if simulate_is_reflex_cloud:
assert connection_banner._poll_for(lambda: has_cloud_banner(driver))
else:
_assert_token(connection_banner, driver)
assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
driver.delete_cookie("backend-enabled")
driver.refresh()
_assert_token(connection_banner, driver)
assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))