[ENG-4134]Allow specifying custom app module in rxconfig (#4556)

* Allow custom app module in rxconfig

* what was that pyscopg mess?

* fix another mess

* get this working with relative imports and hot reload

* typing to named tuple

* minor refactor

* revert redis knobs positions

* fix pyright except 1

* fix pyright hopefully

* use the resolved module path

* testing workflow

* move nba-proxy job to counter job

* just cast the type

* fix tests for python 3.9

* darglint

* CR Suggestions for #4556 (#4644)

* reload_dirs: search up from app_module for last directory containing __init__

* Change custom app_module to use an import string

* preserve sys.path entries added while loading rxconfig.py

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
Elijah Ahianyo 2025-01-20 18:12:54 +00:00 committed by GitHub
parent 4da32a122b
commit 268effe62e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 132 additions and 30 deletions

View File

@ -33,7 +33,7 @@ env:
PR_TITLE: ${{ github.event.pull_request.title }} PR_TITLE: ${{ github.event.pull_request.title }}
jobs: jobs:
example-counter: example-counter-and-nba-proxy:
env: env:
OUTPUT_FILE: import_benchmark.json OUTPUT_FILE: import_benchmark.json
timeout-minutes: 30 timeout-minutes: 30
@ -119,6 +119,26 @@ jobs:
--benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}" --benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}"
--branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}" --branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}"
--app-name "counter" --app-name "counter"
- name: Install requirements for nba proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run uv pip install -r requirements.txt
- name: Install additional dependencies for DB access
run: poetry run uv pip install psycopg
- name: Check export --backend-only before init for nba-proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run reflex export --backend-only
- name: Init Website for nba-proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run reflex init --loglevel debug
- name: Run Website and Check for errors
run: |
# Check that npm is home
npm -v
poetry run bash scripts/integration.sh ./reflex-examples/nba-proxy dev
reflex-web: reflex-web:
strategy: strategy:

View File

@ -7,14 +7,13 @@ from concurrent.futures import ThreadPoolExecutor
from reflex import constants from reflex import constants
from reflex.utils import telemetry from reflex.utils import telemetry
from reflex.utils.exec import is_prod_mode from reflex.utils.exec import is_prod_mode
from reflex.utils.prerequisites import get_app from reflex.utils.prerequisites import get_and_validate_app
if constants.CompileVars.APP != "app": if constants.CompileVars.APP != "app":
raise AssertionError("unexpected variable name for 'app'") raise AssertionError("unexpected variable name for 'app'")
telemetry.send("compile") telemetry.send("compile")
app_module = get_app(reload=False) app, app_module = get_and_validate_app(reload=False)
app = getattr(app_module, constants.CompileVars.APP)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages # For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172). # before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages() app._apply_decorated_pages()
@ -30,7 +29,7 @@ if is_prod_mode():
# ensure only "app" is exposed. # ensure only "app" is exposed.
del app_module del app_module
del compile_future del compile_future
del get_app del get_and_validate_app
del is_prod_mode del is_prod_mode
del telemetry del telemetry
del constants del constants

View File

@ -12,6 +12,7 @@ import threading
import urllib.parse import urllib.parse
from importlib.util import find_spec from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from types import ModuleType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -607,6 +608,9 @@ class Config(Base):
# The name of the app (should match the name of the app directory). # The name of the app (should match the name of the app directory).
app_name: str app_name: str
# The path to the app module.
app_module_import: Optional[str] = None
# The log level to use. # The log level to use.
loglevel: constants.LogLevel = constants.LogLevel.DEFAULT loglevel: constants.LogLevel = constants.LogLevel.DEFAULT
@ -729,6 +733,19 @@ class Config(Base):
"REDIS_URL is required when using the redis state manager." "REDIS_URL is required when using the redis state manager."
) )
@property
def app_module(self) -> ModuleType | None:
"""Return the app module if `app_module_import` is set.
Returns:
The app module.
"""
return (
importlib.import_module(self.app_module_import)
if self.app_module_import
else None
)
@property @property
def module(self) -> str: def module(self) -> str:
"""Get the module name of the app. """Get the module name of the app.
@ -736,6 +753,8 @@ class Config(Base):
Returns: Returns:
The module name. The module name.
""" """
if self.app_module is not None:
return self.app_module.__name__
return ".".join([self.app_name, self.app_name]) return ".".join([self.app_name, self.app_name])
def update_from_env(self) -> dict[str, Any]: def update_from_env(self) -> dict[str, Any]:
@ -874,7 +893,7 @@ def get_config(reload: bool = False) -> Config:
return cached_rxconfig.config return cached_rxconfig.config
with _config_lock: with _config_lock:
sys_path = sys.path.copy() orig_sys_path = sys.path.copy()
sys.path.clear() sys.path.clear()
sys.path.append(str(Path.cwd())) sys.path.append(str(Path.cwd()))
try: try:
@ -882,9 +901,14 @@ def get_config(reload: bool = False) -> Config:
return _get_config() return _get_config()
except Exception: except Exception:
# If the module import fails, try to import with the original sys.path. # If the module import fails, try to import with the original sys.path.
sys.path.extend(sys_path) sys.path.extend(orig_sys_path)
return _get_config() return _get_config()
finally: finally:
# Find any entries added to sys.path by rxconfig.py itself.
extra_paths = [
p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd())
]
# Restore the original sys.path. # Restore the original sys.path.
sys.path.clear() sys.path.clear()
sys.path.extend(sys_path) sys.path.extend(extra_paths)
sys.path.extend(orig_sys_path)

View File

@ -1591,7 +1591,7 @@ def get_handler_args(
def fix_events( def fix_events(
events: list[EventHandler | EventSpec] | None, events: list[EventSpec | EventHandler] | None,
token: str, token: str,
router_data: dict[str, Any] | None = None, router_data: dict[str, Any] | None = None,
) -> list[Event]: ) -> list[Event]:

View File

@ -1776,9 +1776,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
except Exception as ex: except Exception as ex:
state._clean() state._clean()
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) event_specs = (
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
event_specs = app_instance.backend_exception_handler(ex) )
if event_specs is None: if event_specs is None:
return StateUpdate() return StateUpdate()
@ -1888,9 +1888,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
except Exception as ex: except Exception as ex:
telemetry.send_error(ex, context="backend") telemetry.send_error(ex, context="backend")
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) event_specs = (
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
event_specs = app_instance.backend_exception_handler(ex) )
yield state._as_state_update( yield state._as_state_update(
handler, handler,
@ -2403,8 +2403,9 @@ class FrontendEventExceptionState(State):
component_stack: The stack trace of the component where the exception occurred. component_stack: The stack trace of the component where the exception occurred.
""" """
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) prerequisites.get_and_validate_app().app.frontend_exception_handler(
app_instance.frontend_exception_handler(Exception(stack)) Exception(stack)
)
class UpdateVarsInternalState(State): class UpdateVarsInternalState(State):
@ -2442,15 +2443,16 @@ class OnLoadInternalState(State):
The list of events to queue for on load handling. The list of events to queue for on load handling.
""" """
# Do not app._compile()! It should be already compiled by now. # Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP) load_events = prerequisites.get_and_validate_app().app.get_load_events(
load_events = app.get_load_events(self.router.page.path) self.router.page.path
)
if not load_events: if not load_events:
self.is_hydrated = True self.is_hydrated = True
return # Fast path for navigation with no on_load events defined. return # Fast path for navigation with no on_load events defined.
self.is_hydrated = False self.is_hydrated = False
return [ return [
*fix_events( *fix_events(
load_events, cast(list[Union[EventSpec, EventHandler]], load_events),
self.router.session.client_token, self.router.session.client_token,
router_data=self.router_data, router_data=self.router_data,
), ),
@ -2609,7 +2611,7 @@ class StateProxy(wrapt.ObjectProxy):
""" """
super().__init__(state_instance) super().__init__(state_instance)
# compile is not relevant to backend logic # compile is not relevant to backend logic
self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) self._self_app = prerequisites.get_and_validate_app().app
self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_actx = None self._self_actx = None
self._self_mutable = False self._self_mutable = False
@ -3702,8 +3704,7 @@ def get_state_manager() -> StateManager:
Returns: Returns:
The state manager. The state manager.
""" """
app = getattr(prerequisites.get_app(), constants.CompileVars.APP) return prerequisites.get_and_validate_app().app.state_manager
return app.state_manager
class MutableProxy(wrapt.ObjectProxy): class MutableProxy(wrapt.ObjectProxy):

View File

@ -240,6 +240,28 @@ def run_backend(
run_uvicorn_backend(host, port, loglevel) run_uvicorn_backend(host, port, loglevel)
def get_reload_dirs() -> list[str]:
"""Get the reload directories for the backend.
Returns:
The reload directories for the backend.
"""
config = get_config()
reload_dirs = [config.app_name]
if config.app_module is not None and config.app_module.__file__:
module_path = Path(config.app_module.__file__).resolve().parent
while module_path.parent.name:
for parent_file in module_path.parent.iterdir():
if parent_file == "__init__.py":
# go up a level to find dir without `__init__.py`
module_path = module_path.parent
break
else:
break
reload_dirs.append(str(module_path))
return reload_dirs
def run_uvicorn_backend(host, port, loglevel: LogLevel): def run_uvicorn_backend(host, port, loglevel: LogLevel):
"""Run the backend in development mode using Uvicorn. """Run the backend in development mode using Uvicorn.
@ -256,7 +278,7 @@ def run_uvicorn_backend(host, port, loglevel: LogLevel):
port=port, port=port,
log_level=loglevel.value, log_level=loglevel.value,
reload=True, reload=True,
reload_dirs=[get_config().app_name], reload_dirs=get_reload_dirs(),
) )
@ -281,7 +303,7 @@ def run_granian_backend(host, port, loglevel: LogLevel):
interface=Interfaces.ASGI, interface=Interfaces.ASGI,
log_level=LogLevels(loglevel.value), log_level=LogLevels(loglevel.value),
reload=True, reload=True,
reload_paths=[Path(get_config().app_name)], reload_paths=get_reload_dirs(),
reload_ignore_dirs=[".web"], reload_ignore_dirs=[".web"],
).serve() ).serve()
except ImportError: except ImportError:

View File

@ -17,11 +17,12 @@ import stat
import sys import sys
import tempfile import tempfile
import time import time
import typing
import zipfile import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Callable, List, Optional from typing import Callable, List, NamedTuple, Optional
import httpx import httpx
import typer import typer
@ -42,9 +43,19 @@ from reflex.utils.exceptions import (
from reflex.utils.format import format_library_name from reflex.utils.format import format_library_name
from reflex.utils.registry import _get_npm_registry from reflex.utils.registry import _get_npm_registry
if typing.TYPE_CHECKING:
from reflex.app import App
CURRENTLY_INSTALLING_NODE = False CURRENTLY_INSTALLING_NODE = False
class AppInfo(NamedTuple):
"""A tuple containing the app instance and module."""
app: App
module: ModuleType
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Template: class Template:
"""A template for a Reflex app.""" """A template for a Reflex app."""
@ -291,8 +302,11 @@ def get_app(reload: bool = False) -> ModuleType:
) )
module = config.module module = config.module
sys.path.insert(0, str(Path.cwd())) sys.path.insert(0, str(Path.cwd()))
app = __import__(module, fromlist=(constants.CompileVars.APP,)) app = (
__import__(module, fromlist=(constants.CompileVars.APP,))
if not config.app_module
else config.app_module
)
if reload: if reload:
from reflex.state import reload_state_module from reflex.state import reload_state_module
@ -308,6 +322,29 @@ def get_app(reload: bool = False) -> ModuleType:
raise raise
def get_and_validate_app(reload: bool = False) -> AppInfo:
"""Get the app instance based on the default config and validate it.
Args:
reload: Re-import the app module from disk
Returns:
The app instance and the app module.
Raises:
RuntimeError: If the app instance is not an instance of rx.App.
"""
from reflex.app import App
app_module = get_app(reload=reload)
app = getattr(app_module, constants.CompileVars.APP)
if not isinstance(app, App):
raise RuntimeError(
"The app instance in the specified app_module_import in rxconfig must be an instance of rx.App."
)
return AppInfo(app=app, module=app_module)
def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
"""Get the app module based on the default config after first compiling it. """Get the app module based on the default config after first compiling it.
@ -318,8 +355,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
Returns: Returns:
The compiled app based on the default config. The compiled app based on the default config.
""" """
app_module = get_app(reload=reload) app, app_module = get_and_validate_app(reload=reload)
app = getattr(app_module, constants.CompileVars.APP)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages # For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172). # before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages() app._apply_decorated_pages()