ConnectionModal and ConnectionBanner cleanup (#1379)

This commit is contained in:
Masen Furer 2023-08-28 18:04:52 -07:00 committed by GitHub
parent 51f0339fa4
commit 6b481ecfc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 112 additions and 50 deletions

View File

@ -14,7 +14,7 @@ export default function Component() {
const focusRef = useRef(); const focusRef = useRef();
// Main event loop. // Main event loop.
const [Event, notConnected] = useContext(EventLoopContext) const [Event, connectError] = useContext(EventLoopContext)
// Set focus to the specified element. // Set focus to the specified element.
useEffect(() => { useEffect(() => {
@ -37,12 +37,7 @@ export default function Component() {
{% endfor %} {% endfor %}
return ( return (
<Fragment>
{%- if err_comp -%}
{{ utils.render(err_comp, indent_width=1) }}
{%- endif -%}
{{utils.render(render, indent_width=0)}} {{utils.render(render, indent_width=0)}}
</Fragment>
) )
} }
{% endblock %} {% endblock %}

View File

@ -15,12 +15,12 @@ const GlobalStyles = css`
`; `;
function EventLoopProvider({ children }) { function EventLoopProvider({ children }) {
const [state, Event, notConnected] = useEventLoop( const [state, Event, connectError] = useEventLoop(
initialState, initialState,
initialEvents, initialEvents,
) )
return ( return (
<EventLoopContext.Provider value={[Event, notConnected]}> <EventLoopContext.Provider value={[Event, connectError]}>
<StateContext.Provider value={state}> <StateContext.Provider value={state}>
{children} {children}
</StateContext.Provider> </StateContext.Provider>

View File

@ -250,14 +250,14 @@ export const processEvent = async (
* @param socket The socket object to connect. * @param socket The socket object to connect.
* @param dispatch The function to queue state update * @param dispatch The function to queue state update
* @param transports The transports to use. * @param transports The transports to use.
* @param setNotConnected The function to update connection state. * @param setConnectError The function to update connection error value.
* @param initial_events Array of events to seed the queue after connecting. * @param initial_events Array of events to seed the queue after connecting.
*/ */
export const connect = async ( export const connect = async (
socket, socket,
dispatch, dispatch,
transports, transports,
setNotConnected, setConnectError,
initial_events = [], initial_events = [],
) => { ) => {
// Get backend URL object from the endpoint. // Get backend URL object from the endpoint.
@ -272,11 +272,11 @@ export const connect = async (
// Once the socket is open, hydrate the page. // Once the socket is open, hydrate the page.
socket.current.on("connect", () => { socket.current.on("connect", () => {
queueEvents(initial_events, socket) queueEvents(initial_events, socket)
setNotConnected(false) setConnectError(null)
}); });
socket.current.on('connect_error', (error) => { socket.current.on('connect_error', (error) => {
setNotConnected(true) setConnectError(error)
}); });
// On each received message, queue the updates and events. // On each received message, queue the updates and events.
@ -357,10 +357,10 @@ export const E = (name, payload = {}, handler = null) => {
* @param initial_state The initial page state. * @param initial_state The initial page state.
* @param initial_events Array of events to seed the queue after connecting. * @param initial_events Array of events to seed the queue after connecting.
* *
* @returns [state, Event, notConnected] - * @returns [state, Event, connectError] -
* state is a reactive dict, * state is a reactive dict,
* Event is used to queue an event, and * Event is used to queue an event, and
* notConnected is a reactive boolean indicating whether the websocket is connected. * connectError is a reactive js error from the websocket connection (or null if connected).
*/ */
export const useEventLoop = ( export const useEventLoop = (
initial_state = {}, initial_state = {},
@ -369,7 +369,7 @@ export const useEventLoop = (
const socket = useRef(null) const socket = useRef(null)
const router = useRouter() const router = useRouter()
const [state, dispatch] = useReducer(applyDelta, initial_state) const [state, dispatch] = useReducer(applyDelta, initial_state)
const [notConnected, setNotConnected] = useState(false) const [connectError, setConnectError] = useState(null)
// Function to add new events to the event queue. // Function to add new events to the event queue.
const Event = (events, _e) => { const Event = (events, _e) => {
@ -386,7 +386,7 @@ export const useEventLoop = (
// Initialize the websocket connection. // Initialize the websocket connection.
if (!socket.current) { if (!socket.current) {
connect(socket, dispatch, ['websocket', 'polling'], setNotConnected, initial_events) connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events)
} }
(async () => { (async () => {
// Process all outstanding events. // Process all outstanding events.
@ -395,7 +395,7 @@ export const useEventLoop = (
} }
})() })()
}) })
return [state, Event, notConnected] return [state, Event, connectError]
} }
/*** /***

View File

@ -1,4 +1,5 @@
"""The main Reflex app.""" """The main Reflex app."""
from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
@ -29,6 +30,7 @@ from reflex.admin import AdminDash
from reflex.base import Base from reflex.base import Base
from reflex.compiler import compiler from reflex.compiler import compiler
from reflex.compiler import utils as compiler_utils from reflex.compiler import utils as compiler_utils
from reflex.components import connection_modal
from reflex.components.component import Component, ComponentStyle from reflex.components.component import Component, ComponentStyle
from reflex.components.layout.fragment import Fragment from reflex.components.layout.fragment import Fragment
from reflex.config import get_config from reflex.config import get_config
@ -88,12 +90,12 @@ class App(Base):
# Admin dashboard # Admin dashboard
admin_dash: Optional[AdminDash] = None admin_dash: Optional[AdminDash] = None
# The component to render if there is a connection error to the server.
connect_error_component: Optional[Component] = None
# The async server name space # The async server name space
event_namespace: Optional[AsyncNamespace] = None event_namespace: Optional[AsyncNamespace] = None
# A component that is present on every page.
overlay_component: Optional[Union[Component, ComponentCallable]] = connection_modal
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialize the app. """Initialize the app.
@ -106,6 +108,10 @@ class App(Base):
Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
of the DefaultState and the client app state). of the DefaultState and the client app state).
""" """
if "connect_error_component" in kwargs:
raise ValueError(
"`connect_error_component` is deprecated, use `overlay_component` instead"
)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
state_subclasses = State.__subclasses__() state_subclasses = State.__subclasses__()
inferred_state = state_subclasses[-1] inferred_state = state_subclasses[-1]
@ -269,6 +275,31 @@ class App(Base):
else: else:
self.middleware.insert(index, middleware) self.middleware.insert(index, middleware)
@staticmethod
def _generate_component(component: Component | ComponentCallable) -> Component:
"""Generate a component from a callable.
Args:
component: The component function to call or Component to return as-is.
Returns:
The generated component.
Raises:
TypeError: When an invalid component function is passed.
"""
try:
return component if isinstance(component, Component) else component()
except TypeError as e:
message = str(e)
if "BaseVar" in message or "ComputedVar" in message:
raise TypeError(
"You may be trying to use an invalid Python function on a state var. "
"When referencing a var inside your render code, only limited var operations are supported. "
"See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations"
) from e
raise e
def add_page( def add_page(
self, self,
component: Union[Component, ComponentCallable], component: Union[Component, ComponentCallable],
@ -296,9 +327,6 @@ class App(Base):
on_load: The event handler(s) that will be called each time the page load. on_load: The event handler(s) that will be called each time the page load.
meta: The metadata of the page. meta: The metadata of the page.
script_tags: List of script tags to be added to component script_tags: List of script tags to be added to component
Raises:
TypeError: If an invalid var operation is used.
""" """
# If the route is not set, get it from the callable. # If the route is not set, get it from the callable.
if route is None: if route is None:
@ -314,19 +342,15 @@ class App(Base):
self.state.setup_dynamic_args(get_route_args(route)) self.state.setup_dynamic_args(get_route_args(route))
# Generate the component if it is a callable. # Generate the component if it is a callable.
try: component = self._generate_component(component)
component = component if isinstance(component, Component) else component()
except TypeError as e:
message = str(e)
if "BaseVar" in message or "ComputedVar" in message:
raise TypeError(
"You may be trying to use an invalid Python function on a state var. "
"When referencing a var inside your render code, only limited var operations are supported. "
"See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations"
) from e
raise e
# Wrap the component in a fragment. # Wrap the component in a fragment with optional overlay.
if self.overlay_component is not None:
component = Fragment.create(
self._generate_component(self.overlay_component),
component,
)
else:
component = Fragment.create(component) component = Fragment.create(component)
# Add meta information to the component. # Add meta information to the component.
@ -497,7 +521,6 @@ class App(Base):
route, route,
component, component,
self.state, self.state,
self.connect_error_component,
), ),
) )
) )

View File

@ -89,14 +89,12 @@ def _compile_contexts(state: Type[State]) -> str:
def _compile_page( def _compile_page(
component: Component, component: Component,
state: Type[State], state: Type[State],
connect_error_component,
) -> str: ) -> str:
"""Compile the component given the app state. """Compile the component given the app state.
Args: Args:
component: The component to compile. component: The component to compile.
state: The app state. state: The app state.
connect_error_component: The component to render on sever connection error.
Returns: Returns:
The compiled component. The compiled component.
@ -113,7 +111,6 @@ def _compile_page(
state_name=state.get_name(), state_name=state.get_name(),
hooks=component.get_hooks(), hooks=component.get_hooks(),
render=component.render(), render=component.render(),
err_comp=connect_error_component.render() if connect_error_component else None,
) )
@ -221,7 +218,6 @@ def compile_page(
path: str, path: str,
component: Component, component: Component,
state: Type[State], state: Type[State],
connect_error_component: Component,
) -> Tuple[str, str]: ) -> Tuple[str, str]:
"""Compile a single page. """Compile a single page.
@ -229,7 +225,6 @@ def compile_page(
path: The path to compile the page to. path: The path to compile the page to.
component: The component to compile. component: The component to compile.
state: The app state. state: The app state.
connect_error_component: The component to render on sever connection error.
Returns: Returns:
The path and code of the compiled page. The path and code of the compiled page.
@ -238,11 +233,7 @@ def compile_page(
output_path = utils.get_page_path(path) output_path = utils.get_page_path(path)
# Add the style to the component. # Add the style to the component.
code = _compile_page( code = _compile_page(component, state)
component,
state,
connect_error_component,
)
return output_path, code return output_path, code

View File

@ -31,6 +31,7 @@ badge = Badge.create
code = Code.create code = Code.create
code_block = CodeBlock.create code_block = CodeBlock.create
connection_banner = ConnectionBanner.create connection_banner = ConnectionBanner.create
connection_modal = ConnectionModal.create
data_table = DataTable.create data_table = DataTable.create
divider = Divider.create divider = Divider.create
list = List.create list = List.create

View File

@ -8,7 +8,7 @@ from .alertdialog import (
AlertDialogHeader, AlertDialogHeader,
AlertDialogOverlay, AlertDialogOverlay,
) )
from .banner import ConnectionBanner from .banner import ConnectionBanner, ConnectionModal
from .drawer import ( from .drawer import (
Drawer, Drawer,
DrawerBody, DrawerBody,

View File

@ -1,11 +1,41 @@
"""Banner components.""" """Banner components."""
from __future__ import annotations
from typing import Optional from typing import Optional
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.layout import Box, Cond, Fragment from reflex.components.layout import Box, Cond, Fragment
from reflex.components.overlay.modal import Modal
from reflex.components.typography import Text from reflex.components.typography import Text
from reflex.vars import Var from reflex.vars import Var
connection_error = Var.create_safe(
value="(connectError !== null) ? connectError.message : ''",
is_local=False,
is_string=False,
)
has_connection_error = Var.create_safe(
value="connectError !== null",
is_string=False,
)
has_connection_error.type_ = bool
def default_connection_error() -> list[str | Var]:
"""Get the default connection error message.
Returns:
The default connection error message.
"""
from reflex.config import get_config
return [
"Cannot connect to server: ",
connection_error,
". Check if server is reachable at ",
get_config().api_url or "<API_URL not set>",
]
class ConnectionBanner(Cond): class ConnectionBanner(Cond):
"""A connection banner component.""" """A connection banner component."""
@ -23,11 +53,33 @@ class ConnectionBanner(Cond):
if not comp: if not comp:
comp = Box.create( comp = Box.create(
Text.create( Text.create(
"cannot connect to server. Check if server is reachable", *default_connection_error(),
bg="red", bg="red",
color="white", color="white",
), ),
textAlign="center", textAlign="center",
) )
return super().create(Var.create("notConnected"), comp, Fragment.create()) # type: ignore return super().create(has_connection_error, comp, Fragment.create()) # type: ignore
class ConnectionModal(Modal):
"""A connection status modal window."""
@classmethod
def create(cls, comp: Optional[Component] = None) -> Component:
"""Create a connection banner component.
Args:
comp: The component to render when there's a server connection error.
Returns:
The connection banner component.
"""
if not comp:
comp = Text.create(*default_connection_error())
return super().create(
header="Connection Error",
body=comp,
is_open=has_connection_error,
)