* port enum env var support from #4248 * add some tests for interpret env var functions
This commit is contained in:
parent
01ca42648b
commit
08d8d54b50
@ -3,7 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import urllib.parse
|
||||
@ -221,6 +223,28 @@ def interpret_path_env(value: str, field_name: str) -> Path:
|
||||
return path
|
||||
|
||||
|
||||
def interpret_enum_env(value: str, field_type: GenericType, field_name: str) -> Any:
|
||||
"""Interpret an enum environment variable value.
|
||||
|
||||
Args:
|
||||
value: The environment variable value.
|
||||
field_type: The field type.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
The interpreted value.
|
||||
|
||||
Raises:
|
||||
EnvironmentVarValueError: If the value is invalid.
|
||||
"""
|
||||
try:
|
||||
return field_type(value)
|
||||
except ValueError as ve:
|
||||
raise EnvironmentVarValueError(
|
||||
f"Invalid enum value: {value} for {field_name}"
|
||||
) from ve
|
||||
|
||||
|
||||
def interpret_env_var_value(
|
||||
value: str, field_type: GenericType, field_name: str
|
||||
) -> Any:
|
||||
@ -252,6 +276,8 @@ def interpret_env_var_value(
|
||||
return interpret_int_env(value, field_name)
|
||||
elif field_type is Path:
|
||||
return interpret_path_env(value, field_name)
|
||||
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
|
||||
return interpret_enum_env(value, field_type, field_name)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -7,8 +7,13 @@ import pytest
|
||||
|
||||
import reflex as rx
|
||||
import reflex.config
|
||||
from reflex.config import environment
|
||||
from reflex.constants import Endpoint
|
||||
from reflex.config import (
|
||||
environment,
|
||||
interpret_boolean_env,
|
||||
interpret_enum_env,
|
||||
interpret_int_env,
|
||||
)
|
||||
from reflex.constants import Endpoint, Env
|
||||
|
||||
|
||||
def test_requires_app_name():
|
||||
@ -208,11 +213,11 @@ def test_replace_defaults(
|
||||
assert getattr(c, key) == value
|
||||
|
||||
|
||||
def reflex_dir_constant():
|
||||
def reflex_dir_constant() -> Path:
|
||||
return environment.REFLEX_DIR
|
||||
|
||||
|
||||
def test_reflex_dir_env_var(monkeypatch, tmp_path):
|
||||
def test_reflex_dir_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
"""Test that the REFLEX_DIR environment variable is used to set the Reflex.DIR constant.
|
||||
|
||||
Args:
|
||||
@ -224,3 +229,16 @@ def test_reflex_dir_env_var(monkeypatch, tmp_path):
|
||||
mp_ctx = multiprocessing.get_context(method="spawn")
|
||||
with mp_ctx.Pool(processes=1) as pool:
|
||||
assert pool.apply(reflex_dir_constant) == tmp_path
|
||||
|
||||
|
||||
def test_interpret_enum_env() -> None:
|
||||
assert interpret_enum_env(Env.PROD.value, Env, "REFLEX_ENV") == Env.PROD
|
||||
|
||||
|
||||
def test_interpret_int_env() -> None:
|
||||
assert interpret_int_env("3001", "FRONTEND_PORT") == 3001
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, expected", [("true", True), ("false", False)])
|
||||
def test_interpret_bool_env(value: str, expected: bool) -> None:
|
||||
assert interpret_boolean_env(value, "TELEMETRY_ENABLED") == expected
|
||||
|
Loading…
Reference in New Issue
Block a user