From 62f15a4dbe63bd9a43f2eb1ffa6e7edc162a111b Mon Sep 17 00:00:00 2001 From: Elijah Date: Wed, 18 Dec 2024 19:22:56 +0000 Subject: [PATCH] Allow custom app module in rxconfig --- reflex/config.py | 34 ++++++++++++++++++++++++++++------ reflex/utils/prerequisites.py | 21 ++++++++++++++++++++- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/reflex/config.py b/reflex/config.py index 0579b019f..452e182f7 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -12,6 +12,7 @@ import threading import urllib.parse from importlib.util import find_spec from pathlib import Path +from types import ModuleType from typing import ( TYPE_CHECKING, Any, @@ -82,7 +83,7 @@ class DBConfig(Base): ) @classmethod - def postgresql_psycopg( + def postgresql_psycopg2( cls, database: str, username: str, @@ -90,7 +91,7 @@ class DBConfig(Base): host: str | None = None, port: int | None = 5432, ) -> DBConfig: - """Create an instance with postgresql+psycopg engine. + """Create an instance with postgresql+psycopg2 engine. Args: database: Database name. @@ -103,7 +104,7 @@ class DBConfig(Base): DBConfig instance. """ return cls( - engine="postgresql+psycopg", + engine="postgresql+psycopg2", username=username, password=password, host=host, @@ -604,6 +605,9 @@ class Config(Base): # The name of the app (should match the name of the app directory). app_name: str + # The path to the app module. + app_module_path: Optional[str] = None + # The log level to use. loglevel: constants.LogLevel = constants.LogLevel.DEFAULT @@ -684,9 +688,6 @@ class Config(Base): # Maximum expiration lock time for redis state manager redis_lock_expiration: int = constants.Expiration.LOCK - # Maximum lock time before warning for redis state manager. - redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD - # Token expiration time for redis state manager redis_token_expiration: int = constants.Expiration.TOKEN @@ -726,6 +727,25 @@ class Config(Base): "REDIS_URL is required when using the redis state manager." ) + @property + def app_module(self) -> ModuleType | None: + """Get the app module. + + Returns: + The app module. + """ + + def load_via_spec(path): + module_name = Path(path).stem + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + if self.app_module_path: + return load_via_spec(self.app_module_path) + return None + @property def module(self) -> str: """Get the module name of the app. @@ -733,6 +753,8 @@ class Config(Base): Returns: The module name. """ + if self.app_module: + return f"{Path(self.app_module.__file__).parent.name}.{Path(self.app_module.__file__).stem}" return ".".join([self.app_name, self.app_name]) def update_from_env(self) -> dict[str, Any]: diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index d838c0eea..f181f3b1c 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -291,7 +291,11 @@ def get_app(reload: bool = False) -> ModuleType: ) module = config.module sys.path.insert(0, str(Path.cwd())) - app = __import__(module, fromlist=(constants.CompileVars.APP,)) + app = ( + __import__(module, fromlist=(constants.CompileVars.APP,)) + if not config.app_module + else config.app_module + ) if reload: from reflex.state import reload_state_module @@ -308,6 +312,21 @@ def get_app(reload: bool = False) -> ModuleType: raise +def get_and_validate_app(reload: bool = False): + """Get the app module based on the default config and validate it. + + Args: + reload: Re-import the app module from disk + """ + from reflex.app import App + + app_module = get_app(reload=reload) + app = getattr(app_module, constants.CompileVars.APP) + if not isinstance(app, App): + raise RuntimeError("The app object in rxconfig must be an instance of rx.App.") + return app + + def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: """Get the app module based on the default config after first compiling it.