From 19f6dc5edccf4d274ee24aa149ecb75dfd1ac218 Mon Sep 17 00:00:00 2001 From: KronosDev-Pro Date: Sat, 23 Nov 2024 15:32:10 +0000 Subject: [PATCH] [IMPL] - add `get_backend_bind()` & `shutdown()` --- reflex/server/base.py | 16 ++++++++++++++-- reflex/server/granian.py | 13 +++++++++++-- reflex/server/gunicorn.py | 23 ++++++++++++++++++++++- reflex/server/uvicorn.py | 20 ++++++++++++++++++-- reflex/utils/exec.py | 6 ++---- 5 files changed, 67 insertions(+), 11 deletions(-) diff --git a/reflex/server/base.py b/reflex/server/base.py index 012511233..b19a6acb5 100644 --- a/reflex/server/base.py +++ b/reflex/server/base.py @@ -7,7 +7,7 @@ from abc import abstractmethod from dataclasses import Field, dataclass from dataclasses import field as dc_field from pathlib import Path -from typing import Any, Callable, Sequence +from typing import Any, Callable, Sequence, ClassVar from reflex import constants from reflex.constants.base import Env, LogLevel @@ -156,7 +156,9 @@ def field_( class CustomBackendServer: """BackendServer base.""" - _app_uri: str = field_(default="", metadata_cli=None, exclude=True) + _env: ClassVar[Env] = field_(default=Env.DEV, metadata_cli=None, exclude=True, repr = False, init = False) + _app: ClassVar[Any] = field_(default=None, metadata_cli=None, exclude=True, repr = False, init = False) + _app_uri: ClassVar[str] = field_(default="", metadata_cli=None, exclude=True, repr = False, init = False) @staticmethod def get_app_module(for_granian_target: bool = False, add_extra_api: bool = False): @@ -246,6 +248,11 @@ class CustomBackendServer: return False + @abstractmethod + def get_backend_bind(self) -> tuple[str, int]: + """Return the backend host and port""" + raise NotImplementedError() + @abstractmethod def check_import(self): """Check package importation.""" @@ -265,3 +272,8 @@ class CustomBackendServer: def run_dev(self): """Run in development mode.""" raise NotImplementedError() + + @abstractmethod + async def shutdown(self): + """Shutdown the backend server.""" + raise NotImplementedError() diff --git a/reflex/server/granian.py b/reflex/server/granian.py index 2a76d8c53..547929398 100644 --- a/reflex/server/granian.py +++ b/reflex/server/granian.py @@ -191,6 +191,9 @@ class GranianBackendServer(CustomBackendServer): default=None, metadata_cli=CliType.default("--pid-file {value}") ) + def get_backend_bind(self) -> tuple[str, int]: + return self.address, self.port + def check_import(self): """Check package importation.""" from importlib.util import find_spec @@ -218,6 +221,7 @@ class GranianBackendServer(CustomBackendServer): self.address = host self.port = port self.interface = "asgi" # NOTE: prevent obvious error + self._env = env if env == Env.PROD: if self.workers == self.get_fields()["workers"].default: @@ -273,7 +277,7 @@ class GranianBackendServer(CustomBackendServer): "http2_max_headers_size", "http2_max_send_buffer_size", ) - Granian( + self._app = Granian( **{ **{ key: value @@ -301,4 +305,9 @@ class GranianBackendServer(CustomBackendServer): self.http2_max_send_buffer_size, ), } - ).serve() + ) + self._app.serve() + + async def shutdown(self): + if self._app and self._env == Env.DEV: + self._app.shutdown() diff --git a/reflex/server/gunicorn.py b/reflex/server/gunicorn.py index 4cdb150fa..be0b62729 100644 --- a/reflex/server/gunicorn.py +++ b/reflex/server/gunicorn.py @@ -278,6 +278,11 @@ class GunicornBackendServer(CustomBackendServer): default="drop", metadata_cli=CliType.default("--header-map {value}") ) + def get_backend_bind(self) -> tuple[str, int]: + """Return the backend host and port""" + host, port = self.bind[0].split(":") + return host, int(port) + def check_import(self): """Check package importation.""" from importlib.util import find_spec @@ -303,6 +308,7 @@ class GunicornBackendServer(CustomBackendServer): self._app_uri = f"{self.get_app_module()}()" self.loglevel = loglevel.value # type: ignore self.bind = [f"{host}:{port}"] + self._env = env if env == Env.PROD: if self.workers == self.get_fields()["workers"].default: @@ -370,5 +376,20 @@ class GunicornBackendServer(CustomBackendServer): def load(self): return gunicorn_import_app(self._app_uri) + + def stop(self): + from gunicorn.arbiter import Arbiter - StandaloneApplication(app_uri=self._app_uri, options=options_).run() + Arbiter(self).stop() + + self._app = StandaloneApplication(app_uri=self._app_uri, options=options_) + self._app.run() + + async def shutdown(self): + """Shutdown the backend server.""" + if self._app and self._env == Env.DEV: + self._app.stop() # type: ignore + + # TODO: complicated because currently `*BackendServer` don't execute the server command, he just create it + # if self._env == Env.PROD: + # pass diff --git a/reflex/server/uvicorn.py b/reflex/server/uvicorn.py index 8d167eacb..7f96c6070 100644 --- a/reflex/server/uvicorn.py +++ b/reflex/server/uvicorn.py @@ -183,6 +183,10 @@ class UvicornBackendServer(CustomBackendServer): metadata_cli=CliType.default("--h11-max-incomplete-event-size {value}"), ) + def get_backend_bind(self) -> tuple[str, int]: + """Return the backend host and port""" + return self.host, self.port + def check_import(self): """Check package importation.""" from importlib.util import find_spec @@ -211,6 +215,7 @@ class UvicornBackendServer(CustomBackendServer): self.log_level = loglevel.value self.host = host self.port = port + self._env = env if env == Env.PROD: if self.workers == self.get_fields()["workers"].default: @@ -250,6 +255,17 @@ class UvicornBackendServer(CustomBackendServer): if not self.is_default_value(key, value) } - Server( + self._app = Server( config=Config(**options_, app=self._app_uri), - ).run() + ) + self._app.run() + + async def shutdown(self): + """Shutdown the backend server.""" + if self._app and self._env == Env.DEV: + self._app.shutdown() # type: ignore + + # TODO: hard because currently `*BackendServer` don't execute the server command, he just create it + # if self._env == Env.PROD: + # pass + diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 4daaa69b8..12fa555e6 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -178,7 +178,6 @@ def run_frontend_prod(root: Path, port: str, backend_present=True): ) -### REWORK <-- def run_backend( host: str, port: int, @@ -237,12 +236,11 @@ def run_backend_prod( run=True, show_logs=True, env={ - environment.REFLEX_SKIP_COMPILE.name: "true" - }, # skip compile for prod backend + environment.REFLEX_SKIP_COMPILE.name: "true" # skip compile for prod backend + }, ) -### REWORK--> def output_system_info(): """Show system information if the loglevel is in DEBUG.""" if console._LOG_LEVEL > constants.LogLevel.DEBUG: