From 4dfe739e45d54987964d05d8072020b73d3211ca Mon Sep 17 00:00:00 2001 From: Lendemor Date: Sat, 18 Jan 2025 16:53:59 +0100 Subject: [PATCH] fix public API for some attributes of App() --- reflex/app.py | 66 +++++++++++++++++++++++------------- reflex/constants/compiler.py | 2 +- reflex/testing.py | 4 +-- tests/units/test_app.py | 32 ++++++++--------- tests/units/test_state.py | 12 +++---- 5 files changed, 67 insertions(+), 49 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index a76d82213..712dcee9f 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -274,7 +274,7 @@ class App(MiddlewareMixin, LifespanMixin): ) # Admin dashboard to view and manage the database. - _admin_dash: Optional[AdminDash] = None + admin_dash: Optional[AdminDash] = None # The async server name space. PRIVATE. _event_namespace: Optional[EventNamespace] = None @@ -292,6 +292,24 @@ class App(MiddlewareMixin, LifespanMixin): [Exception], Union[EventSpec, List[EventSpec], None] ] = default_backend_exception_handler + @property + def api(self) -> FastAPI | None: + """Get the backend api. + + Returns: + The backend api. + """ + return self._api + + @property + def event_namespace(self) -> EventNamespace | None: + """Get the event namespace. + + Returns: + The event namespace. + """ + return self._event_namespace + def __post_init__(self): """Initialize the app. @@ -384,10 +402,10 @@ class App(MiddlewareMixin, LifespanMixin): self._event_namespace = EventNamespace(namespace, self) # Register the event namespace with the socket. - self.sio.register_namespace(self._event_namespace) + self.sio.register_namespace(self.event_namespace) # Mount the socket app with the API. - if self._api: - self._api.mount(str(constants.Endpoint.EVENT), socket_app) + if self.api: + self.api.mount(str(constants.Endpoint.EVENT), socket_app) # Check the exception handlers self._validate_exception_handlers() @@ -409,44 +427,44 @@ class App(MiddlewareMixin, LifespanMixin): Returns: The backend api. """ - if not self._api: + if not self.api: raise ValueError("The app has not been initialized.") - return self._api + return self.api def _add_default_endpoints(self): """Add default api endpoints (ping).""" # To test the server. - if not self._api: + if not self.api: return - self._api.get(str(constants.Endpoint.PING))(ping) - self._api.get(str(constants.Endpoint.HEALTH))(health) + self.api.get(str(constants.Endpoint.PING))(ping) + self.api.get(str(constants.Endpoint.HEALTH))(health) def _add_optional_endpoints(self): """Add optional api endpoints (_upload).""" - if not self._api: + if not self.api: return if Upload.is_used: # To upload files. - self._api.post(str(constants.Endpoint.UPLOAD))(upload(self)) + self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) # To access uploaded files. - self._api.mount( + self.api.mount( str(constants.Endpoint.UPLOAD), StaticFiles(directory=get_upload_dir()), name="uploaded_files", ) if codespaces.is_running_in_codespaces(): - self._api.get(str(constants.Endpoint.AUTH_CODESPACE))( + self.api.get(str(constants.Endpoint.AUTH_CODESPACE))( codespaces.auth_codespace ) def _add_cors(self): """Add CORS middleware to the app.""" - if not self._api: + if not self.api: return - self._api.add_middleware( + self.api.add_middleware( cors.CORSMiddleware, allow_credentials=True, allow_methods=["*"], @@ -689,10 +707,10 @@ class App(MiddlewareMixin, LifespanMixin): def _setup_admin_dash(self): """Setup the admin dash.""" # Get the admin dash. - if not self._api: + if not self.api: return - admin_dash = self._admin_dash + admin_dash = self.admin_dash if admin_dash and admin_dash.models: # Build the admin dashboard @@ -710,7 +728,7 @@ class App(MiddlewareMixin, LifespanMixin): view = admin_dash.view_overrides.get(model, ModelView) admin.add_view(view(model)) - admin.mount_to(self._api) + admin.mount_to(self.api) def _get_frontend_packages(self, imports: Dict[str, set[ImportVar]]): """Gets the frontend packages to be installed and filters out the unnecessary ones. @@ -1113,7 +1131,7 @@ class App(MiddlewareMixin, LifespanMixin): Raises: RuntimeError: If the app has not been initialized yet. """ - if self._event_namespace is None: + if self.event_namespace is None: raise RuntimeError("App has not been initialized yet.") # Get exclusive access to the state. @@ -1124,7 +1142,7 @@ class App(MiddlewareMixin, LifespanMixin): if delta: # When the state is modified reset dirty status and emit the delta to the frontend. state._clean() - await self._event_namespace.emit_update( + await self.event_namespace.emit_update( update=StateUpdate(delta=delta), sid=state.router.session.session_id, ) @@ -1152,7 +1170,7 @@ class App(MiddlewareMixin, LifespanMixin): Raises: RuntimeError: If the app has not been initialized yet. """ - if self._event_namespace is None: + if self.event_namespace is None: raise RuntimeError("App has not been initialized yet.") # Process the event. @@ -1163,7 +1181,7 @@ class App(MiddlewareMixin, LifespanMixin): update = await self._postprocess(state, event, update) # Send the update to the client. - await self._event_namespace.emit_update( + await self.event_namespace.emit_update( update=update, sid=state.router.session.session_id, ) @@ -1308,10 +1326,10 @@ async def process( if ( not state.router_data and event.name != get_hydrate_event(state) - and app._event_namespace is not None + and app.event_namespace is not None ): await asyncio.create_task( - app._event_namespace.emit( + app.event_namespace.emit( "reload", data=event, to=sid, diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 13982eab8..d98c04d76 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -36,7 +36,7 @@ class CompileVars(SimpleNamespace): # The expected variable name where the rx.App is stored. APP = "app" # The expected variable name where the API object is stored for deployment. - API = "_api" + API = "api" # The name of the router variable. ROUTER = "router" # The name of the socket variable. diff --git a/reflex/testing.py b/reflex/testing.py index fc5e90e1b..f9bef2c09 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -323,11 +323,11 @@ class AppHarness: return _shutdown_redis def _start_backend(self, port=0): - if self.app_instance is None or self.app_instance._api is None: + if self.app_instance is None or self.app_instance.api is None: raise RuntimeError("App was not initialized.") self.backend = uvicorn.Server( uvicorn.Config( - app=self.app_instance._api, + app=self.app_instance.api, host="127.0.0.1", port=port, ) diff --git a/tests/units/test_app.py b/tests/units/test_app.py index aef7791e5..a09fde972 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -212,7 +212,7 @@ def test_default_app(app: App): """ assert app.middleware == [HydrateMiddleware()] assert app.style == Style() - assert app._admin_dash is None + assert app.admin_dash is None def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): @@ -357,10 +357,10 @@ def test_initialize_with_admin_dashboard(test_model): Args: test_model: The default model. """ - app = App(_admin_dash=AdminDash(models=[test_model])) - assert app._admin_dash is not None - assert len(app._admin_dash.models) > 0 - assert app._admin_dash.models[0] == test_model + app = App(admin_dash=AdminDash(models=[test_model])) + assert app.admin_dash is not None + assert len(app.admin_dash.models) > 0 + assert app.admin_dash.models[0] == test_model def test_initialize_with_custom_admin_dashboard( @@ -377,12 +377,12 @@ def test_initialize_with_custom_admin_dashboard( """ custom_auth_provider = test_custom_auth_admin() custom_admin = Admin(engine=test_get_engine, auth_provider=custom_auth_provider) - app = App(_admin_dash=AdminDash(models=[test_model_auth], admin=custom_admin)) - assert app._admin_dash is not None - assert app._admin_dash.admin is not None - assert len(app._admin_dash.models) > 0 - assert app._admin_dash.models[0] == test_model_auth - assert app._admin_dash.admin.auth_provider == custom_auth_provider + app = App(admin_dash=AdminDash(models=[test_model_auth], admin=custom_admin)) + assert app.admin_dash is not None + assert app.admin_dash.admin is not None + assert len(app.admin_dash.models) > 0 + assert app.admin_dash.models[0] == test_model_auth + assert app.admin_dash.admin.auth_provider == custom_auth_provider def test_initialize_admin_dashboard_with_view_overrides(test_model): @@ -396,13 +396,13 @@ def test_initialize_admin_dashboard_with_view_overrides(test_model): pass app = App( - _admin_dash=AdminDash( + admin_dash=AdminDash( models=[test_model], view_overrides={test_model: TestModelView} ) ) - assert app._admin_dash is not None - assert app._admin_dash.models == [test_model] - assert app._admin_dash.view_overrides[test_model] == TestModelView + assert app.admin_dash is not None + assert app.admin_dash.models == [test_model] + assert app.admin_dash.view_overrides[test_model] == TestModelView @pytest.mark.asyncio @@ -772,7 +772,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): # The App state must be the "root" of the state tree app = App() app._enable_state() - app._event_namespace.emit = AsyncMock() # type: ignore + app.event_namespace.emit = AsyncMock() # type: ignore current_state = await app.state_manager.get_state(_substate_key(token, state)) data = b"This is binary data" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 3aa6a0167..cf3363770 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1913,7 +1913,7 @@ def mock_app_simple(monkeypatch) -> rx.App: setattr(app_module, CompileVars.APP, app) app._state = TestState - app._event_namespace.emit = CopyingAsyncMock() # type: ignore + app.event_namespace.emit = CopyingAsyncMock() # type: ignore def _mock_get_app(*args, **kwargs): return app_module @@ -2021,9 +2021,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): assert gotten_grandchild_state.value2 == "42" # ensure state update was emitted - assert mock_app._event_namespace is not None - mock_app._event_namespace.emit.assert_called_once() - mcall = mock_app._event_namespace.emit.mock_calls[0] + assert mock_app.event_namespace is not None + mock_app.event_namespace.emit.assert_called_once() + mcall = mock_app.event_namespace.emit.mock_calls[0] assert mcall.args[0] == str(SocketEvent.EVENT) assert mcall.args[1] == StateUpdate( delta={ @@ -2225,8 +2225,8 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): ) ).order == exp_order - assert mock_app._event_namespace is not None - emit_mock = mock_app._event_namespace.emit + assert mock_app.event_namespace is not None + emit_mock = mock_app.event_namespace.emit first_ws_message = emit_mock.mock_calls[0].args[1] assert (