From deb1f4f702e3207d3ebe32b3637759cef26e9b35 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 19 Feb 2025 12:43:20 -0800 Subject: [PATCH] [ENG-4713] Cache pages which add states when evaluating (#4788) * cache order of imports that create BaseState subclasses * Track which pages create State subclasses during evaluation These need to be replayed on the backend to ensure state alignment. * Clean up: use constants, remove unused code Handle closing files with contextmanager * Expose app.add_all_routes_endpoint for flexgen * Include .web/backend directory in backend.zip when exporting --- reflex/app.py | 54 +++++++++++++++++++++++++++++++++-- reflex/config.py | 3 ++ reflex/constants/base.py | 6 ++++ reflex/constants/event.py | 1 + reflex/state.py | 6 ++++ reflex/utils/build.py | 12 ++++++++ reflex/utils/prerequisites.py | 9 ++++++ 7 files changed, 89 insertions(+), 2 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 03382751a..65cb5bfdf 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -100,6 +100,7 @@ from reflex.state import ( StateManager, StateUpdate, _substate_key, + all_base_state_classes, code_uses_state_contexts, ) from reflex.utils import ( @@ -117,6 +118,7 @@ from reflex.utils.imports import ImportVar if TYPE_CHECKING: from reflex.vars import Var + # Define custom types. ComponentCallable = Callable[[], Component] Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]] @@ -375,6 +377,9 @@ class App(MiddlewareMixin, LifespanMixin): # A map from a page route to the component to render. Users should use `add_page`. _pages: Dict[str, Component] = dataclasses.field(default_factory=dict) + # A mapping of pages which created states as they were being evaluated. + _stateful_pages: Dict[str, None] = dataclasses.field(default_factory=dict) + # The backend API object. _api: FastAPI | None = None @@ -592,8 +597,10 @@ class App(MiddlewareMixin, LifespanMixin): """Add optional api endpoints (_upload).""" if not self.api: return - - if Upload.is_used: + upload_is_used_marker = ( + prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED + ) + if Upload.is_used or upload_is_used_marker.exists(): # To upload files. self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) @@ -603,10 +610,15 @@ class App(MiddlewareMixin, LifespanMixin): StaticFiles(directory=get_upload_dir()), name="uploaded_files", ) + + upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True) + upload_is_used_marker.touch() if codespaces.is_running_in_codespaces(): self.api.get(str(constants.Endpoint.AUTH_CODESPACE))( codespaces.auth_codespace ) + if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get(): + self.add_all_routes_endpoint() def _add_cors(self): """Add CORS middleware to the app.""" @@ -747,13 +759,19 @@ class App(MiddlewareMixin, LifespanMixin): route: The route of the page to compile. save_page: If True, the compiled page is saved to self._pages. """ + n_states_before = len(all_base_state_classes) component, enable_state = compiler.compile_unevaluated_page( route, self._unevaluated_pages[route], self._state, self.style, self.theme ) + # Indicate that the app should use state. if enable_state: self._enable_state() + # Indicate that evaluating this page creates one or more state classes. + if len(all_base_state_classes) > n_states_before: + self._stateful_pages[route] = None + # Add the page. self._check_routes_conflict(route) if save_page: @@ -1042,6 +1060,20 @@ class App(MiddlewareMixin, LifespanMixin): def get_compilation_time() -> str: return str(datetime.now().time()).split(".")[0] + should_compile = self._should_compile() + backend_dir = prerequisites.get_backend_dir() + if not should_compile and backend_dir.exists(): + stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES + if stateful_pages_marker.exists(): + with stateful_pages_marker.open("r") as f: + stateful_pages = json.load(f) + for route in stateful_pages: + console.info(f"BE Evaluating stateful page: {route}") + self._compile_page(route, save_page=False) + self._enable_state() + self._add_optional_endpoints() + return + # Render a default 404 page if the user didn't supply one if constants.Page404.SLUG not in self._unevaluated_pages: self.add_page(route=constants.Page404.SLUG) @@ -1343,6 +1375,24 @@ class App(MiddlewareMixin, LifespanMixin): for output_path, code in compile_results: compiler_utils.write_page(output_path, code) + # Write list of routes that create dynamic states for backend to use. + if self._state is not None: + stateful_pages_marker = ( + prerequisites.get_backend_dir() / constants.Dirs.STATEFUL_PAGES + ) + stateful_pages_marker.parent.mkdir(parents=True, exist_ok=True) + with stateful_pages_marker.open("w") as f: + json.dump(list(self._stateful_pages), f) + + def add_all_routes_endpoint(self): + """Add an endpoint to the app that returns all the routes.""" + if not self.api: + return + + @self.api.get(str(constants.Endpoint.ALL_ROUTES)) + async def all_routes(): + return list(self._unevaluated_pages.keys()) + @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state out of band. diff --git a/reflex/config.py b/reflex/config.py index 296b01805..590b57f46 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -713,6 +713,9 @@ class EnvironmentVariables: # Paths to exclude from the hot reload. Takes precedence over include paths. Separated by a colon. REFLEX_HOT_RELOAD_EXCLUDE_PATHS: EnvVar[List[Path]] = env_var([]) + # Used by flexgen to enumerate the pages. + REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False) + environment = EnvironmentVariables() diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 7fbcdf18a..0611c7d4c 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -53,6 +53,12 @@ class Dirs(SimpleNamespace): POSTCSS_JS = "postcss.config.js" # The name of the states directory. STATES = ".states" + # Where compilation artifacts for the backend are stored. + BACKEND = "backend" + # JSON-encoded list of page routes that need to be evaluated on the backend. + STATEFUL_PAGES = "stateful_pages.json" + # Marker file indicating that upload component was used in the frontend. + UPLOAD_IS_USED = "upload_is_used" class Reflex(SimpleNamespace): diff --git a/reflex/constants/event.py b/reflex/constants/event.py index d454e6ea8..7b58c99cf 100644 --- a/reflex/constants/event.py +++ b/reflex/constants/event.py @@ -12,6 +12,7 @@ class Endpoint(Enum): UPLOAD = "_upload" AUTH_CODESPACE = "auth-codespace" HEALTH = "_health" + ALL_ROUTES = "_all_routes" def __str__(self) -> str: """Get the string representation of the endpoint. diff --git a/reflex/state.py b/reflex/state.py index 2689ba910..0f0ba97f9 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -327,6 +327,9 @@ async def _resolve_delta(delta: Delta) -> Delta: return delta +all_base_state_classes: dict[str, None] = {} + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -624,6 +627,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls._var_dependencies = {} cls._init_var_dependency_dicts() + all_base_state_classes[cls.get_full_name()] = None + @staticmethod def _copy_fn(fn: Callable) -> Callable: """Copy a function. Used to copy ComputedVars and EventHandlers from mixins. @@ -4087,6 +4092,7 @@ def reload_state_module( for subclass in tuple(state.class_subclasses): reload_state_module(module=module, state=subclass) if subclass.__module__ == module and module is not None: + all_base_state_classes.pop(subclass.get_full_name(), None) state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) state._var_dependencies = {} diff --git a/reflex/utils/build.py b/reflex/utils/build.py index 9e35ab984..c02a30c7b 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -60,6 +60,7 @@ def _zip( dirs_to_exclude: set[str] | None = None, files_to_exclude: set[str] | None = None, top_level_dirs_to_exclude: set[str] | None = None, + globs_to_include: list[str] | None = None, ) -> None: """Zip utility function. @@ -72,6 +73,7 @@ def _zip( dirs_to_exclude: The directories to exclude. files_to_exclude: The files to exclude. top_level_dirs_to_exclude: The top level directory names immediately under root_dir to exclude. Do not exclude folders by these names further in the sub-directories. + globs_to_include: Apply these globs from the root_dir and always include them in the zip. """ target = Path(target) @@ -103,6 +105,13 @@ def _zip( files_to_zip += [ str(root / file) for file in files if file not in files_to_exclude ] + if globs_to_include: + for glob in globs_to_include: + files_to_zip += [ + str(file) + for file in root_dir.glob(glob) + if file.name not in files_to_exclude + ] # Create a progress bar for zipping the component. progress = Progress( @@ -160,6 +169,9 @@ def zip_app( top_level_dirs_to_exclude={"assets"}, exclude_venv_dirs=True, upload_db_file=upload_db_file, + globs_to_include=[ + str(Path(constants.Dirs.WEB) / constants.Dirs.BACKEND / "*") + ], ) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 145b5324c..b5987f4e8 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -99,6 +99,15 @@ def get_states_dir() -> Path: return environment.REFLEX_STATES_WORKDIR.get() +def get_backend_dir() -> Path: + """Get the working directory for the backend. + + Returns: + The working directory. + """ + return get_web_dir() / constants.Dirs.BACKEND + + def check_latest_package_version(package_name: str): """Check if the latest version of the package is installed.