diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2
index ed19b5279..87465e9e4 100644
--- a/reflex/.templates/jinja/web/pages/index.js.jinja2
+++ b/reflex/.templates/jinja/web/pages/index.js.jinja2
@@ -14,7 +14,7 @@ export default function Component() {
const focusRef = useRef();
// Main event loop.
- const [Event, notConnected] = useContext(EventLoopContext)
+ const [Event, connectError] = useContext(EventLoopContext)
// Set focus to the specified element.
useEffect(() => {
@@ -37,12 +37,7 @@ export default function Component() {
{% endfor %}
return (
-
- {%- if err_comp -%}
- {{ utils.render(err_comp, indent_width=1) }}
- {%- endif -%}
{{utils.render(render, indent_width=0)}}
-
)
}
{% endblock %}
diff --git a/reflex/.templates/web/pages/_app.js b/reflex/.templates/web/pages/_app.js
index 89461263f..6411ee9b5 100644
--- a/reflex/.templates/web/pages/_app.js
+++ b/reflex/.templates/web/pages/_app.js
@@ -15,12 +15,12 @@ const GlobalStyles = css`
`;
function EventLoopProvider({ children }) {
- const [state, Event, notConnected] = useEventLoop(
+ const [state, Event, connectError] = useEventLoop(
initialState,
initialEvents,
)
return (
-
+
{children}
diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js
index 5c382a3cf..e39ac8cce 100644
--- a/reflex/.templates/web/utils/state.js
+++ b/reflex/.templates/web/utils/state.js
@@ -250,14 +250,14 @@ export const processEvent = async (
* @param socket The socket object to connect.
* @param dispatch The function to queue state update
* @param transports The transports to use.
- * @param setNotConnected The function to update connection state.
+ * @param setConnectError The function to update connection error value.
* @param initial_events Array of events to seed the queue after connecting.
*/
export const connect = async (
socket,
dispatch,
transports,
- setNotConnected,
+ setConnectError,
initial_events = [],
) => {
// Get backend URL object from the endpoint.
@@ -272,11 +272,11 @@ export const connect = async (
// Once the socket is open, hydrate the page.
socket.current.on("connect", () => {
queueEvents(initial_events, socket)
- setNotConnected(false)
+ setConnectError(null)
});
socket.current.on('connect_error', (error) => {
- setNotConnected(true)
+ setConnectError(error)
});
// On each received message, queue the updates and events.
@@ -357,10 +357,10 @@ export const E = (name, payload = {}, handler = null) => {
* @param initial_state The initial page state.
* @param initial_events Array of events to seed the queue after connecting.
*
- * @returns [state, Event, notConnected] -
+ * @returns [state, Event, connectError] -
* state is a reactive dict,
* Event is used to queue an event, and
- * notConnected is a reactive boolean indicating whether the websocket is connected.
+ * connectError is a reactive js error from the websocket connection (or null if connected).
*/
export const useEventLoop = (
initial_state = {},
@@ -369,7 +369,7 @@ export const useEventLoop = (
const socket = useRef(null)
const router = useRouter()
const [state, dispatch] = useReducer(applyDelta, initial_state)
- const [notConnected, setNotConnected] = useState(false)
+ const [connectError, setConnectError] = useState(null)
// Function to add new events to the event queue.
const Event = (events, _e) => {
@@ -386,7 +386,7 @@ export const useEventLoop = (
// Initialize the websocket connection.
if (!socket.current) {
- connect(socket, dispatch, ['websocket', 'polling'], setNotConnected, initial_events)
+ connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events)
}
(async () => {
// Process all outstanding events.
@@ -395,7 +395,7 @@ export const useEventLoop = (
}
})()
})
- return [state, Event, notConnected]
+ return [state, Event, connectError]
}
/***
diff --git a/reflex/app.py b/reflex/app.py
index 80eb0fb1e..89fad454d 100644
--- a/reflex/app.py
+++ b/reflex/app.py
@@ -1,4 +1,5 @@
"""The main Reflex app."""
+from __future__ import annotations
import asyncio
import inspect
@@ -29,6 +30,7 @@ from reflex.admin import AdminDash
from reflex.base import Base
from reflex.compiler import compiler
from reflex.compiler import utils as compiler_utils
+from reflex.components import connection_modal
from reflex.components.component import Component, ComponentStyle
from reflex.components.layout.fragment import Fragment
from reflex.config import get_config
@@ -88,12 +90,12 @@ class App(Base):
# Admin dashboard
admin_dash: Optional[AdminDash] = None
- # The component to render if there is a connection error to the server.
- connect_error_component: Optional[Component] = None
-
# The async server name space
event_namespace: Optional[AsyncNamespace] = None
+ # A component that is present on every page.
+ overlay_component: Optional[Union[Component, ComponentCallable]] = connection_modal
+
def __init__(self, *args, **kwargs):
"""Initialize the app.
@@ -106,6 +108,10 @@ class App(Base):
Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
of the DefaultState and the client app state).
"""
+ if "connect_error_component" in kwargs:
+ raise ValueError(
+ "`connect_error_component` is deprecated, use `overlay_component` instead"
+ )
super().__init__(*args, **kwargs)
state_subclasses = State.__subclasses__()
inferred_state = state_subclasses[-1]
@@ -269,6 +275,31 @@ class App(Base):
else:
self.middleware.insert(index, middleware)
+ @staticmethod
+ def _generate_component(component: Component | ComponentCallable) -> Component:
+ """Generate a component from a callable.
+
+ Args:
+ component: The component function to call or Component to return as-is.
+
+ Returns:
+ The generated component.
+
+ Raises:
+ TypeError: When an invalid component function is passed.
+ """
+ try:
+ return component if isinstance(component, Component) else component()
+ except TypeError as e:
+ message = str(e)
+ if "BaseVar" in message or "ComputedVar" in message:
+ raise TypeError(
+ "You may be trying to use an invalid Python function on a state var. "
+ "When referencing a var inside your render code, only limited var operations are supported. "
+ "See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations"
+ ) from e
+ raise e
+
def add_page(
self,
component: Union[Component, ComponentCallable],
@@ -296,9 +327,6 @@ class App(Base):
on_load: The event handler(s) that will be called each time the page load.
meta: The metadata of the page.
script_tags: List of script tags to be added to component
-
- Raises:
- TypeError: If an invalid var operation is used.
"""
# If the route is not set, get it from the callable.
if route is None:
@@ -314,20 +342,16 @@ class App(Base):
self.state.setup_dynamic_args(get_route_args(route))
# Generate the component if it is a callable.
- try:
- component = component if isinstance(component, Component) else component()
- except TypeError as e:
- message = str(e)
- if "BaseVar" in message or "ComputedVar" in message:
- raise TypeError(
- "You may be trying to use an invalid Python function on a state var. "
- "When referencing a var inside your render code, only limited var operations are supported. "
- "See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations"
- ) from e
- raise e
+ component = self._generate_component(component)
- # Wrap the component in a fragment.
- component = Fragment.create(component)
+ # Wrap the component in a fragment with optional overlay.
+ if self.overlay_component is not None:
+ component = Fragment.create(
+ self._generate_component(self.overlay_component),
+ component,
+ )
+ else:
+ component = Fragment.create(component)
# Add meta information to the component.
compiler_utils.add_meta(
@@ -497,7 +521,6 @@ class App(Base):
route,
component,
self.state,
- self.connect_error_component,
),
)
)
diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py
index 94a0a9d18..0d5a325fb 100644
--- a/reflex/compiler/compiler.py
+++ b/reflex/compiler/compiler.py
@@ -89,14 +89,12 @@ def _compile_contexts(state: Type[State]) -> str:
def _compile_page(
component: Component,
state: Type[State],
- connect_error_component,
) -> str:
"""Compile the component given the app state.
Args:
component: The component to compile.
state: The app state.
- connect_error_component: The component to render on sever connection error.
Returns:
The compiled component.
@@ -113,7 +111,6 @@ def _compile_page(
state_name=state.get_name(),
hooks=component.get_hooks(),
render=component.render(),
- err_comp=connect_error_component.render() if connect_error_component else None,
)
@@ -221,7 +218,6 @@ def compile_page(
path: str,
component: Component,
state: Type[State],
- connect_error_component: Component,
) -> Tuple[str, str]:
"""Compile a single page.
@@ -229,7 +225,6 @@ def compile_page(
path: The path to compile the page to.
component: The component to compile.
state: The app state.
- connect_error_component: The component to render on sever connection error.
Returns:
The path and code of the compiled page.
@@ -238,11 +233,7 @@ def compile_page(
output_path = utils.get_page_path(path)
# Add the style to the component.
- code = _compile_page(
- component,
- state,
- connect_error_component,
- )
+ code = _compile_page(component, state)
return output_path, code
diff --git a/reflex/components/__init__.py b/reflex/components/__init__.py
index 28db13d87..80e2bc662 100644
--- a/reflex/components/__init__.py
+++ b/reflex/components/__init__.py
@@ -31,6 +31,7 @@ badge = Badge.create
code = Code.create
code_block = CodeBlock.create
connection_banner = ConnectionBanner.create
+connection_modal = ConnectionModal.create
data_table = DataTable.create
divider = Divider.create
list = List.create
diff --git a/reflex/components/overlay/__init__.py b/reflex/components/overlay/__init__.py
index 227fdf536..9482d003f 100644
--- a/reflex/components/overlay/__init__.py
+++ b/reflex/components/overlay/__init__.py
@@ -8,7 +8,7 @@ from .alertdialog import (
AlertDialogHeader,
AlertDialogOverlay,
)
-from .banner import ConnectionBanner
+from .banner import ConnectionBanner, ConnectionModal
from .drawer import (
Drawer,
DrawerBody,
diff --git a/reflex/components/overlay/banner.py b/reflex/components/overlay/banner.py
index c21bdc626..b16c2ae0d 100644
--- a/reflex/components/overlay/banner.py
+++ b/reflex/components/overlay/banner.py
@@ -1,11 +1,41 @@
"""Banner components."""
+from __future__ import annotations
+
from typing import Optional
from reflex.components.component import Component
from reflex.components.layout import Box, Cond, Fragment
+from reflex.components.overlay.modal import Modal
from reflex.components.typography import Text
from reflex.vars import Var
+connection_error = Var.create_safe(
+ value="(connectError !== null) ? connectError.message : ''",
+ is_local=False,
+ is_string=False,
+)
+has_connection_error = Var.create_safe(
+ value="connectError !== null",
+ is_string=False,
+)
+has_connection_error.type_ = bool
+
+
+def default_connection_error() -> list[str | Var]:
+ """Get the default connection error message.
+
+ Returns:
+ The default connection error message.
+ """
+ from reflex.config import get_config
+
+ return [
+ "Cannot connect to server: ",
+ connection_error,
+ ". Check if server is reachable at ",
+ get_config().api_url or "",
+ ]
+
class ConnectionBanner(Cond):
"""A connection banner component."""
@@ -23,11 +53,33 @@ class ConnectionBanner(Cond):
if not comp:
comp = Box.create(
Text.create(
- "cannot connect to server. Check if server is reachable",
+ *default_connection_error(),
bg="red",
color="white",
),
textAlign="center",
)
- return super().create(Var.create("notConnected"), comp, Fragment.create()) # type: ignore
+ return super().create(has_connection_error, comp, Fragment.create()) # type: ignore
+
+
+class ConnectionModal(Modal):
+ """A connection status modal window."""
+
+ @classmethod
+ def create(cls, comp: Optional[Component] = None) -> Component:
+ """Create a connection banner component.
+
+ Args:
+ comp: The component to render when there's a server connection error.
+
+ Returns:
+ The connection banner component.
+ """
+ if not comp:
+ comp = Text.create(*default_connection_error())
+ return super().create(
+ header="Connection Error",
+ body=comp,
+ is_open=has_connection_error,
+ )