From 433ccda3a6f16aebdcd76a2eef5999073860617e Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Fri, 13 Oct 2023 21:54:59 +0000 Subject: [PATCH] No state No Websocket (#1950) --- reflex/.templates/web/utils/state.js | 22 +++++++------- reflex/app.py | 45 ++++++++++++++-------------- tests/test_state.py | 10 ++++--- 3 files changed, 41 insertions(+), 36 deletions(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 98d8db273..9b596d0ff 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -507,17 +507,19 @@ export const useEventLoop = ( if (!router.isReady) { return; } - - // Initialize the websocket connection. - if (!socket.current) { - connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage) - } - (async () => { - // Process all outstanding events. - while (event_queue.length > 0 && !event_processing) { - await processEvent(socket.current) + // only use websockets if state is present + if (Object.keys(state).length > 0) { + // Initialize the websocket connection. + if (!socket.current) { + connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage) } - })() + (async () => { + // Process all outstanding events. + while (event_queue.length > 0 && !event_processing) { + await processEvent(socket.current) + } + })() + } }) return [state, addEvents, connectError] } diff --git a/reflex/app.py b/reflex/app.py index 0fdacea48..de3eba566 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -175,32 +175,33 @@ class App(Base): self.add_cors() self.add_default_endpoints() - # Set up the Socket.IO AsyncServer. - self.sio = AsyncServer( - async_mode="asgi", - cors_allowed_origins="*" - if config.cors_allowed_origins == ["*"] - else config.cors_allowed_origins, - cors_credentials=True, - max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, - ping_interval=constants.Ping.INTERVAL, - ping_timeout=constants.Ping.TIMEOUT, - ) + if self.state is not DefaultState: + # Set up the Socket.IO AsyncServer. + self.sio = AsyncServer( + async_mode="asgi", + cors_allowed_origins="*" + if config.cors_allowed_origins == ["*"] + else config.cors_allowed_origins, + cors_credentials=True, + max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, + ping_interval=constants.Ping.INTERVAL, + ping_timeout=constants.Ping.TIMEOUT, + ) - # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path. - self.socket_app = ASGIApp(self.sio, socketio_path="") - namespace = config.get_event_namespace() + # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path. + self.socket_app = ASGIApp(self.sio, socketio_path="") + namespace = config.get_event_namespace() - if not namespace: - raise ValueError("event namespace must be provided in the config.") + if not namespace: + raise ValueError("event namespace must be provided in the config.") - # Create the event namespace and attach the main app. Not related to any paths. - self.event_namespace = EventNamespace(namespace, self) + # Create the event namespace and attach the main app. Not related to any paths. + self.event_namespace = EventNamespace(namespace, self) - # Register the event namespace with the socket. - self.sio.register_namespace(self.event_namespace) - # Mount the socket app with the API. - self.api.mount(str(constants.Endpoint.EVENT), self.socket_app) + # Register the event namespace with the socket. + self.sio.register_namespace(self.event_namespace) + # Mount the socket app with the API. + self.api.mount(str(constants.Endpoint.EVENT), self.socket_app) # Set up the admin dash. self.setup_admin_dash() diff --git a/tests/test_state.py b/tests/test_state.py index 308d9282b..0995f3403 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -14,6 +14,7 @@ import pytest from plotly.graph_objects import Figure import reflex as rx +from reflex.app import App from reflex.base import Base from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.event import Event, EventHandler @@ -1528,23 +1529,24 @@ async def test_state_manager_lock_expire_contend( @pytest.fixture(scope="function") -def mock_app(monkeypatch, app: rx.App, state_manager: StateManager) -> rx.App: +def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: """Mock app fixture. Args: monkeypatch: Pytest monkeypatch object. - app: An app. state_manager: A state manager. Returns: The app, after mocking out prerequisites.get_app() """ + app = App(state=TestState) + app_module = Mock() + setattr(app_module, CompileVars.APP, app) app.state = TestState app.state_manager = state_manager - assert app.event_namespace is not None - app.event_namespace.emit = AsyncMock() + app.event_namespace.emit = AsyncMock() # type: ignore monkeypatch.setattr(prerequisites, "get_app", lambda: app_module) return app