diff --git a/integration/test_connection_banner.py b/integration/test_connection_banner.py index 293d2e412..b83a493ce 100644 --- a/integration/test_connection_banner.py +++ b/integration/test_connection_banner.py @@ -8,16 +8,32 @@ from selenium.webdriver.common.by import By from reflex.testing import AppHarness, WebDriver +from .utils import SessionStorage + def ConnectionBanner(): """App with a connection banner.""" + import asyncio + import reflex as rx class State(rx.State): foo: int = 0 + async def delay(self): + await asyncio.sleep(5) + def index(): - return rx.text("Hello World") + return rx.vstack( + rx.text("Hello World"), + rx.input(value=State.foo, read_only=True, id="counter"), + rx.button( + "Increment", + id="increment", + on_click=State.set_foo(State.foo + 1), # type: ignore + ), + rx.button("Delay", id="delay", on_click=State.delay), + ) app = rx.App(state=rx.State) app.add_page(index) @@ -40,7 +56,7 @@ def connection_banner(tmp_path) -> Generator[AppHarness, None, None]: yield harness -CONNECTION_ERROR_XPATH = "//*[ text() = 'Connection Error' ]" +CONNECTION_ERROR_XPATH = "//*[ contains(text(), 'Cannot connect to server') ]" def has_error_modal(driver: WebDriver) -> bool: @@ -59,7 +75,8 @@ def has_error_modal(driver: WebDriver) -> bool: return False -def test_connection_banner(connection_banner: AppHarness): +@pytest.mark.asyncio +async def test_connection_banner(connection_banner: AppHarness): """Test that the connection banner is displayed when the websocket drops. Args: @@ -69,7 +86,23 @@ def test_connection_banner(connection_banner: AppHarness): assert connection_banner.backend is not None driver = connection_banner.frontend() - connection_banner._poll_for(lambda: not has_error_modal(driver)) + ss = SessionStorage(driver) + assert connection_banner._poll_for( + lambda: ss.get("token") is not None + ), "token not found" + + assert connection_banner._poll_for(lambda: not has_error_modal(driver)) + + delay_button = driver.find_element(By.ID, "delay") + increment_button = driver.find_element(By.ID, "increment") + counter_element = driver.find_element(By.ID, "counter") + + # Increment the counter + increment_button.click() + assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" + + # Start an long event before killing the backend, to mark event_processing=true + delay_button.click() # Get the backend port backend_port = connection_banner._poll_for_servers().getsockname()[1] @@ -80,10 +113,20 @@ def test_connection_banner(connection_banner: AppHarness): connection_banner.backend_thread.join() # Error modal should now be displayed - connection_banner._poll_for(lambda: has_error_modal(driver)) + assert connection_banner._poll_for(lambda: has_error_modal(driver)) + + # Increment the counter with backend down + increment_button.click() + assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" # Bring the backend back up connection_banner._start_backend(port=backend_port) + # Create a new StateManager to avoid async loop affinity issues w/ redis + await connection_banner._reset_backend_state_manager() + # Banner should be gone now - connection_banner._poll_for(lambda: not has_error_modal(driver)) + assert connection_banner._poll_for(lambda: not has_error_modal(driver)) + + # Count should have incremented after coming back up + assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2" diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 8386261e9..35a129337 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -358,6 +358,12 @@ export const connect = async ( socket.current.on("connect_error", (error) => { setConnectErrors((connectErrors) => [connectErrors.slice(-9), error]); }); + + // When the socket disconnects reset the event_processing flag + socket.current.on("disconnect", () => { + event_processing = false; + }); + // On each received message, queue the updates and events. socket.current.on("event", (message) => { const update = JSON5.parse(message); diff --git a/reflex/testing.py b/reflex/testing.py index d27396892..8fcb50d60 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -297,6 +297,27 @@ class AppHarness: self.backend_thread = threading.Thread(target=self.backend.run) self.backend_thread.start() + async def _reset_backend_state_manager(self): + """Reset the StateManagerRedis event loop affinity. + + This is necessary when the backend is restarted and the state manager is a + StateManagerRedis instance. + """ + if ( + self.app_instance is not None + and isinstance( + self.app_instance.state_manager, + StateManagerRedis, + ) + and self.app_instance.state is not None + ): + with contextlib.suppress(RuntimeError): + await self.app_instance.state_manager.close() + self.app_instance._state_manager = StateManagerRedis.create( + state=self.app_instance.state, + ) + assert isinstance(self.app_instance.state_manager, StateManagerRedis) + def _start_frontend(self): # Set up the frontend. with chdir(self.app_path):