improve hot reload handling (#4795)
This commit is contained in:
parent
90be664981
commit
a194c90d6f
@ -23,6 +23,7 @@ from typing import (
|
||||
Set,
|
||||
TypeVar,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated, get_type_hints
|
||||
@ -304,6 +305,15 @@ def interpret_env_var_value(
|
||||
return interpret_path_env(value, field_name)
|
||||
elif field_type is ExistingPath:
|
||||
return interpret_existing_path_env(value, field_name)
|
||||
elif get_origin(field_type) is list:
|
||||
return [
|
||||
interpret_env_var_value(
|
||||
v,
|
||||
get_args(field_type)[0],
|
||||
f"{field_name}[{i}]",
|
||||
)
|
||||
for i, v in enumerate(value.split(":"))
|
||||
]
|
||||
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
|
||||
return interpret_enum_env(value, field_type, field_name)
|
||||
|
||||
@ -387,7 +397,11 @@ class EnvVar(Generic[T]):
|
||||
else:
|
||||
if isinstance(value, enum.Enum):
|
||||
value = value.value
|
||||
os.environ[self.name] = str(value)
|
||||
if isinstance(value, list):
|
||||
str_value = ":".join(str(v) for v in value)
|
||||
else:
|
||||
str_value = str(value)
|
||||
os.environ[self.name] = str_value
|
||||
|
||||
|
||||
class env_var: # noqa: N801 # pyright: ignore [reportRedeclaration]
|
||||
@ -571,6 +585,12 @@ class EnvironmentVariables:
|
||||
# Whether to use the turbopack bundler.
|
||||
REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True)
|
||||
|
||||
# Additional paths to include in the hot reload. Separated by a colon.
|
||||
REFLEX_HOT_RELOAD_INCLUDE_PATHS: EnvVar[List[Path]] = env_var([])
|
||||
|
||||
# Paths to exclude from the hot reload. Takes precedence over include paths. Separated by a colon.
|
||||
REFLEX_HOT_RELOAD_EXCLUDE_PATHS: EnvVar[List[Path]] = env_var([])
|
||||
|
||||
|
||||
environment = EnvironmentVariables()
|
||||
|
||||
|
@ -10,6 +10,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import psutil
|
||||
@ -242,14 +243,14 @@ def run_backend(
|
||||
run_uvicorn_backend(host, port, loglevel)
|
||||
|
||||
|
||||
def get_reload_dirs() -> list[Path]:
|
||||
"""Get the reload directories for the backend.
|
||||
def get_reload_paths() -> Sequence[Path]:
|
||||
"""Get the reload paths for the backend.
|
||||
|
||||
Returns:
|
||||
The reload directories for the backend.
|
||||
The reload paths for the backend.
|
||||
"""
|
||||
config = get_config()
|
||||
reload_dirs = [Path(config.app_name)]
|
||||
reload_paths = [Path(config.app_name).parent]
|
||||
if config.app_module is not None and config.app_module.__file__:
|
||||
module_path = Path(config.app_module.__file__).resolve().parent
|
||||
|
||||
@ -263,8 +264,43 @@ def get_reload_dirs() -> list[Path]:
|
||||
else:
|
||||
break
|
||||
|
||||
reload_dirs = [module_path]
|
||||
return reload_dirs
|
||||
reload_paths = [module_path]
|
||||
|
||||
include_dirs = tuple(
|
||||
map(Path.absolute, environment.REFLEX_HOT_RELOAD_INCLUDE_PATHS.get())
|
||||
)
|
||||
exclude_dirs = tuple(
|
||||
map(Path.absolute, environment.REFLEX_HOT_RELOAD_EXCLUDE_PATHS.get())
|
||||
)
|
||||
|
||||
def is_excluded_by_default(path: Path) -> bool:
|
||||
if path.is_dir():
|
||||
if path.name.startswith("."):
|
||||
# exclude hidden directories
|
||||
return True
|
||||
if path.name.startswith("__"):
|
||||
# ignore things like __pycache__
|
||||
return True
|
||||
return path.name not in (".gitignore", "uploaded_files")
|
||||
|
||||
reload_paths = (
|
||||
tuple(
|
||||
path.absolute()
|
||||
for dir in reload_paths
|
||||
for path in dir.iterdir()
|
||||
if not is_excluded_by_default(path)
|
||||
)
|
||||
+ include_dirs
|
||||
)
|
||||
|
||||
if exclude_dirs:
|
||||
reload_paths = tuple(
|
||||
path
|
||||
for path in reload_paths
|
||||
if all(not path.samefile(exclude) for exclude in exclude_dirs)
|
||||
)
|
||||
|
||||
return reload_paths
|
||||
|
||||
|
||||
def run_uvicorn_backend(host: str, port: int, loglevel: LogLevel):
|
||||
@ -283,7 +319,7 @@ def run_uvicorn_backend(host: str, port: int, loglevel: LogLevel):
|
||||
port=port,
|
||||
log_level=loglevel.value,
|
||||
reload=True,
|
||||
reload_dirs=list(map(str, get_reload_dirs())),
|
||||
reload_dirs=list(map(str, get_reload_paths())),
|
||||
)
|
||||
|
||||
|
||||
@ -310,8 +346,7 @@ def run_granian_backend(host: str, port: int, loglevel: LogLevel):
|
||||
interface=Interfaces.ASGI,
|
||||
log_level=LogLevels(loglevel.value),
|
||||
reload=True,
|
||||
reload_paths=get_reload_dirs(),
|
||||
reload_ignore_dirs=[".web", ".states"],
|
||||
reload_paths=get_reload_paths(),
|
||||
).serve()
|
||||
except ImportError:
|
||||
console.error(
|
||||
|
@ -252,6 +252,7 @@ def test_env_var():
|
||||
BLUBB: EnvVar[str] = env_var("default")
|
||||
INTERNAL: EnvVar[str] = env_var("default", internal=True)
|
||||
BOOLEAN: EnvVar[bool] = env_var(False)
|
||||
LIST: EnvVar[list[int]] = env_var([1, 2, 3])
|
||||
|
||||
assert TestEnv.BLUBB.get() == "default"
|
||||
assert TestEnv.BLUBB.name == "BLUBB"
|
||||
@ -280,3 +281,11 @@ def test_env_var():
|
||||
assert TestEnv.BOOLEAN.get() is False
|
||||
TestEnv.BOOLEAN.set(None)
|
||||
assert "BOOLEAN" not in os.environ
|
||||
|
||||
assert TestEnv.LIST.get() == [1, 2, 3]
|
||||
assert TestEnv.LIST.name == "LIST"
|
||||
TestEnv.LIST.set([4, 5, 6])
|
||||
assert os.environ.get("LIST") == "4:5:6"
|
||||
assert TestEnv.LIST.get() == [4, 5, 6]
|
||||
TestEnv.LIST.set(None)
|
||||
assert "LIST" not in os.environ
|
||||
|
Loading…
Reference in New Issue
Block a user