streamline env var interpretation with @adhami3310
This commit is contained in:
parent
44d60d1cc8
commit
4ddc014d28
@ -4,17 +4,16 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from typing_extensions import get_type_hints
|
||||
|
||||
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
|
||||
from reflex.utils.types import is_optional, is_union, value_inside_optional
|
||||
from reflex.utils.types import GenericType, is_union, value_inside_optional
|
||||
|
||||
try:
|
||||
import pydantic.v1 as pydantic
|
||||
@ -216,12 +215,15 @@ def interpret_path_env(value: str) -> Path:
|
||||
return path
|
||||
|
||||
|
||||
def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any:
|
||||
def interpret_env_var_value(
|
||||
value: str, field_type: GenericType, field_name: str
|
||||
) -> Any:
|
||||
"""Interpret an environment variable value based on the field type.
|
||||
|
||||
Args:
|
||||
value: The environment variable value.
|
||||
field: The field.
|
||||
field_type: The field type.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
The interpreted value.
|
||||
@ -229,7 +231,12 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any:
|
||||
Raises:
|
||||
ValueError: If the value is invalid.
|
||||
"""
|
||||
field_type = value_inside_optional(field.type)
|
||||
field_type = value_inside_optional(field_type)
|
||||
|
||||
if is_union(field_type):
|
||||
raise ValueError(
|
||||
f"Union types are not supported for environment variables: {field_name}."
|
||||
)
|
||||
|
||||
if field_type is bool:
|
||||
return interpret_boolean_env(value)
|
||||
@ -242,7 +249,7 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any:
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type for environment variable {field.name}: {field_type}. This is probably an issue in Reflex."
|
||||
f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex."
|
||||
)
|
||||
|
||||
|
||||
@ -317,7 +324,7 @@ class EnvironmentVariables:
|
||||
field.type = type_hints.get(field.name) or field.type
|
||||
|
||||
value = (
|
||||
interpret_env_var_value(raw_value, field)
|
||||
interpret_env_var_value(raw_value, field.type, field.name)
|
||||
if raw_value is not None
|
||||
else get_default_value_for_field(field)
|
||||
)
|
||||
@ -388,7 +395,7 @@ class Config(Base):
|
||||
telemetry_enabled: bool = True
|
||||
|
||||
# The bun path
|
||||
bun_path: Union[str, Path] = constants.Bun.DEFAULT_PATH
|
||||
bun_path: Path = constants.Bun.DEFAULT_PATH
|
||||
|
||||
# List of origins that are allowed to connect to the backend API.
|
||||
cors_allowed_origins: List[str] = ["*"]
|
||||
@ -475,12 +482,7 @@ class Config(Base):
|
||||
|
||||
Returns:
|
||||
The updated config values.
|
||||
|
||||
Raises:
|
||||
EnvVarValueError: If an environment variable is set to an invalid type.
|
||||
"""
|
||||
from reflex.utils.exceptions import EnvVarValueError
|
||||
|
||||
updated_values = {}
|
||||
# Iterate over the fields.
|
||||
for key, field in self.__fields__.items():
|
||||
@ -495,29 +497,11 @@ class Config(Base):
|
||||
dedupe=True,
|
||||
)
|
||||
|
||||
# Convert the env var to the expected type.
|
||||
type_ = field.type_
|
||||
if is_optional(type_):
|
||||
type_ = type_.__args__[0] or type_.__args__[1]
|
||||
# TODO: This just handles the first type in a Union. Needs refactoring.
|
||||
if is_union(type_):
|
||||
type_ = type_.__args__[0]
|
||||
try:
|
||||
if type_ is bool or (
|
||||
inspect.isclass(type_) and issubclass(type_, bool)
|
||||
):
|
||||
# special handling for bool values
|
||||
env_var = env_var.lower() in ["true", "1", "yes"]
|
||||
else:
|
||||
env_var = type_(env_var)
|
||||
except ValueError as ve:
|
||||
console.error(
|
||||
f"Could not convert {key.upper()}={env_var} to type {type_}"
|
||||
)
|
||||
raise EnvVarValueError from ve
|
||||
# Interpret the value.
|
||||
value = interpret_env_var_value(env_var, field.type_, field.name)
|
||||
|
||||
# Set the value.
|
||||
updated_values[key] = env_var
|
||||
updated_values[key] = value
|
||||
|
||||
return updated_values
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
@ -30,7 +31,6 @@ def test_set_app_name(base_config_values):
|
||||
"env_var, value",
|
||||
[
|
||||
("APP_NAME", "my_test_app"),
|
||||
("BUN_PATH", "/test"),
|
||||
("FRONTEND_PORT", 3001),
|
||||
("FRONTEND_PATH", "/test"),
|
||||
("BACKEND_PORT", 8001),
|
||||
@ -64,6 +64,29 @@ def test_update_from_env(
|
||||
assert getattr(config, env_var.lower()) == value
|
||||
|
||||
|
||||
def test_update_from_env_path(
|
||||
base_config_values: Dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test that environment variables override config values.
|
||||
|
||||
Args:
|
||||
base_config_values: Config values.
|
||||
monkeypatch: The pytest monkeypatch object.
|
||||
tmp_path: The pytest tmp_path fixture object.
|
||||
"""
|
||||
monkeypatch.setenv("BUN_PATH", "/test")
|
||||
assert os.environ.get("BUN_PATH") == "/test"
|
||||
with pytest.raises(ValueError):
|
||||
rx.Config(**base_config_values)
|
||||
|
||||
monkeypatch.setenv("BUN_PATH", str(tmp_path))
|
||||
assert os.environ.get("BUN_PATH") == str(tmp_path)
|
||||
config = rx.Config(**base_config_values)
|
||||
assert config.bun_path == tmp_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs, expected",
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user