From 4ddc014d2864676dfe3aff458dcff594fddad2ae Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Mon, 21 Oct 2024 22:46:23 +0200 Subject: [PATCH] streamline env var interpretation with @adhami3310 --- reflex/config.py | 54 ++++++++++++++------------------------ tests/units/test_config.py | 25 +++++++++++++++++- 2 files changed, 43 insertions(+), 36 deletions(-) diff --git a/reflex/config.py b/reflex/config.py index 44813d7e0..c055fe3ae 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -4,17 +4,16 @@ from __future__ import annotations import dataclasses import importlib -import inspect import os import sys import urllib.parse from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set from typing_extensions import get_type_hints from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError -from reflex.utils.types import is_optional, is_union, value_inside_optional +from reflex.utils.types import GenericType, is_union, value_inside_optional try: import pydantic.v1 as pydantic @@ -216,12 +215,15 @@ def interpret_path_env(value: str) -> Path: return path -def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any: +def interpret_env_var_value( + value: str, field_type: GenericType, field_name: str +) -> Any: """Interpret an environment variable value based on the field type. Args: value: The environment variable value. - field: The field. + field_type: The field type. + field_name: The field name. Returns: The interpreted value. @@ -229,7 +231,12 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any: Raises: ValueError: If the value is invalid. """ - field_type = value_inside_optional(field.type) + field_type = value_inside_optional(field_type) + + if is_union(field_type): + raise ValueError( + f"Union types are not supported for environment variables: {field_name}." + ) if field_type is bool: return interpret_boolean_env(value) @@ -242,7 +249,7 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any: else: raise ValueError( - f"Invalid type for environment variable {field.name}: {field_type}. This is probably an issue in Reflex." + f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex." ) @@ -317,7 +324,7 @@ class EnvironmentVariables: field.type = type_hints.get(field.name) or field.type value = ( - interpret_env_var_value(raw_value, field) + interpret_env_var_value(raw_value, field.type, field.name) if raw_value is not None else get_default_value_for_field(field) ) @@ -388,7 +395,7 @@ class Config(Base): telemetry_enabled: bool = True # The bun path - bun_path: Union[str, Path] = constants.Bun.DEFAULT_PATH + bun_path: Path = constants.Bun.DEFAULT_PATH # List of origins that are allowed to connect to the backend API. cors_allowed_origins: List[str] = ["*"] @@ -475,12 +482,7 @@ class Config(Base): Returns: The updated config values. - - Raises: - EnvVarValueError: If an environment variable is set to an invalid type. """ - from reflex.utils.exceptions import EnvVarValueError - updated_values = {} # Iterate over the fields. for key, field in self.__fields__.items(): @@ -495,29 +497,11 @@ class Config(Base): dedupe=True, ) - # Convert the env var to the expected type. - type_ = field.type_ - if is_optional(type_): - type_ = type_.__args__[0] or type_.__args__[1] - # TODO: This just handles the first type in a Union. Needs refactoring. - if is_union(type_): - type_ = type_.__args__[0] - try: - if type_ is bool or ( - inspect.isclass(type_) and issubclass(type_, bool) - ): - # special handling for bool values - env_var = env_var.lower() in ["true", "1", "yes"] - else: - env_var = type_(env_var) - except ValueError as ve: - console.error( - f"Could not convert {key.upper()}={env_var} to type {type_}" - ) - raise EnvVarValueError from ve + # Interpret the value. + value = interpret_env_var_value(env_var, field.type_, field.name) # Set the value. - updated_values[key] = env_var + updated_values[key] = value return updated_values diff --git a/tests/units/test_config.py b/tests/units/test_config.py index fa3b64a61..1027042c9 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -1,5 +1,6 @@ import multiprocessing import os +from pathlib import Path from typing import Any, Dict import pytest @@ -30,7 +31,6 @@ def test_set_app_name(base_config_values): "env_var, value", [ ("APP_NAME", "my_test_app"), - ("BUN_PATH", "/test"), ("FRONTEND_PORT", 3001), ("FRONTEND_PATH", "/test"), ("BACKEND_PORT", 8001), @@ -64,6 +64,29 @@ def test_update_from_env( assert getattr(config, env_var.lower()) == value +def test_update_from_env_path( + base_config_values: Dict[str, Any], + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +): + """Test that environment variables override config values. + + Args: + base_config_values: Config values. + monkeypatch: The pytest monkeypatch object. + tmp_path: The pytest tmp_path fixture object. + """ + monkeypatch.setenv("BUN_PATH", "/test") + assert os.environ.get("BUN_PATH") == "/test" + with pytest.raises(ValueError): + rx.Config(**base_config_values) + + monkeypatch.setenv("BUN_PATH", str(tmp_path)) + assert os.environ.get("BUN_PATH") == str(tmp_path) + config = rx.Config(**base_config_values) + assert config.bun_path == tmp_path + + @pytest.mark.parametrize( "kwargs, expected", [