From 08d8d54b50ea4beea07f989af6aa1760f8fc39c5 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Mon, 28 Oct 2024 19:56:40 +0100 Subject: [PATCH] port enum env var support from #4248 (#4251) * port enum env var support from #4248 * add some tests for interpret env var functions --- reflex/config.py | 26 ++++++++++++++++++++++++++ tests/units/test_config.py | 26 ++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/reflex/config.py b/reflex/config.py index ba86d911d..22a04c50c 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -3,7 +3,9 @@ from __future__ import annotations import dataclasses +import enum import importlib +import inspect import os import sys import urllib.parse @@ -221,6 +223,28 @@ def interpret_path_env(value: str, field_name: str) -> Path: return path +def interpret_enum_env(value: str, field_type: GenericType, field_name: str) -> Any: + """Interpret an enum environment variable value. + + Args: + value: The environment variable value. + field_type: The field type. + field_name: The field name. + + Returns: + The interpreted value. + + Raises: + EnvironmentVarValueError: If the value is invalid. + """ + try: + return field_type(value) + except ValueError as ve: + raise EnvironmentVarValueError( + f"Invalid enum value: {value} for {field_name}" + ) from ve + + def interpret_env_var_value( value: str, field_type: GenericType, field_name: str ) -> Any: @@ -252,6 +276,8 @@ def interpret_env_var_value( return interpret_int_env(value, field_name) elif field_type is Path: return interpret_path_env(value, field_name) + elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): + return interpret_enum_env(value, field_type, field_name) else: raise ValueError( diff --git a/tests/units/test_config.py b/tests/units/test_config.py index 1027042c9..0c63abc96 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -7,8 +7,13 @@ import pytest import reflex as rx import reflex.config -from reflex.config import environment -from reflex.constants import Endpoint +from reflex.config import ( + environment, + interpret_boolean_env, + interpret_enum_env, + interpret_int_env, +) +from reflex.constants import Endpoint, Env def test_requires_app_name(): @@ -208,11 +213,11 @@ def test_replace_defaults( assert getattr(c, key) == value -def reflex_dir_constant(): +def reflex_dir_constant() -> Path: return environment.REFLEX_DIR -def test_reflex_dir_env_var(monkeypatch, tmp_path): +def test_reflex_dir_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: """Test that the REFLEX_DIR environment variable is used to set the Reflex.DIR constant. Args: @@ -224,3 +229,16 @@ def test_reflex_dir_env_var(monkeypatch, tmp_path): mp_ctx = multiprocessing.get_context(method="spawn") with mp_ctx.Pool(processes=1) as pool: assert pool.apply(reflex_dir_constant) == tmp_path + + +def test_interpret_enum_env() -> None: + assert interpret_enum_env(Env.PROD.value, Env, "REFLEX_ENV") == Env.PROD + + +def test_interpret_int_env() -> None: + assert interpret_int_env("3001", "FRONTEND_PORT") == 3001 + + +@pytest.mark.parametrize("value, expected", [("true", True), ("false", False)]) +def test_interpret_bool_env(value: str, expected: bool) -> None: + assert interpret_boolean_env(value, "TELEMETRY_ENABLED") == expected