From bb093c7873eaf20011b10b3ccf3fc9997d23bfb5 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Mon, 25 Nov 2024 23:57:47 +0100 Subject: [PATCH] feat: add env var validators --- reflex/config.py | 46 ++++++++++++++++++++++++++++++++++---- tests/units/test_config.py | 12 ++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/reflex/config.py b/reflex/config.py index 88230cefe..ed121d75c 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Generic, List, @@ -314,6 +315,8 @@ def interpret_env_var_value( T = TypeVar("T") +VALIDATOR_TYPE = List[Callable[[T], None]] + class EnvVar(Generic[T]): """Environment variable.""" @@ -321,18 +324,39 @@ class EnvVar(Generic[T]): name: str default: Any type_: T + validators: VALIDATOR_TYPE - def __init__(self, name: str, default: Any, type_: T) -> None: + def __init__( + self, + name: str, + default: Any, + type_: T, + validators: Optional[VALIDATOR_TYPE] = None, + ) -> None: """Initialize the environment variable. Args: name: The environment variable name. default: The default value. type_: The type of the value. + validators: The validators to apply to the value. """ self.name = name self.default = default self.type_ = type_ + self.validators = validators or [] + + self.validate(default) + self.validate(self.get()) + + def validate(self, value: T | None = None) -> None: + """Validate the environment variable. + + Args: + value: The value to validate. + """ + for validator in self.validators: + validator(value) def interpret(self, value: str) -> T: """Interpret the environment variable value. @@ -381,6 +405,7 @@ class EnvVar(Generic[T]): Args: value: The value to set. """ + self.validate(value) if value is None: _ = os.environ.pop(self.name, None) else: @@ -395,16 +420,24 @@ class env_var: # type: ignore name: str default: Any internal: bool = False + validators: VALIDATOR_TYPE - def __init__(self, default: Any, internal: bool = False) -> None: + def __init__( + self, + default: Any, + internal: bool = False, + validators: Optional[VALIDATOR_TYPE] = None, + ) -> None: """Initialize the descriptor. Args: default: The default value. internal: Whether the environment variable is reflex internal. + validators: The validators to apply to the value. """ self.default = default self.internal = internal + self.validators = validators or [] def __set_name__(self, owner, name): """Set the name of the descriptor. @@ -429,17 +462,22 @@ class env_var: # type: ignore env_name = self.name if self.internal: env_name = f"__{env_name}" - return EnvVar(name=env_name, default=self.default, type_=type_) + return EnvVar( + name=env_name, default=self.default, type_=type_, validators=self.validators + ) if TYPE_CHECKING: - def env_var(default, internal=False) -> EnvVar: + def env_var( + default, internal=False, validators: Optional[VALIDATOR_TYPE[T]] = None + ) -> EnvVar[T]: """Typing helper for the env_var descriptor. Args: default: The default value. internal: Whether the environment variable is reflex internal. + validators: The validators to apply to the EnvVar. Returns: The EnvVar instance. diff --git a/tests/units/test_config.py b/tests/units/test_config.py index e5d4622bd..ba5f756e7 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -248,11 +248,17 @@ def test_interpret_bool_env(value: str, expected: bool) -> None: def test_env_var(): + def validate_positive(value: int) -> None: + if value < 0: + raise ValueError("Value must be positive") + class TestEnv: BLUBB: EnvVar[str] = env_var("default") INTERNAL: EnvVar[str] = env_var("default", internal=True) BOOLEAN: EnvVar[bool] = env_var(False) + VALIDATE: EnvVar[int] = env_var(0, validators=[validate_positive]) + assert TestEnv.BLUBB.get() == "default" assert TestEnv.BLUBB.name == "BLUBB" TestEnv.BLUBB.set("new") @@ -280,3 +286,9 @@ def test_env_var(): assert TestEnv.BOOLEAN.get() is False TestEnv.BOOLEAN.set(None) assert "BOOLEAN" not in os.environ + + assert TestEnv.VALIDATE.get() == 0 + TestEnv.VALIDATE.set(1) + assert TestEnv.VALIDATE.get() == 1 + with pytest.raises(ValueError): + TestEnv.VALIDATE.set(-1)