add internal and enum env var support

This commit is contained in:
Benedikt Bartscher 2024-10-26 14:16:04 +02:00
parent 79be00217b
commit 832b8038fc
No known key found for this signature in database
8 changed files with 60 additions and 20 deletions

View File

@ -505,7 +505,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
# Check if the route given is valid # Check if the route given is valid
verify_route_validity(route) verify_route_validity(route)
if route in self.unevaluated_pages and os.getenv(constants.RELOAD_CONFIG): if route in self.unevaluated_pages and environment.RELOAD_CONFIG.get:
# when the app is reloaded(typically for app harness tests), we should maintain # when the app is reloaded(typically for app harness tests), we should maintain
# the latest render function of a route.This applies typically to decorated pages # the latest render function of a route.This applies typically to decorated pages
# since they are only added when app._compile is called. # since they are only added when app._compile is called.

View File

@ -31,7 +31,7 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None
""" """
from reflex.utils.exceptions import VarNameError from reflex.utils.exceptions import VarNameError
reload = os.getenv(constants.RELOAD_CONFIG) == "True" reload = os.getenv("RELOAD_CONFIG") == "True"
for base in bases: for base in bases:
try: try:
if not reload and getattr(base, field_name, None): if not reload and getattr(base, field_name, None):

View File

@ -3,7 +3,9 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import enum
import importlib import importlib
import inspect
import os import os
import sys import sys
import urllib.parse import urllib.parse
@ -231,6 +233,28 @@ def interpret_path_env(value: str, field_name: str) -> Path:
return path return path
def interpret_enum_env(value: str, field_type: GenericType, field_name: str) -> Any:
"""Interpret an enum environment variable value.
Args:
value: The environment variable value.
field_type: The field type.
field_name: The field name.
Returns:
The interpreted value.
Raises:
EnvironmentVarValueError: If the value is invalid.
"""
try:
return field_type(value)
except ValueError as ve:
raise EnvironmentVarValueError(
f"Invalid enum value: {value} for {field_name}"
) from ve
def interpret_env_var_value( def interpret_env_var_value(
value: str, field_type: GenericType, field_name: str value: str, field_type: GenericType, field_name: str
) -> Any: ) -> Any:
@ -262,6 +286,8 @@ def interpret_env_var_value(
return interpret_int_env(value, field_name) return interpret_int_env(value, field_name)
elif field_type is Path: elif field_type is Path:
return interpret_path_env(value, field_name) return interpret_path_env(value, field_name)
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
return interpret_enum_env(value, field_type, field_name)
else: else:
raise ValueError( raise ValueError(
@ -291,6 +317,7 @@ class EnvVar(Generic[T]):
self.default = default self.default = default
self.type_ = type_ self.type_ = type_
@property
def getenv(self) -> Any: def getenv(self) -> Any:
"""Get the environment variable from os.environ. """Get the environment variable from os.environ.
@ -306,13 +333,13 @@ class EnvVar(Generic[T]):
Returns: Returns:
The interpreted value. The interpreted value.
""" """
env_value = self.getenv() env_value = self.getenv
if env_value is not None: if env_value is not None:
return interpret_env_var_value(env_value, self.type_, self.name) return interpret_env_var_value(env_value, self.type_, self.name)
return self.default return self.default
def set(self, value: T) -> None: def set(self, value: T | None) -> None:
"""Set the environment variable. """Set the environment variable. None unsets the variable.
Args: Args:
value: The value to set. value: The value to set.
@ -320,6 +347,8 @@ class EnvVar(Generic[T]):
if value is None: if value is None:
_ = os.environ.pop(self.name, None) _ = os.environ.pop(self.name, None)
else: else:
if isinstance(value, enum.Enum):
value = value.value
os.environ[self.name] = str(value) os.environ[self.name] = str(value)
@ -328,14 +357,17 @@ class env_var: # type: ignore
name: str name: str
default: Any default: Any
internal: bool = False
def __init__(self, default: Any): def __init__(self, default: Any, internal: bool = False) -> None:
"""Initialize the descriptor. """Initialize the descriptor.
Args: Args:
default: The default value. default: The default value.
internal: Whether the environment variable is reflex internal.
""" """
self.default = default self.default = default
self.internal = internal
def __set_name__(self, owner, name): def __set_name__(self, owner, name):
"""Set the name of the descriptor. """Set the name of the descriptor.
@ -354,12 +386,15 @@ class env_var: # type: ignore
owner: The owner class. owner: The owner class.
""" """
type_ = get_args(get_type_hints(owner)[self.name])[0] type_ = get_args(get_type_hints(owner)[self.name])[0]
return EnvVar[type_](name=self.name, default=self.default, type_=type_) name = self.name
if self.internal:
name = f"__{name}"
return EnvVar(name=self.name, default=self.default, type_=type_)
if TYPE_CHECKING: if TYPE_CHECKING:
def env_var(default) -> EnvVar: def env_var(default, internal=False) -> EnvVar:
"""Typing helper for the env_var descriptor.""" """Typing helper for the env_var descriptor."""
return default return default
@ -430,9 +465,6 @@ class EnvironmentVariables:
constants.Templates.REFLEX_BUILD_BACKEND constants.Templates.REFLEX_BUILD_BACKEND
) )
# If this env var is set to "yes", App.compile will be a no-op
REFLEX_SKIP_COMPILE: EnvVar[bool] = env_var(False)
# This env var stores the execution mode of the app # This env var stores the execution mode of the app
REFLEX_ENV_MODE: EnvVar[constants.Env] = env_var(constants.Env.DEV) REFLEX_ENV_MODE: EnvVar[constants.Env] = env_var(constants.Env.DEV)
@ -445,6 +477,12 @@ class EnvironmentVariables:
# Whether to send telemetry data to Reflex. # Whether to send telemetry data to Reflex.
TELEMETRY_ENABLED: EnvVar[bool] = env_var(True) TELEMETRY_ENABLED: EnvVar[bool] = env_var(True)
# Reflex internal env to reload the config.
RELOAD_CONFIG: EnvVar[bool] = env_var(False, internal=True)
# If this env var is set to "yes", App.compile will be a no-op
REFLEX_SKIP_COMPILE: EnvVar[bool] = env_var(False, internal=True)
environment = EnvironmentVariables() environment = EnvironmentVariables()

View File

@ -8,7 +8,6 @@ from .base import (
PYTEST_CURRENT_TEST, PYTEST_CURRENT_TEST,
REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_CLOSING_TAG,
REFLEX_VAR_OPENING_TAG, REFLEX_VAR_OPENING_TAG,
RELOAD_CONFIG,
SESSION_STORAGE, SESSION_STORAGE,
ColorMode, ColorMode,
Dirs, Dirs,
@ -102,7 +101,6 @@ __ALL__ = [
POLLING_MAX_HTTP_BUFFER_SIZE, POLLING_MAX_HTTP_BUFFER_SIZE,
PYTEST_CURRENT_TEST, PYTEST_CURRENT_TEST,
Reflex, Reflex,
RELOAD_CONFIG,
RequirementsTxt, RequirementsTxt,
RouteArgType, RouteArgType,
RouteRegex, RouteRegex,

View File

@ -243,7 +243,6 @@ SESSION_STORAGE = "session_storage"
# Testing variables. # Testing variables.
# Testing os env set by pytest when running a test case. # Testing os env set by pytest when running a test case.
PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"
REFLEX_VAR_OPENING_TAG = "<reflex.Var>" REFLEX_VAR_OPENING_TAG = "<reflex.Var>"
REFLEX_VAR_CLOSING_TAG = "</reflex.Var>" REFLEX_VAR_CLOSING_TAG = "</reflex.Var>"

View File

@ -491,9 +491,5 @@ def is_prod_mode() -> bool:
Returns: Returns:
True if the app is running in production mode or False if running in dev mode. True if the app is running in production mode or False if running in dev mode.
""" """
current_mode = os.environ.get(
environment.REFLEX_ENV_MODE.get,
constants.Env.DEV.value,
)
current_mode = environment.REFLEX_ENV_MODE.get current_mode = environment.REFLEX_ENV_MODE.get
return current_mode == constants.Env.PROD.value return current_mode == constants.Env.PROD

View File

@ -267,7 +267,7 @@ def get_app(reload: bool = False) -> ModuleType:
from reflex.utils import telemetry from reflex.utils import telemetry
try: try:
os.environ[constants.RELOAD_CONFIG] = str(reload) environment.RELOAD_CONFIG.set(reload)
config = get_config() config = get_config()
if not config.app_name: if not config.app_name:
raise RuntimeError( raise RuntimeError(

View File

@ -10,6 +10,7 @@ from packaging import version
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.config import environment
from reflex.event import EventHandler from reflex.event import EventHandler
from reflex.state import BaseState from reflex.state import BaseState
from reflex.utils import ( from reflex.utils import (
@ -552,3 +553,11 @@ def test_style_prop_with_event_handler_value(callable):
rx.box( rx.box(
style=style, # type: ignore style=style, # type: ignore
) )
def test_is_prod_mode() -> None:
"""Test that the prod mode is correctly determined."""
environment.REFLEX_ENV_MODE.set(constants.Env.PROD)
assert utils_exec.is_prod_mode()
environment.REFLEX_ENV_MODE.set(None)
assert not utils_exec.is_prod_mode()