[ENG-4713] Cache pages which add states when evaluating (#4788)

* cache order of imports that create BaseState subclasses

* Track which pages create State subclasses during evaluation

These need to be replayed on the backend to ensure state alignment.

* Clean up: use constants, remove unused code

Handle closing files with contextmanager

* Expose app.add_all_routes_endpoint for flexgen

* Include .web/backend directory in backend.zip when exporting
This commit is contained in:
Masen Furer 2025-02-19 12:43:20 -08:00 committed by GitHub
parent 7a6c7123bd
commit deb1f4f702
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 89 additions and 2 deletions

View File

@ -100,6 +100,7 @@ from reflex.state import (
StateManager,
StateUpdate,
_substate_key,
all_base_state_classes,
code_uses_state_contexts,
)
from reflex.utils import (
@ -117,6 +118,7 @@ from reflex.utils.imports import ImportVar
if TYPE_CHECKING:
from reflex.vars import Var
# Define custom types.
ComponentCallable = Callable[[], Component]
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
@ -375,6 +377,9 @@ class App(MiddlewareMixin, LifespanMixin):
# A map from a page route to the component to render. Users should use `add_page`.
_pages: Dict[str, Component] = dataclasses.field(default_factory=dict)
# A mapping of pages which created states as they were being evaluated.
_stateful_pages: Dict[str, None] = dataclasses.field(default_factory=dict)
# The backend API object.
_api: FastAPI | None = None
@ -592,8 +597,10 @@ class App(MiddlewareMixin, LifespanMixin):
"""Add optional api endpoints (_upload)."""
if not self.api:
return
if Upload.is_used:
upload_is_used_marker = (
prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED
)
if Upload.is_used or upload_is_used_marker.exists():
# To upload files.
self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
@ -603,10 +610,15 @@ class App(MiddlewareMixin, LifespanMixin):
StaticFiles(directory=get_upload_dir()),
name="uploaded_files",
)
upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
upload_is_used_marker.touch()
if codespaces.is_running_in_codespaces():
self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
codespaces.auth_codespace
)
if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
self.add_all_routes_endpoint()
def _add_cors(self):
"""Add CORS middleware to the app."""
@ -747,13 +759,19 @@ class App(MiddlewareMixin, LifespanMixin):
route: The route of the page to compile.
save_page: If True, the compiled page is saved to self._pages.
"""
n_states_before = len(all_base_state_classes)
component, enable_state = compiler.compile_unevaluated_page(
route, self._unevaluated_pages[route], self._state, self.style, self.theme
)
# Indicate that the app should use state.
if enable_state:
self._enable_state()
# Indicate that evaluating this page creates one or more state classes.
if len(all_base_state_classes) > n_states_before:
self._stateful_pages[route] = None
# Add the page.
self._check_routes_conflict(route)
if save_page:
@ -1042,6 +1060,20 @@ class App(MiddlewareMixin, LifespanMixin):
def get_compilation_time() -> str:
return str(datetime.now().time()).split(".")[0]
should_compile = self._should_compile()
backend_dir = prerequisites.get_backend_dir()
if not should_compile and backend_dir.exists():
stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES
if stateful_pages_marker.exists():
with stateful_pages_marker.open("r") as f:
stateful_pages = json.load(f)
for route in stateful_pages:
console.info(f"BE Evaluating stateful page: {route}")
self._compile_page(route, save_page=False)
self._enable_state()
self._add_optional_endpoints()
return
# Render a default 404 page if the user didn't supply one
if constants.Page404.SLUG not in self._unevaluated_pages:
self.add_page(route=constants.Page404.SLUG)
@ -1343,6 +1375,24 @@ class App(MiddlewareMixin, LifespanMixin):
for output_path, code in compile_results:
compiler_utils.write_page(output_path, code)
# Write list of routes that create dynamic states for backend to use.
if self._state is not None:
stateful_pages_marker = (
prerequisites.get_backend_dir() / constants.Dirs.STATEFUL_PAGES
)
stateful_pages_marker.parent.mkdir(parents=True, exist_ok=True)
with stateful_pages_marker.open("w") as f:
json.dump(list(self._stateful_pages), f)
def add_all_routes_endpoint(self):
"""Add an endpoint to the app that returns all the routes."""
if not self.api:
return
@self.api.get(str(constants.Endpoint.ALL_ROUTES))
async def all_routes():
return list(self._unevaluated_pages.keys())
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state out of band.

View File

@ -713,6 +713,9 @@ class EnvironmentVariables:
# Paths to exclude from the hot reload. Takes precedence over include paths. Separated by a colon.
REFLEX_HOT_RELOAD_EXCLUDE_PATHS: EnvVar[List[Path]] = env_var([])
# Used by flexgen to enumerate the pages.
REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False)
environment = EnvironmentVariables()

View File

@ -53,6 +53,12 @@ class Dirs(SimpleNamespace):
POSTCSS_JS = "postcss.config.js"
# The name of the states directory.
STATES = ".states"
# Where compilation artifacts for the backend are stored.
BACKEND = "backend"
# JSON-encoded list of page routes that need to be evaluated on the backend.
STATEFUL_PAGES = "stateful_pages.json"
# Marker file indicating that upload component was used in the frontend.
UPLOAD_IS_USED = "upload_is_used"
class Reflex(SimpleNamespace):

View File

@ -12,6 +12,7 @@ class Endpoint(Enum):
UPLOAD = "_upload"
AUTH_CODESPACE = "auth-codespace"
HEALTH = "_health"
ALL_ROUTES = "_all_routes"
def __str__(self) -> str:
"""Get the string representation of the endpoint.

View File

@ -327,6 +327,9 @@ async def _resolve_delta(delta: Delta) -> Delta:
return delta
all_base_state_classes: dict[str, None] = {}
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""
@ -624,6 +627,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cls._var_dependencies = {}
cls._init_var_dependency_dicts()
all_base_state_classes[cls.get_full_name()] = None
@staticmethod
def _copy_fn(fn: Callable) -> Callable:
"""Copy a function. Used to copy ComputedVars and EventHandlers from mixins.
@ -4087,6 +4092,7 @@ def reload_state_module(
for subclass in tuple(state.class_subclasses):
reload_state_module(module=module, state=subclass)
if subclass.__module__ == module and module is not None:
all_base_state_classes.pop(subclass.get_full_name(), None)
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
state._var_dependencies = {}

View File

@ -60,6 +60,7 @@ def _zip(
dirs_to_exclude: set[str] | None = None,
files_to_exclude: set[str] | None = None,
top_level_dirs_to_exclude: set[str] | None = None,
globs_to_include: list[str] | None = None,
) -> None:
"""Zip utility function.
@ -72,6 +73,7 @@ def _zip(
dirs_to_exclude: The directories to exclude.
files_to_exclude: The files to exclude.
top_level_dirs_to_exclude: The top level directory names immediately under root_dir to exclude. Do not exclude folders by these names further in the sub-directories.
globs_to_include: Apply these globs from the root_dir and always include them in the zip.
"""
target = Path(target)
@ -103,6 +105,13 @@ def _zip(
files_to_zip += [
str(root / file) for file in files if file not in files_to_exclude
]
if globs_to_include:
for glob in globs_to_include:
files_to_zip += [
str(file)
for file in root_dir.glob(glob)
if file.name not in files_to_exclude
]
# Create a progress bar for zipping the component.
progress = Progress(
@ -160,6 +169,9 @@ def zip_app(
top_level_dirs_to_exclude={"assets"},
exclude_venv_dirs=True,
upload_db_file=upload_db_file,
globs_to_include=[
str(Path(constants.Dirs.WEB) / constants.Dirs.BACKEND / "*")
],
)

View File

@ -99,6 +99,15 @@ def get_states_dir() -> Path:
return environment.REFLEX_STATES_WORKDIR.get()
def get_backend_dir() -> Path:
"""Get the working directory for the backend.
Returns:
The working directory.
"""
return get_web_dir() / constants.Dirs.BACKEND
def check_latest_package_version(package_name: str):
"""Check if the latest version of the package is installed.