diff --git a/reflex/app.py b/reflex/app.py index c47077d0e..95c9de2c6 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -116,9 +116,6 @@ class App(Base): # The Socket.IO AsyncServer. sio: Optional[AsyncServer] = None - # The socket app. - socket_app: Optional[ASGIApp] = None - # The state class to use for the app. state: Optional[Type[BaseState]] = None @@ -213,7 +210,11 @@ class App(Base): self.setup_state() def setup_state(self) -> None: - """Set up the state for the app.""" + """Set up the state for the app. + + Raises: + RuntimeError: If custom `sio` does not use `async_mode='asgi'`. + """ if not self.state: return @@ -223,21 +224,27 @@ class App(Base): self._state_manager = StateManager.create(state=self.state) # Set up the Socket.IO AsyncServer. - self.sio = AsyncServer( - async_mode="asgi", - cors_allowed_origins=( - "*" - if config.cors_allowed_origins == ["*"] - else config.cors_allowed_origins - ), - cors_credentials=True, - max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, - ping_interval=constants.Ping.INTERVAL, - ping_timeout=constants.Ping.TIMEOUT, - ) + if not self.sio: + self.sio = AsyncServer( + async_mode="asgi", + cors_allowed_origins=( + "*" + if config.cors_allowed_origins == ["*"] + else config.cors_allowed_origins + ), + cors_credentials=True, + max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, + ping_interval=constants.Ping.INTERVAL, + ping_timeout=constants.Ping.TIMEOUT, + ) + elif getattr(self.sio, "async_mode", "") != "asgi": + raise RuntimeError( + f"Custom `sio` must use `async_mode='asgi'`, not '{self.sio.async_mode}'." + ) # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path. - self.socket_app = ASGIApp(self.sio, socketio_path="") + socket_app = ASGIApp(self.sio, socketio_path="") + namespace = config.get_event_namespace() # Create the event namespace and attach the main app. Not related to any paths. @@ -246,7 +253,7 @@ class App(Base): # Register the event namespace with the socket. self.sio.register_namespace(self.event_namespace) # Mount the socket app with the API. - self.api.mount(str(constants.Endpoint.EVENT), self.socket_app) + self.api.mount(str(constants.Endpoint.EVENT), socket_app) def __repr__(self) -> str: """Get the string representation of the app.