From 6b481ecfc3864bd8a6da90b01a2be178a0237eb8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 28 Aug 2023 18:04:52 -0700 Subject: [PATCH] ConnectionModal and ConnectionBanner cleanup (#1379) --- .../jinja/web/pages/index.js.jinja2 | 7 +-- reflex/.templates/web/pages/_app.js | 4 +- reflex/.templates/web/utils/state.js | 18 +++--- reflex/app.py | 63 +++++++++++++------ reflex/compiler/compiler.py | 11 +--- reflex/components/__init__.py | 1 + reflex/components/overlay/__init__.py | 2 +- reflex/components/overlay/banner.py | 56 ++++++++++++++++- 8 files changed, 112 insertions(+), 50 deletions(-) diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2 index ed19b5279..87465e9e4 100644 --- a/reflex/.templates/jinja/web/pages/index.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/index.js.jinja2 @@ -14,7 +14,7 @@ export default function Component() { const focusRef = useRef(); // Main event loop. - const [Event, notConnected] = useContext(EventLoopContext) + const [Event, connectError] = useContext(EventLoopContext) // Set focus to the specified element. useEffect(() => { @@ -37,12 +37,7 @@ export default function Component() { {% endfor %} return ( - - {%- if err_comp -%} - {{ utils.render(err_comp, indent_width=1) }} - {%- endif -%} {{utils.render(render, indent_width=0)}} - ) } {% endblock %} diff --git a/reflex/.templates/web/pages/_app.js b/reflex/.templates/web/pages/_app.js index 89461263f..6411ee9b5 100644 --- a/reflex/.templates/web/pages/_app.js +++ b/reflex/.templates/web/pages/_app.js @@ -15,12 +15,12 @@ const GlobalStyles = css` `; function EventLoopProvider({ children }) { - const [state, Event, notConnected] = useEventLoop( + const [state, Event, connectError] = useEventLoop( initialState, initialEvents, ) return ( - + {children} diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 5c382a3cf..e39ac8cce 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -250,14 +250,14 @@ export const processEvent = async ( * @param socket The socket object to connect. * @param dispatch The function to queue state update * @param transports The transports to use. - * @param setNotConnected The function to update connection state. + * @param setConnectError The function to update connection error value. * @param initial_events Array of events to seed the queue after connecting. */ export const connect = async ( socket, dispatch, transports, - setNotConnected, + setConnectError, initial_events = [], ) => { // Get backend URL object from the endpoint. @@ -272,11 +272,11 @@ export const connect = async ( // Once the socket is open, hydrate the page. socket.current.on("connect", () => { queueEvents(initial_events, socket) - setNotConnected(false) + setConnectError(null) }); socket.current.on('connect_error', (error) => { - setNotConnected(true) + setConnectError(error) }); // On each received message, queue the updates and events. @@ -357,10 +357,10 @@ export const E = (name, payload = {}, handler = null) => { * @param initial_state The initial page state. * @param initial_events Array of events to seed the queue after connecting. * - * @returns [state, Event, notConnected] - + * @returns [state, Event, connectError] - * state is a reactive dict, * Event is used to queue an event, and - * notConnected is a reactive boolean indicating whether the websocket is connected. + * connectError is a reactive js error from the websocket connection (or null if connected). */ export const useEventLoop = ( initial_state = {}, @@ -369,7 +369,7 @@ export const useEventLoop = ( const socket = useRef(null) const router = useRouter() const [state, dispatch] = useReducer(applyDelta, initial_state) - const [notConnected, setNotConnected] = useState(false) + const [connectError, setConnectError] = useState(null) // Function to add new events to the event queue. const Event = (events, _e) => { @@ -386,7 +386,7 @@ export const useEventLoop = ( // Initialize the websocket connection. if (!socket.current) { - connect(socket, dispatch, ['websocket', 'polling'], setNotConnected, initial_events) + connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events) } (async () => { // Process all outstanding events. @@ -395,7 +395,7 @@ export const useEventLoop = ( } })() }) - return [state, Event, notConnected] + return [state, Event, connectError] } /*** diff --git a/reflex/app.py b/reflex/app.py index 80eb0fb1e..89fad454d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1,4 +1,5 @@ """The main Reflex app.""" +from __future__ import annotations import asyncio import inspect @@ -29,6 +30,7 @@ from reflex.admin import AdminDash from reflex.base import Base from reflex.compiler import compiler from reflex.compiler import utils as compiler_utils +from reflex.components import connection_modal from reflex.components.component import Component, ComponentStyle from reflex.components.layout.fragment import Fragment from reflex.config import get_config @@ -88,12 +90,12 @@ class App(Base): # Admin dashboard admin_dash: Optional[AdminDash] = None - # The component to render if there is a connection error to the server. - connect_error_component: Optional[Component] = None - # The async server name space event_namespace: Optional[AsyncNamespace] = None + # A component that is present on every page. + overlay_component: Optional[Union[Component, ComponentCallable]] = connection_modal + def __init__(self, *args, **kwargs): """Initialize the app. @@ -106,6 +108,10 @@ class App(Base): Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist of the DefaultState and the client app state). """ + if "connect_error_component" in kwargs: + raise ValueError( + "`connect_error_component` is deprecated, use `overlay_component` instead" + ) super().__init__(*args, **kwargs) state_subclasses = State.__subclasses__() inferred_state = state_subclasses[-1] @@ -269,6 +275,31 @@ class App(Base): else: self.middleware.insert(index, middleware) + @staticmethod + def _generate_component(component: Component | ComponentCallable) -> Component: + """Generate a component from a callable. + + Args: + component: The component function to call or Component to return as-is. + + Returns: + The generated component. + + Raises: + TypeError: When an invalid component function is passed. + """ + try: + return component if isinstance(component, Component) else component() + except TypeError as e: + message = str(e) + if "BaseVar" in message or "ComputedVar" in message: + raise TypeError( + "You may be trying to use an invalid Python function on a state var. " + "When referencing a var inside your render code, only limited var operations are supported. " + "See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations" + ) from e + raise e + def add_page( self, component: Union[Component, ComponentCallable], @@ -296,9 +327,6 @@ class App(Base): on_load: The event handler(s) that will be called each time the page load. meta: The metadata of the page. script_tags: List of script tags to be added to component - - Raises: - TypeError: If an invalid var operation is used. """ # If the route is not set, get it from the callable. if route is None: @@ -314,20 +342,16 @@ class App(Base): self.state.setup_dynamic_args(get_route_args(route)) # Generate the component if it is a callable. - try: - component = component if isinstance(component, Component) else component() - except TypeError as e: - message = str(e) - if "BaseVar" in message or "ComputedVar" in message: - raise TypeError( - "You may be trying to use an invalid Python function on a state var. " - "When referencing a var inside your render code, only limited var operations are supported. " - "See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations" - ) from e - raise e + component = self._generate_component(component) - # Wrap the component in a fragment. - component = Fragment.create(component) + # Wrap the component in a fragment with optional overlay. + if self.overlay_component is not None: + component = Fragment.create( + self._generate_component(self.overlay_component), + component, + ) + else: + component = Fragment.create(component) # Add meta information to the component. compiler_utils.add_meta( @@ -497,7 +521,6 @@ class App(Base): route, component, self.state, - self.connect_error_component, ), ) ) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 94a0a9d18..0d5a325fb 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -89,14 +89,12 @@ def _compile_contexts(state: Type[State]) -> str: def _compile_page( component: Component, state: Type[State], - connect_error_component, ) -> str: """Compile the component given the app state. Args: component: The component to compile. state: The app state. - connect_error_component: The component to render on sever connection error. Returns: The compiled component. @@ -113,7 +111,6 @@ def _compile_page( state_name=state.get_name(), hooks=component.get_hooks(), render=component.render(), - err_comp=connect_error_component.render() if connect_error_component else None, ) @@ -221,7 +218,6 @@ def compile_page( path: str, component: Component, state: Type[State], - connect_error_component: Component, ) -> Tuple[str, str]: """Compile a single page. @@ -229,7 +225,6 @@ def compile_page( path: The path to compile the page to. component: The component to compile. state: The app state. - connect_error_component: The component to render on sever connection error. Returns: The path and code of the compiled page. @@ -238,11 +233,7 @@ def compile_page( output_path = utils.get_page_path(path) # Add the style to the component. - code = _compile_page( - component, - state, - connect_error_component, - ) + code = _compile_page(component, state) return output_path, code diff --git a/reflex/components/__init__.py b/reflex/components/__init__.py index 28db13d87..80e2bc662 100644 --- a/reflex/components/__init__.py +++ b/reflex/components/__init__.py @@ -31,6 +31,7 @@ badge = Badge.create code = Code.create code_block = CodeBlock.create connection_banner = ConnectionBanner.create +connection_modal = ConnectionModal.create data_table = DataTable.create divider = Divider.create list = List.create diff --git a/reflex/components/overlay/__init__.py b/reflex/components/overlay/__init__.py index 227fdf536..9482d003f 100644 --- a/reflex/components/overlay/__init__.py +++ b/reflex/components/overlay/__init__.py @@ -8,7 +8,7 @@ from .alertdialog import ( AlertDialogHeader, AlertDialogOverlay, ) -from .banner import ConnectionBanner +from .banner import ConnectionBanner, ConnectionModal from .drawer import ( Drawer, DrawerBody, diff --git a/reflex/components/overlay/banner.py b/reflex/components/overlay/banner.py index c21bdc626..b16c2ae0d 100644 --- a/reflex/components/overlay/banner.py +++ b/reflex/components/overlay/banner.py @@ -1,11 +1,41 @@ """Banner components.""" +from __future__ import annotations + from typing import Optional from reflex.components.component import Component from reflex.components.layout import Box, Cond, Fragment +from reflex.components.overlay.modal import Modal from reflex.components.typography import Text from reflex.vars import Var +connection_error = Var.create_safe( + value="(connectError !== null) ? connectError.message : ''", + is_local=False, + is_string=False, +) +has_connection_error = Var.create_safe( + value="connectError !== null", + is_string=False, +) +has_connection_error.type_ = bool + + +def default_connection_error() -> list[str | Var]: + """Get the default connection error message. + + Returns: + The default connection error message. + """ + from reflex.config import get_config + + return [ + "Cannot connect to server: ", + connection_error, + ". Check if server is reachable at ", + get_config().api_url or "", + ] + class ConnectionBanner(Cond): """A connection banner component.""" @@ -23,11 +53,33 @@ class ConnectionBanner(Cond): if not comp: comp = Box.create( Text.create( - "cannot connect to server. Check if server is reachable", + *default_connection_error(), bg="red", color="white", ), textAlign="center", ) - return super().create(Var.create("notConnected"), comp, Fragment.create()) # type: ignore + return super().create(has_connection_error, comp, Fragment.create()) # type: ignore + + +class ConnectionModal(Modal): + """A connection status modal window.""" + + @classmethod + def create(cls, comp: Optional[Component] = None) -> Component: + """Create a connection banner component. + + Args: + comp: The component to render when there's a server connection error. + + Returns: + The connection banner component. + """ + if not comp: + comp = Text.create(*default_connection_error()) + return super().create( + header="Connection Error", + body=comp, + is_open=has_connection_error, + )