Clean up config parameters (#1591)

This commit is contained in:
Nikhil Rao 2023-08-18 14:22:20 -07:00 committed by GitHub
parent 042710ca91
commit 0beb7a409f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 200 additions and 397 deletions

View File

@ -43,7 +43,6 @@ watchdog = "^2.3.1"
watchfiles = "^0.19.0"
websockets = "^10.4"
starlette-admin = "^0.9.0"
python-dotenv = "^0.13.0"
importlib-metadata = {version = "^6.7.0", python = ">=3.7, <3.8"}
alembic = "^1.11.1"
platformdirs = "^3.10.0"

View File

@ -18,7 +18,6 @@ from .components.graphing.victory import data as data
from .config import Config as Config
from .config import DBConfig as DBConfig
from .constants import Env as Env
from .constants import Transports as Transports
from .event import EVENT_ARG as EVENT_ARG
from .event import EventChain as EventChain
from .event import FileUpload as upload_files

View File

@ -144,10 +144,10 @@ class App(Base):
self.sio = AsyncServer(
async_mode="asgi",
cors_allowed_origins="*"
if config.cors_allowed_origins == constants.CORS_ALLOWED_ORIGINS
if config.cors_allowed_origins == ["*"]
else config.cors_allowed_origins,
cors_credentials=config.cors_credentials,
max_http_buffer_size=config.polling_max_http_buffer_size,
cors_credentials=True,
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
ping_interval=constants.PING_INTERVAL,
ping_timeout=constants.PING_TIMEOUT,
)
@ -438,13 +438,14 @@ class App(Base):
def setup_admin_dash(self):
"""Setup the admin dash."""
# Get the config.
config = get_config()
if config.admin_dash and config.admin_dash.models:
# Get the admin dash.
admin_dash = self.admin_dash
if admin_dash and admin_dash.models:
# Build the admin dashboard
admin = (
config.admin_dash.admin
if config.admin_dash.admin
admin_dash.admin
if admin_dash.admin
else Admin(
engine=Model.get_db_engine(),
title="Reflex Admin Dashboard",
@ -452,8 +453,8 @@ class App(Base):
)
)
for model in config.admin_dash.models:
view = config.admin_dash.view_overrides.get(model, ModelView)
for model in admin_dash.models:
view = admin_dash.view_overrides.get(model, ModelView)
admin.add_view(view(model))
admin.mount_to(self.api)

View File

@ -113,7 +113,6 @@ def _compile_page(
state_name=state.get_name(),
hooks=component.get_hooks(),
render=component.render(),
transports=constants.Transports.POLLING_WEBSOCKET.get_transports(),
err_comp=connect_error_component.render() if connect_error_component else None,
)

View File

@ -39,7 +39,6 @@ class ReflexJinjaEnvironment(Environment):
"toggle_color_mode": constants.TOGGLE_COLOR_MODE,
"use_color_mode": constants.USE_COLOR_MODE,
"hydrate": constants.HYDRATE,
"db_url": constants.DB_URL,
}

View File

@ -8,11 +8,9 @@ import sys
import urllib.parse
from typing import Any, Dict, List, Optional
from dotenv import load_dotenv
from reflex import constants
from reflex.admin import AdminDash
from reflex.base import Base
from reflex.utils import console
class DBConfig(Base):
@ -126,130 +124,135 @@ class DBConfig(Base):
class Config(Base):
"""A Reflex config."""
class Config:
"""Pydantic config for the config."""
validate_assignment = True
# The name of the app.
app_name: str
# The username.
username: Optional[str] = None
# The log level to use.
loglevel: constants.LogLevel = constants.LogLevel.INFO
# The frontend port.
frontend_port: str = constants.FRONTEND_PORT
# The port to run the frontend on.
frontend_port: int = 3000
# The backend port.
backend_port: str = constants.BACKEND_PORT
# The port to run the backend on.
backend_port: int = 8000
# The backend host.
backend_host: str = constants.BACKEND_HOST
# The backend url the frontend will connect to.
api_url: str = f"http://localhost:{backend_port}"
# The backend API url.
api_url: str = constants.API_URL
# The url the frontend will be hosted on.
deploy_url: Optional[str] = f"http://localhost:{frontend_port}"
# The deploy url.
deploy_url: Optional[str] = constants.DEPLOY_URL
# The url the backend will be hosted on.
backend_host: str = "0.0.0.0"
# The database url.
db_url: Optional[str] = constants.DB_URL
# The database config.
db_config: Optional[DBConfig] = None
db_url: Optional[str] = "sqlite:///reflex.db"
# The redis url.
redis_url: Optional[str] = constants.REDIS_URL
redis_url: Optional[str] = None
# Telemetry opt-in.
telemetry_enabled: bool = True
# The rxdeploy url.
rxdeploy_url: Optional[str] = None
# The environment mode.
env: constants.Env = constants.Env.DEV
# Additional frontend packages to install.
frontend_packages: List[str] = []
# The bun path
bun_path: str = constants.BUN_PATH
# The Admin Dash.
admin_dash: Optional[AdminDash] = None
# Backend transport methods.
backend_transports: Optional[
constants.Transports
] = constants.Transports.WEBSOCKET_POLLING
bun_path: str = constants.DEFAULT_BUN_PATH
# List of origins that are allowed to connect to the backend API.
cors_allowed_origins: Optional[list] = constants.CORS_ALLOWED_ORIGINS
# Whether credentials (cookies, authentication) are allowed in requests to the backend API.
cors_credentials: Optional[bool] = True
# The maximum size of a message when using the polling backend transport.
polling_max_http_buffer_size: Optional[int] = constants.POLLING_MAX_HTTP_BUFFER_SIZE
# Dotenv file path.
env_path: Optional[str] = constants.DOT_ENV_FILE
# Whether to override OS environment variables.
override_os_envs: Optional[bool] = True
cors_allowed_origins: List[str] = ["*"]
# Tailwind config.
tailwind: Optional[Dict[str, Any]] = None
# Timeout when launching the gunicorn server.
timeout: int = constants.TIMEOUT
# Timeout when launching the gunicorn server. TODO(rename this to backend_timeout?)
timeout: int = 120
# Whether to enable or disable nextJS gzip compression.
next_compression: bool = True
# The event namespace for ws connection
event_namespace: Optional[str] = constants.EVENT_NAMESPACE
event_namespace: Optional[str] = None
# Params to remove eventually.
# Additional frontend packages to install. (TODO: these can be inferred from the imports)
frontend_packages: List[str] = []
# For rest are for deploy only.
# The rxdeploy url.
rxdeploy_url: Optional[str] = None
# The username.
username: Optional[str] = None
def __init__(self, *args, **kwargs):
"""Initialize the config values.
If db_url is not provided gets it from db_config.
Args:
*args: The args to pass to the Pydantic init method.
**kwargs: The kwargs to pass to the Pydantic init method.
"""
if "db_url" not in kwargs and "db_config" in kwargs:
kwargs["db_url"] = kwargs["db_config"].get_url()
super().__init__(*args, **kwargs)
# set overriden class attribute values as os env variables to avoid losing them
for key, value in dict(self).items():
key = key.upper()
if (
key.startswith("_")
or key in os.environ
or (value is None and key != "DB_URL")
):
continue
os.environ[key] = str(value)
# Check for deprecated values.
self.check_deprecated_values(**kwargs)
# Avoid overriding if env_path is not provided or does not exist
if self.env_path is not None and os.path.isfile(self.env_path):
load_dotenv(self.env_path, override=self.override_os_envs) # type: ignore
# Recompute constants after loading env variables
importlib.reload(constants)
# Recompute instance attributes
self.recompute_field_values()
# Update the config from environment variables.
self.update_from_env()
def recompute_field_values(self):
"""Recompute instance field values to reflect new values after reloading
constant values.
@staticmethod
def check_deprecated_values(**kwargs):
"""Check for deprecated config values.
Args:
**kwargs: The kwargs passed to the config.
Raises:
ValueError: If a deprecated config value is found.
"""
for field in self.get_fields():
try:
if field.startswith("_"):
continue
setattr(self, field, getattr(constants, f"{field.upper()}"))
except AttributeError:
pass
if "db_config" in kwargs:
raise ValueError("db_config is deprecated - use db_url instead")
if "admin_dash" in kwargs:
raise ValueError(
"admin_dash is deprecated in the config - pass it as a param to rx.App instead"
)
if "env_path" in kwargs:
raise ValueError(
"env_path is deprecated - use environment variables instead"
)
def update_from_env(self):
"""Update the config from environment variables.
Raises:
ValueError: If an environment variable is set to an invalid type.
"""
# Iterate over the fields.
for key, field in self.__fields__.items():
# The env var name is the key in uppercase.
env_var = os.environ.get(key.upper())
# If the env var is set, override the config value.
if env_var is not None:
console.info(
f"Overriding config value {key} with env var {key.upper()}={env_var}"
)
# Convert the env var to the expected type.
try:
env_var = field.type_(env_var)
except ValueError:
console.error(
f"Could not convert {key.upper()}={env_var} to type {field.type_}"
)
raise
# Set the value.
setattr(self, key, env_var)
def get_event_namespace(self) -> Optional[str]:
"""Get the websocket event namespace.

View File

@ -6,7 +6,6 @@ import platform
import re
from enum import Enum
from types import SimpleNamespace
from typing import Any, Type
from platformdirs import PlatformDirs
@ -19,31 +18,6 @@ except ImportError:
IS_WINDOWS = platform.system() == "Windows"
def get_value(key: str, default: Any = None, type_: Type = str) -> Type:
"""Get the value for the constant.
Obtain os env value and cast non-string types into
their original types.
Args:
key: constant name.
default: default value if key doesn't exist.
type_: the type of the constant.
Returns:
the value of the constant in its designated type
"""
value = os.getenv(key, default)
try:
if value and type_ != str:
value = eval(value)
except Exception:
pass
finally:
# Special case for db_url expects None to be a valid input when
# user explicitly overrides db_url as None
return value if value != "None" else None # noqa B012
# App names and versions.
# The name of the Reflex package.
MODULE_NAME = "reflex"
@ -85,8 +59,6 @@ MIN_BUN_VERSION = "0.7.0"
BUN_ROOT_PATH = os.path.join(REFLEX_DIR, ".bun")
# Default bun path.
DEFAULT_BUN_PATH = os.path.join(BUN_ROOT_PATH, "bin", "bun")
# The bun path.
BUN_PATH = get_value("BUN_PATH", DEFAULT_BUN_PATH)
# URL to bun install script.
BUN_INSTALL_URL = "https://bun.sh/install"
@ -161,29 +133,6 @@ REFLEX_JSON = os.path.join(WEB_DIR, "reflex.json")
# The env json file.
ENV_JSON = os.path.join(WEB_DIR, "env.json")
# Commands to run the app.
DOT_ENV_FILE = ".env"
# The frontend default port.
FRONTEND_PORT = get_value("FRONTEND_PORT", "3000")
# The backend default port.
BACKEND_PORT = get_value("BACKEND_PORT", "8000")
# The backend api url.
API_URL = get_value("API_URL", "http://localhost:8000")
# The deploy url
DEPLOY_URL = get_value("DEPLOY_URL")
# Default host in dev mode.
BACKEND_HOST = get_value("BACKEND_HOST", "0.0.0.0")
# The default timeout when launching the gunicorn server.
TIMEOUT = get_value("TIMEOUT", 120, type_=int)
# The command to run the backend in production mode.
RUN_BACKEND_PROD = f"gunicorn --worker-class uvicorn.workers.UvicornH11Worker --preload --timeout {TIMEOUT} --log-level critical".split()
RUN_BACKEND_PROD_WINDOWS = f"uvicorn --timeout-keep-alive {TIMEOUT}".split()
# Socket.IO web server
PING_INTERVAL = 25
PING_TIMEOUT = 120
# flag to make the engine print all the SQL statements it executes
SQLALCHEMY_ECHO = get_value("SQLALCHEMY_ECHO", False, type_=bool)
# Compiler variables.
# The extension for compiled Javascript files.
JS_EXT = ".js"
@ -223,12 +172,6 @@ SETTER_PREFIX = "set_"
FRONTEND_ZIP = "frontend.zip"
# The name of the backend zip during deployment.
BACKEND_ZIP = "backend.zip"
# The name of the sqlite database.
DB_NAME = os.getenv("DB_NAME", "reflex.db")
# The sqlite url.
DB_URL = get_value("DB_URL", f"sqlite:///{DB_NAME}")
# The redis url
REDIS_URL = get_value("REDIS_URL")
# The default title to show for Reflex apps.
DEFAULT_TITLE = "Reflex App"
# The default description to show for Reflex apps.
@ -252,8 +195,6 @@ OLD_CONFIG_FILE = f"pcconfig{PY_EXT}"
PRODUCTION_BACKEND_URL = "https://{username}-{app_name}.api.pynecone.app"
# Token expiration time in seconds.
TOKEN_EXPIRATION = 60 * 60
# The event namespace for websocket
EVENT_NAMESPACE = get_value("EVENT_NAMESPACE")
# Testing variables.
# Testing os env set by pytest when running a test case.
@ -351,36 +292,6 @@ class SocketEvent(Enum):
return str(self.value)
class Transports(Enum):
"""Socket transports used by the reflex backend API."""
POLLING_WEBSOCKET = "['polling', 'websocket']"
WEBSOCKET_POLLING = "['websocket', 'polling']"
WEBSOCKET_ONLY = "['websocket']"
POLLING_ONLY = "['polling']"
def __str__(self) -> str:
"""Get the string representation of the transports.
Returns:
The transports string.
"""
return str(self.value)
def get_transports(self) -> str:
"""Get the transports config for the backend.
Returns:
The transports config for the backend.
"""
# Import here to avoid circular imports.
from reflex.config import get_config
# Get the API URL from the config.
config = get_config()
return str(config.backend_transports)
class RouteArgType(SimpleNamespace):
"""Type of dynamic route arg extracted from URI route."""
@ -430,8 +341,9 @@ COLOR_MODE = "colorMode"
TOGGLE_COLOR_MODE = "toggleColorMode"
# Server socket configuration variables
CORS_ALLOWED_ORIGINS = get_value("CORS_ALLOWED_ORIGINS", ["*"], list)
POLLING_MAX_HTTP_BUFFER_SIZE = 1000 * 1000
PING_INTERVAL = 25
PING_TIMEOUT = 120
# Alembic migrations
ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini")

View File

@ -1,5 +1,6 @@
"""Database built into Reflex."""
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Optional
@ -40,13 +41,13 @@ def get_engine(url: Optional[str] = None):
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first."
)
echo_db_query = False
if conf.env == constants.Env.DEV and constants.SQLALCHEMY_ECHO:
echo_db_query = True
# Print the SQL queries if the log level is INFO or lower.
echo_db_query = os.environ.get("SQLALCHEMY_ECHO") == "True"
return sqlmodel.create_engine(
url,
echo=echo_db_query,
connect_args={"check_same_thread": False} if conf.admin_dash else {},
# Needed for the admin dash.
connect_args={"check_same_thread": False},
)

View File

@ -15,6 +15,9 @@ from reflex.utils import build, console, exec, prerequisites, processes, telemet
# Create the app.
cli = typer.Typer(add_completion=False)
# Get the config.
config = get_config()
def version(value: bool):
"""Get the Reflex version.
@ -48,13 +51,13 @@ def main(
@cli.command()
def init(
name: str = typer.Option(
None, metavar="APP_NAME", help="The name of the app to be initialized."
None, metavar="APP_NAME", help="The name of the app to initialize."
),
template: constants.Template = typer.Option(
constants.Template.DEFAULT, help="The template to initialize the app with."
),
loglevel: constants.LogLevel = typer.Option(
console.LOG_LEVEL, help="The log level to use."
config.loglevel, help="The log level to use."
),
):
"""Initialize a new Reflex app in the current directory."""
@ -75,7 +78,6 @@ def init(
prerequisites.migrate_to_reflex()
# Set up the app directory, only if the config doesn't exist.
config = get_config()
if not os.path.exists(constants.CONFIG_FILE):
prerequisites.create_config(app_name)
prerequisites.initialize_app_directory(app_name, template)
@ -94,17 +96,23 @@ def init(
@cli.command()
def run(
env: constants.Env = typer.Option(
get_config().env, help="The environment to run the app in."
constants.Env.DEV, help="The environment to run the app in."
),
frontend: bool = typer.Option(
False, "--frontend-only", help="Execute only frontend."
),
backend: bool = typer.Option(False, "--backend-only", help="Execute only backend."),
frontend_port: str = typer.Option(None, help="Specify a different frontend port."),
backend_port: str = typer.Option(None, help="Specify a different backend port."),
backend_host: str = typer.Option(None, help="Specify the backend host."),
frontend_port: str = typer.Option(
config.frontend_port, help="Specify a different frontend port."
),
backend_port: str = typer.Option(
config.backend_port, help="Specify a different backend port."
),
backend_host: str = typer.Option(
config.backend_host, help="Specify the backend host."
),
loglevel: constants.LogLevel = typer.Option(
console.LOG_LEVEL, help="The log level to use."
config.loglevel, help="The log level to use."
),
):
"""Run the app in the current directory."""
@ -114,20 +122,6 @@ def run(
# Show system info
exec.output_system_info()
# Set ports as os env variables to take precedence over config and
# .env variables(if override_os_envs flag in config is set to False).
build.set_os_env(
frontend_port=frontend_port,
backend_port=backend_port,
backend_host=backend_host,
)
# Get the ports from the config.
config = get_config()
frontend_port = config.frontend_port if frontend_port is None else frontend_port
backend_port = config.backend_port if backend_port is None else backend_port
backend_host = config.backend_host if backend_host is None else backend_host
# If no --frontend-only and no --backend-only, then turn on frontend and backend both
if not frontend and not backend:
frontend = True
@ -147,9 +141,6 @@ def run(
console.rule("[bold]Starting Reflex App")
app = prerequisites.get_app()
# Check the admin dashboard settings.
prerequisites.check_admin_settings()
# Warn if schema is not up to date.
prerequisites.check_schema_up_to_date()
@ -199,9 +190,6 @@ def deploy(
# Show system info
exec.output_system_info()
# Get the app config.
config = get_config()
# Check if the deploy url is set.
if config.rxdeploy_url is None:
console.info("This feature is coming soon!")
@ -264,7 +252,6 @@ def export(
build.setup_frontend(Path.cwd())
# Export the app.
config = get_config()
build.export(
backend=backend,
frontend=frontend,
@ -294,7 +281,7 @@ db_cli = typer.Typer()
def db_init():
"""Create database schema and migration configuration."""
# Check the database url.
if get_config().db_url is None:
if config.db_url is None:
console.error("db_url is not configured, cannot initialize.")
return

View File

@ -41,7 +41,7 @@ def set_reflex_project_hash():
update_json_file(constants.REFLEX_JSON, {"project_hash": project_hash})
def set_environment_variables():
def set_env_json():
"""Write the upload url to a REFLEX_JSON."""
update_json_file(
constants.ENV_JSON,
@ -102,13 +102,15 @@ def export(
# Remove the static folder.
path_ops.rm(constants.WEB_STATIC_DIR)
# Generate the sitemap file.
# The export command to run.
command = "export"
if deploy_url is not None:
generate_sitemap_config(deploy_url)
command = "export-sitemap"
if frontend:
# Generate a sitemap if a deploy URL is provided.
if deploy_url is not None:
generate_sitemap_config(deploy_url)
command = "export-sitemap"
checkpoints = [
"Linting and checking ",
"Compiled successfully",
@ -188,7 +190,7 @@ def setup_frontend(
)
# Set the environment variables in client (env.json).
set_environment_variables()
set_env_json()
# Disable the Next telemetry.
if disable_telemetry:

View File

@ -59,7 +59,7 @@ def run_frontend(
# Run the frontend in development mode.
console.rule("[bold green]App Running")
os.environ["PORT"] = get_config().frontend_port if port is None else port
os.environ["PORT"] = str(get_config().frontend_port if port is None else port)
run_process_and_launch_url([prerequisites.get_package_manager(), "run", "dev"])
@ -74,7 +74,7 @@ def run_frontend_prod(
port: The port to run the frontend on.
"""
# Set the port.
os.environ["PORT"] = get_config().frontend_port if port is None else port
os.environ["PORT"] = str(get_config().frontend_port if port is None else port)
# Run the frontend in production mode.
console.rule("[bold green]App Running")
@ -129,9 +129,12 @@ def run_backend_prod(
loglevel: The log level.
"""
num_workers = processes.get_num_workers()
config = get_config()
RUN_BACKEND_PROD = f"gunicorn --worker-class uvicorn.workers.UvicornH11Worker --preload --timeout {config.timeout} --log-level critical".split()
RUN_BACKEND_PROD_WINDOWS = f"uvicorn --timeout-keep-alive {config.timeout}".split()
command = (
[
*constants.RUN_BACKEND_PROD_WINDOWS,
*RUN_BACKEND_PROD_WINDOWS,
"--host",
host,
"--port",
@ -140,7 +143,7 @@ def run_backend_prod(
]
if constants.IS_WINDOWS
else [
*constants.RUN_BACKEND_PROD,
*RUN_BACKEND_PROD,
"--bind",
f"{host}:{port}",
"--threads",
@ -176,7 +179,6 @@ def output_system_info():
dependencies.extend(
[
f"[NVM {constants.NVM_VERSION} (Expected: {constants.NVM_VERSION}) (PATH: {constants.NVM_PATH})]",
f"[Bun {prerequisites.get_bun_version()} (Expected: {constants.BUN_VERSION}) (PATH: {constants.BUN_PATH})]",
],
)
else:
@ -202,4 +204,3 @@ def output_system_info():
console.debug(f"Using package executer at: {prerequisites.get_package_manager()}")
if system != "Windows":
console.debug(f"Unzip path: {path_ops.which('unzip')}")
# exit()

View File

@ -482,23 +482,6 @@ def initialize_frontend_dependencies():
initialize_web_directory()
def check_admin_settings():
"""Check if admin settings are set and valid for logging in cli app."""
admin_dash = get_config().admin_dash
if admin_dash:
if not admin_dash.models:
console.log(
f"[yellow][Admin Dashboard][/yellow] :megaphone: Admin dashboard enabled, but no models defined in [bold magenta]rxconfig.py[/bold magenta]."
)
else:
console.log(
f"[yellow][Admin Dashboard][/yellow] Admin enabled, building admin dashboard."
)
console.log(
"Admin dashboard running at: [bold green]http://localhost:8000/admin[/bold green]"
)
def check_db_initialized() -> bool:
"""Check if the database migrations are initialized.

View File

@ -9,13 +9,11 @@ import signal
import subprocess
from concurrent import futures
from typing import Callable, List, Optional, Tuple, Union
from urllib.parse import urlparse
import psutil
import typer
from reflex import constants
from reflex.config import get_config
from reflex.utils import console, prerequisites
@ -37,19 +35,6 @@ def get_num_workers() -> int:
return 1 if prerequisites.get_redis() is None else (os.cpu_count() or 1) * 2 + 1
def get_api_port() -> int:
"""Get the API port.
Returns:
The API port.
"""
port = urlparse(get_config().api_url).port
if port is None:
port = urlparse(constants.API_URL).port
assert port is not None
return port
def get_process_on_port(port) -> Optional[psutil.Process]:
"""Get the process on the given port.

View File

@ -8,7 +8,6 @@ from typing import Dict, Generator, List, Set, Union
import pytest
import reflex as rx
from reflex import constants
from reflex.app import App
from reflex.event import EventSpec
@ -392,7 +391,7 @@ def base_config_values() -> Dict:
Returns:
Dictionary of base config values
"""
return {"app_name": "app", "db_url": constants.DB_URL, "env": rx.Env.DEV}
return {"app_name": "app"}
@pytest.fixture

View File

@ -83,7 +83,7 @@ def redundant_test_state() -> Type[State]:
return RedundantTestState
@pytest.fixture()
@pytest.fixture(scope="session")
def test_model() -> Type[Model]:
"""A default model.
@ -91,13 +91,13 @@ def test_model() -> Type[Model]:
A default model.
"""
class TestModel(Model):
class TestModel(Model, table=True): # type: ignore
pass
return TestModel
@pytest.fixture()
@pytest.fixture(scope="session")
def test_model_auth() -> Type[Model]:
"""A default model.
@ -105,7 +105,7 @@ def test_model_auth() -> Type[Model]:
A default model.
"""
class TestModelAuth(Model):
class TestModelAuth(Model, table=True): # type: ignore
"""A test model with auth."""
pass

View File

@ -1,138 +1,73 @@
import os
from typing import Dict
import pytest
import reflex as rx
from reflex import constants
from reflex.config import DBConfig, get_config
from reflex.constants import get_value
from reflex.config import get_config
@pytest.fixture
def config_no_db_url_values(base_config_values) -> Dict:
"""Create config values with no db_url.
def test_requires_app_name():
"""Test that a config requires an app_name."""
with pytest.raises(ValueError):
rx.Config()
def test_set_app_name(base_config_values):
"""Test that the app name is set to the value passed in.
Args:
base_config_values: Base config fixture.
Returns:
Config values.
base_config_values: Config values.
"""
base_config_values.pop("db_url")
return base_config_values
@pytest.fixture(autouse=True)
def config_empty_db_url_values(base_config_values):
"""Create config values with empty db_url.
Args:
base_config_values: Base config values fixture.
Yields:
Config values
"""
base_config_values["db_url"] = None
yield base_config_values
os.environ.pop("DB_URL", None)
@pytest.fixture
def config_none_db_url_values(base_config_values):
"""Create config values with None (string) db_url.
Args:
base_config_values: Base config values fixture.
Yields:
Config values
"""
base_config_values["db_url"] = "None"
yield base_config_values
os.environ.pop("DB_URL")
def test_config_db_url(base_config_values):
"""Test defined db_url is not changed.
Args:
base_config_values: base_config_values fixture.
"""
os.environ.pop("DB_URL")
config = rx.Config(**base_config_values)
assert config.db_url == base_config_values["db_url"]
def test_default_db_url(config_no_db_url_values):
"""Test that db_url is assigned the default value if not passed.
Args:
config_no_db_url_values: Config values with no db_url defined.
"""
config = rx.Config(**config_no_db_url_values)
assert config.db_url == constants.DB_URL
def test_empty_db_url(config_empty_db_url_values):
"""Test that db_url is not automatically assigned if an empty value is defined.
Args:
config_empty_db_url_values: Config values with empty db_url.
"""
config = rx.Config(**config_empty_db_url_values)
assert config.db_url is None
def test_none_db_url(config_none_db_url_values):
"""Test that db_url is set 'None' (string) assigned if an 'None' (string) value is defined.
Args:
config_none_db_url_values: Config values with None (string) db_url.
"""
config = rx.Config(**config_none_db_url_values)
assert config.db_url == "None"
def test_db_url_precedence(base_config_values, sqlite_db_config_values):
"""Test that db_url is not overwritten when db_url is defined.
Args:
base_config_values: config values that include db_ur.
sqlite_db_config_values: DB config values.
"""
db_config = DBConfig(**sqlite_db_config_values)
base_config_values["db_config"] = db_config
config = rx.Config(**base_config_values)
assert config.db_url == base_config_values["db_url"]
def test_db_url_from_db_config(config_no_db_url_values, sqlite_db_config_values):
"""Test db_url generation from db_config.
Args:
config_no_db_url_values: Config values with no db_url.
sqlite_db_config_values: DB config values.
"""
db_config = DBConfig(**sqlite_db_config_values)
config_no_db_url_values["db_config"] = db_config
config = rx.Config(**config_no_db_url_values)
assert config.db_url == db_config.get_url()
assert config.app_name == base_config_values["app_name"]
@pytest.mark.parametrize(
"key, value, expected_value_type_in_config",
(
("TIMEOUT", "1", int),
("CORS_ALLOWED_ORIGINS", "[1, 2, 3]", list),
("DB_NAME", "dbname", str),
),
"param",
[
"db_config",
"admin_dash",
"env_path",
],
)
def test_get_value(monkeypatch, key, value, expected_value_type_in_config):
monkeypatch.setenv(key, value)
casted_value = get_value(key, type_=expected_value_type_in_config)
def test_deprecated_params(base_config_values, param):
"""Test that deprecated params are removed from the config.
assert isinstance(casted_value, expected_value_type_in_config)
Args:
base_config_values: Config values.
param: The deprecated param.
"""
with pytest.raises(ValueError):
rx.Config(**base_config_values, **{param: "test"})
@pytest.mark.parametrize(
"env_var, value",
[
("APP_NAME", "my_test_app"),
("FRONTEND_PORT", 3001),
("BACKEND_PORT", 8001),
("API_URL", "https://mybackend.com:8000"),
("DEPLOY_URL", "https://myfrontend.com"),
("BACKEND_HOST", "127.0.0.1"),
("DB_URL", "postgresql://user:pass@localhost:5432/db"),
("REDIS_URL", "redis://localhost:6379"),
("TIMEOUT", 600),
],
)
def test_update_from_env(base_config_values, monkeypatch, env_var, value):
"""Test that environment variables override config values.
Args:
base_config_values: Config values.
monkeypatch: The pytest monkeypatch object.
env_var: The environment variable name.
value: The environment variable value.
"""
monkeypatch.setenv(env_var, value)
assert os.environ.get(env_var) == str(value)
config = rx.Config(**base_config_values)
assert getattr(config, env_var.lower()) == value
@pytest.mark.parametrize(

View File

@ -7,7 +7,7 @@ import pytest
import typer
from packaging import version
from reflex import Env, constants
from reflex import constants
from reflex.base import Base
from reflex.utils import (
build,
@ -323,7 +323,7 @@ def test_setup_frontend(tmp_path, mocker):
(assets / "favicon.ico").touch()
mocker.patch("reflex.utils.prerequisites.install_frontend_packages")
mocker.patch("reflex.utils.build.set_environment_variables")
mocker.patch("reflex.utils.build.set_env_json")
build.setup_frontend(tmp_path, disable_telemetry=False)
assert web_public_folder.exists()
@ -421,8 +421,6 @@ def test_create_config_e2e(tmp_working_dir):
exec((tmp_working_dir / constants.CONFIG_FILE).read_text(), eval_globals)
config = eval_globals["config"]
assert config.app_name == app_name
assert config.db_url == constants.DB_URL
assert config.env == Env.DEV
@pytest.mark.parametrize(