streamline env var interpretation with @adhami3310

This commit is contained in:
Benedikt Bartscher 2024-10-21 22:46:23 +02:00
parent 44d60d1cc8
commit 4ddc014d28
No known key found for this signature in database
2 changed files with 43 additions and 36 deletions

View File

@ -4,17 +4,16 @@ from __future__ import annotations
import dataclasses import dataclasses
import importlib import importlib
import inspect
import os import os
import sys import sys
import urllib.parse import urllib.parse
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set
from typing_extensions import get_type_hints from typing_extensions import get_type_hints
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
from reflex.utils.types import is_optional, is_union, value_inside_optional from reflex.utils.types import GenericType, is_union, value_inside_optional
try: try:
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
@ -216,12 +215,15 @@ def interpret_path_env(value: str) -> Path:
return path return path
def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any: def interpret_env_var_value(
value: str, field_type: GenericType, field_name: str
) -> Any:
"""Interpret an environment variable value based on the field type. """Interpret an environment variable value based on the field type.
Args: Args:
value: The environment variable value. value: The environment variable value.
field: The field. field_type: The field type.
field_name: The field name.
Returns: Returns:
The interpreted value. The interpreted value.
@ -229,7 +231,12 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any:
Raises: Raises:
ValueError: If the value is invalid. ValueError: If the value is invalid.
""" """
field_type = value_inside_optional(field.type) field_type = value_inside_optional(field_type)
if is_union(field_type):
raise ValueError(
f"Union types are not supported for environment variables: {field_name}."
)
if field_type is bool: if field_type is bool:
return interpret_boolean_env(value) return interpret_boolean_env(value)
@ -242,7 +249,7 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any:
else: else:
raise ValueError( raise ValueError(
f"Invalid type for environment variable {field.name}: {field_type}. This is probably an issue in Reflex." f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex."
) )
@ -317,7 +324,7 @@ class EnvironmentVariables:
field.type = type_hints.get(field.name) or field.type field.type = type_hints.get(field.name) or field.type
value = ( value = (
interpret_env_var_value(raw_value, field) interpret_env_var_value(raw_value, field.type, field.name)
if raw_value is not None if raw_value is not None
else get_default_value_for_field(field) else get_default_value_for_field(field)
) )
@ -388,7 +395,7 @@ class Config(Base):
telemetry_enabled: bool = True telemetry_enabled: bool = True
# The bun path # The bun path
bun_path: Union[str, Path] = constants.Bun.DEFAULT_PATH bun_path: Path = constants.Bun.DEFAULT_PATH
# List of origins that are allowed to connect to the backend API. # List of origins that are allowed to connect to the backend API.
cors_allowed_origins: List[str] = ["*"] cors_allowed_origins: List[str] = ["*"]
@ -475,12 +482,7 @@ class Config(Base):
Returns: Returns:
The updated config values. The updated config values.
Raises:
EnvVarValueError: If an environment variable is set to an invalid type.
""" """
from reflex.utils.exceptions import EnvVarValueError
updated_values = {} updated_values = {}
# Iterate over the fields. # Iterate over the fields.
for key, field in self.__fields__.items(): for key, field in self.__fields__.items():
@ -495,29 +497,11 @@ class Config(Base):
dedupe=True, dedupe=True,
) )
# Convert the env var to the expected type. # Interpret the value.
type_ = field.type_ value = interpret_env_var_value(env_var, field.type_, field.name)
if is_optional(type_):
type_ = type_.__args__[0] or type_.__args__[1]
# TODO: This just handles the first type in a Union. Needs refactoring.
if is_union(type_):
type_ = type_.__args__[0]
try:
if type_ is bool or (
inspect.isclass(type_) and issubclass(type_, bool)
):
# special handling for bool values
env_var = env_var.lower() in ["true", "1", "yes"]
else:
env_var = type_(env_var)
except ValueError as ve:
console.error(
f"Could not convert {key.upper()}={env_var} to type {type_}"
)
raise EnvVarValueError from ve
# Set the value. # Set the value.
updated_values[key] = env_var updated_values[key] = value
return updated_values return updated_values

View File

@ -1,5 +1,6 @@
import multiprocessing import multiprocessing
import os import os
from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
import pytest import pytest
@ -30,7 +31,6 @@ def test_set_app_name(base_config_values):
"env_var, value", "env_var, value",
[ [
("APP_NAME", "my_test_app"), ("APP_NAME", "my_test_app"),
("BUN_PATH", "/test"),
("FRONTEND_PORT", 3001), ("FRONTEND_PORT", 3001),
("FRONTEND_PATH", "/test"), ("FRONTEND_PATH", "/test"),
("BACKEND_PORT", 8001), ("BACKEND_PORT", 8001),
@ -64,6 +64,29 @@ def test_update_from_env(
assert getattr(config, env_var.lower()) == value assert getattr(config, env_var.lower()) == value
def test_update_from_env_path(
base_config_values: Dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
):
"""Test that environment variables override config values.
Args:
base_config_values: Config values.
monkeypatch: The pytest monkeypatch object.
tmp_path: The pytest tmp_path fixture object.
"""
monkeypatch.setenv("BUN_PATH", "/test")
assert os.environ.get("BUN_PATH") == "/test"
with pytest.raises(ValueError):
rx.Config(**base_config_values)
monkeypatch.setenv("BUN_PATH", str(tmp_path))
assert os.environ.get("BUN_PATH") == str(tmp_path)
config = rx.Config(**base_config_values)
assert config.bun_path == tmp_path
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kwargs, expected", "kwargs, expected",
[ [