From 52d98b125af113c6c0a7e592fcf2d4d064d9f681 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Jan 2025 14:20:00 -0800 Subject: [PATCH] WiP - pickle dynamic states to bring backend up faster --- reflex/app.py | 85 +++++++++++++++++++++++++++++++++++++++++++-- reflex/vars/base.py | 5 +++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 0d672e4c0..507e8ad86 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -13,11 +13,12 @@ import io import json import multiprocessing import platform +import shutil import sys import traceback from datetime import datetime from pathlib import Path -from types import SimpleNamespace +from types import FunctionType, SimpleNamespace from typing import ( TYPE_CHECKING, Any, @@ -39,11 +40,13 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.middleware import cors from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles +from rich.console import ConsoleThreadLocals from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp, AsyncNamespace, AsyncServer from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.view import ModelView +import reflex.istate.dynamic from reflex import constants from reflex.admin import AdminDash from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin @@ -97,10 +100,34 @@ from reflex.state import ( from reflex.utils import codespaces, console, exceptions, format, prerequisites, types from reflex.utils.exec import is_prod_mode, is_testing_env from reflex.utils.imports import ImportVar +from reflex.vars.base import ComputedVar if TYPE_CHECKING: from reflex.vars import Var +try: + import dill +except ImportError: + dill = None +else: + # Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes + if not isinstance(State.validate.__func__, FunctionType): + import builtins + + cython_function_or_method = type(State.validate.__func__) + builtins.cython_function_or_method = cython_function_or_method + + @dill.register(cython_function_or_method) + def _dill_reduce_cython_function_or_method(pickler, obj): + # Ignore cython function when pickling. + pass + + @dill.register(ConsoleThreadLocals) + def _dill_reduce_console_thread_locals(pickler, obj): + # Ignore console thread locals when pickling. + pass + + # Define custom types. ComponentCallable = Callable[[], Component] Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]] @@ -337,6 +364,11 @@ class App(MiddlewareMixin, LifespanMixin): if not self.state: self.state = State self._setup_state() + enable_state_marker = ( + prerequisites.get_web_dir() / "backend" / "enable_state" + ) + enable_state_marker.parent.mkdir(parents=True, exist_ok=True) + enable_state_marker.touch() def _setup_state(self) -> None: """Set up the state for the app. @@ -415,7 +447,10 @@ class App(MiddlewareMixin, LifespanMixin): def _add_optional_endpoints(self): """Add optional api endpoints (_upload).""" - if Upload.is_used: + upload_is_used_marker = ( + prerequisites.get_web_dir() / "backend" / "upload_is_used" + ) + if Upload.is_used or upload_is_used_marker.exists(): # To upload files. self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) @@ -425,6 +460,9 @@ class App(MiddlewareMixin, LifespanMixin): StaticFiles(directory=get_upload_dir()), name="uploaded_files", ) + + upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True) + upload_is_used_marker.touch() if codespaces.is_running_in_codespaces(): self.api.get(str(constants.Endpoint.AUTH_CODESPACE))( codespaces.auth_codespace @@ -856,6 +894,18 @@ class App(MiddlewareMixin, LifespanMixin): def get_compilation_time() -> str: return str(datetime.now().time()).split(".")[0] + should_compile = self._should_compile() + backend_dir = prerequisites.get_web_dir() / "backend" + if not should_compile and backend_dir.exists(): + enable_state_marker = backend_dir / "enable_state" + if enable_state_marker.exists(): + self._enable_state() + pickle_states_root = backend_dir / "states" + if pickle_states_root.exists(): + self._unpickle_dynamic_states(pickle_states_root) + self._add_optional_endpoints() + return + # Render a default 404 page if the user didn't supply one if constants.Page404.SLUG not in self.unevaluated_pages: self.add_page(route=constants.Page404.SLUG) @@ -1077,6 +1127,37 @@ class App(MiddlewareMixin, LifespanMixin): for output_path, code in compile_results: compiler_utils.write_page(output_path, code) + # Pickle dynamic states + if self.state is not None and dill is not None: + pickle_dir = prerequisites.get_web_dir() / "backend" / "states" + if pickle_dir.exists(): + shutil.rmtree(pickle_dir) + pickle_dir.mkdir(parents=True, exist_ok=True) + unfuck_states = [] + for state in reflex.istate.dynamic.__dict__.values(): + if isinstance(state, type) and issubclass(state, self.state): + unfuck_states.append(state) + object.__setattr__(state.setvar, "state_cls", None) + ComputedVar._is_pickling = True + try: + dill.session.dump_session( + filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic + ) + except TypeError: + with dill.detect.trace(): + dill.session.dump_session( + filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic + ) + ComputedVar._is_pickling = False + for state in unfuck_states: + object.__setattr__(state.setvar, "state_cls", state) + + def _unpickle_dynamic_states(self, root: Path): + if dill is None: + raise ImportError("dill is required to unpickle dynamic states") + for pk_file in sorted(root.iterdir()): + dill.session.load_session(filename=pk_file, main=reflex.istate.dynamic) + @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state out of band. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2892d004d..f502e566a 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1834,6 +1834,9 @@ class ComputedVar(Var[RETURN_TYPE]): default_factory=lambda: lambda _: None ) # type: ignore + # Flag determines whether we are pickling the computed var itself + _is_pickling: ClassVar[bool] = False + def __init__( self, fget: Callable[[BASE_STATE], RETURN_TYPE], @@ -2227,6 +2230,8 @@ class ComputedVar(Var[RETURN_TYPE]): Returns: The class of the var. """ + if self._is_pickling: + return type(self) return FakeComputedVarBaseClass @property