From 832b8038fcd92fff7444e9c9f54ce56bbe5b971f Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 26 Oct 2024 14:16:04 +0200 Subject: [PATCH] add internal and enum env var support --- reflex/app.py | 2 +- reflex/base.py | 2 +- reflex/config.py | 56 +++++++++++++++++++++++++++------ reflex/constants/__init__.py | 2 -- reflex/constants/base.py | 1 - reflex/utils/exec.py | 6 +--- reflex/utils/prerequisites.py | 2 +- tests/units/utils/test_utils.py | 9 ++++++ 8 files changed, 60 insertions(+), 20 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 793ead1e3..ea292d364 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -505,7 +505,7 @@ class App(MiddlewareMixin, LifespanMixin, Base): # Check if the route given is valid verify_route_validity(route) - if route in self.unevaluated_pages and os.getenv(constants.RELOAD_CONFIG): + if route in self.unevaluated_pages and environment.RELOAD_CONFIG.get: # when the app is reloaded(typically for app harness tests), we should maintain # the latest render function of a route.This applies typically to decorated pages # since they are only added when app._compile is called. diff --git a/reflex/base.py b/reflex/base.py index c334ddf56..d6b980422 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -31,7 +31,7 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None """ from reflex.utils.exceptions import VarNameError - reload = os.getenv(constants.RELOAD_CONFIG) == "True" + reload = os.getenv("RELOAD_CONFIG") == "True" for base in bases: try: if not reload and getattr(base, field_name, None): diff --git a/reflex/config.py b/reflex/config.py index bab8bcaf0..7aeea04b9 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -3,7 +3,9 @@ from __future__ import annotations import dataclasses +import enum import importlib +import inspect import os import sys import urllib.parse @@ -231,6 +233,28 @@ def interpret_path_env(value: str, field_name: str) -> Path: return path +def interpret_enum_env(value: str, field_type: GenericType, field_name: str) -> Any: + """Interpret an enum environment variable value. + + Args: + value: The environment variable value. + field_type: The field type. + field_name: The field name. + + Returns: + The interpreted value. + + Raises: + EnvironmentVarValueError: If the value is invalid. + """ + try: + return field_type(value) + except ValueError as ve: + raise EnvironmentVarValueError( + f"Invalid enum value: {value} for {field_name}" + ) from ve + + def interpret_env_var_value( value: str, field_type: GenericType, field_name: str ) -> Any: @@ -262,6 +286,8 @@ def interpret_env_var_value( return interpret_int_env(value, field_name) elif field_type is Path: return interpret_path_env(value, field_name) + elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): + return interpret_enum_env(value, field_type, field_name) else: raise ValueError( @@ -291,6 +317,7 @@ class EnvVar(Generic[T]): self.default = default self.type_ = type_ + @property def getenv(self) -> Any: """Get the environment variable from os.environ. @@ -306,13 +333,13 @@ class EnvVar(Generic[T]): Returns: The interpreted value. """ - env_value = self.getenv() + env_value = self.getenv if env_value is not None: return interpret_env_var_value(env_value, self.type_, self.name) return self.default - def set(self, value: T) -> None: - """Set the environment variable. + def set(self, value: T | None) -> None: + """Set the environment variable. None unsets the variable. Args: value: The value to set. @@ -320,6 +347,8 @@ class EnvVar(Generic[T]): if value is None: _ = os.environ.pop(self.name, None) else: + if isinstance(value, enum.Enum): + value = value.value os.environ[self.name] = str(value) @@ -328,14 +357,17 @@ class env_var: # type: ignore name: str default: Any + internal: bool = False - def __init__(self, default: Any): + def __init__(self, default: Any, internal: bool = False) -> None: """Initialize the descriptor. Args: default: The default value. + internal: Whether the environment variable is reflex internal. """ self.default = default + self.internal = internal def __set_name__(self, owner, name): """Set the name of the descriptor. @@ -354,12 +386,15 @@ class env_var: # type: ignore owner: The owner class. """ type_ = get_args(get_type_hints(owner)[self.name])[0] - return EnvVar[type_](name=self.name, default=self.default, type_=type_) + name = self.name + if self.internal: + name = f"__{name}" + return EnvVar(name=self.name, default=self.default, type_=type_) if TYPE_CHECKING: - def env_var(default) -> EnvVar: + def env_var(default, internal=False) -> EnvVar: """Typing helper for the env_var descriptor.""" return default @@ -430,9 +465,6 @@ class EnvironmentVariables: constants.Templates.REFLEX_BUILD_BACKEND ) - # If this env var is set to "yes", App.compile will be a no-op - REFLEX_SKIP_COMPILE: EnvVar[bool] = env_var(False) - # This env var stores the execution mode of the app REFLEX_ENV_MODE: EnvVar[constants.Env] = env_var(constants.Env.DEV) @@ -445,6 +477,12 @@ class EnvironmentVariables: # Whether to send telemetry data to Reflex. TELEMETRY_ENABLED: EnvVar[bool] = env_var(True) + # Reflex internal env to reload the config. + RELOAD_CONFIG: EnvVar[bool] = env_var(False, internal=True) + + # If this env var is set to "yes", App.compile will be a no-op + REFLEX_SKIP_COMPILE: EnvVar[bool] = env_var(False, internal=True) + environment = EnvironmentVariables() diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index fd77941db..a1233a3fa 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -8,7 +8,6 @@ from .base import ( PYTEST_CURRENT_TEST, REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG, - RELOAD_CONFIG, SESSION_STORAGE, ColorMode, Dirs, @@ -102,7 +101,6 @@ __ALL__ = [ POLLING_MAX_HTTP_BUFFER_SIZE, PYTEST_CURRENT_TEST, Reflex, - RELOAD_CONFIG, RequirementsTxt, RouteArgType, RouteRegex, diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 28cc48e18..5319d872a 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -243,7 +243,6 @@ SESSION_STORAGE = "session_storage" # Testing variables. # Testing os env set by pytest when running a test case. PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" -RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG" REFLEX_VAR_OPENING_TAG = "" REFLEX_VAR_CLOSING_TAG = "" diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 90a12adac..6b2e74fb3 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -491,9 +491,5 @@ def is_prod_mode() -> bool: Returns: True if the app is running in production mode or False if running in dev mode. """ - current_mode = os.environ.get( - environment.REFLEX_ENV_MODE.get, - constants.Env.DEV.value, - ) current_mode = environment.REFLEX_ENV_MODE.get - return current_mode == constants.Env.PROD.value + return current_mode == constants.Env.PROD diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index b45ecded4..a47505483 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -267,7 +267,7 @@ def get_app(reload: bool = False) -> ModuleType: from reflex.utils import telemetry try: - os.environ[constants.RELOAD_CONFIG] = str(reload) + environment.RELOAD_CONFIG.set(reload) config = get_config() if not config.app_name: raise RuntimeError( diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index 81579acc7..d25a06347 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -10,6 +10,7 @@ from packaging import version from reflex import constants from reflex.base import Base +from reflex.config import environment from reflex.event import EventHandler from reflex.state import BaseState from reflex.utils import ( @@ -552,3 +553,11 @@ def test_style_prop_with_event_handler_value(callable): rx.box( style=style, # type: ignore ) + + +def test_is_prod_mode() -> None: + """Test that the prod mode is correctly determined.""" + environment.REFLEX_ENV_MODE.set(constants.Env.PROD) + assert utils_exec.is_prod_mode() + environment.REFLEX_ENV_MODE.set(None) + assert not utils_exec.is_prod_mode()