WiP - pickle dynamic states to bring backend up faster

This commit is contained in:
Masen Furer 2025-01-20 14:20:00 -08:00
parent 4dc106545b
commit 52d98b125a
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
2 changed files with 88 additions and 2 deletions

View File

@ -13,11 +13,12 @@ import io
import json import json
import multiprocessing import multiprocessing
import platform import platform
import shutil
import sys import sys
import traceback import traceback
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import FunctionType, SimpleNamespace
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -39,11 +40,13 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.middleware import cors from fastapi.middleware import cors
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from rich.console import ConsoleThreadLocals
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from socketio import ASGIApp, AsyncNamespace, AsyncServer from socketio import ASGIApp, AsyncNamespace, AsyncServer
from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.admin import Admin
from starlette_admin.contrib.sqla.view import ModelView from starlette_admin.contrib.sqla.view import ModelView
import reflex.istate.dynamic
from reflex import constants from reflex import constants
from reflex.admin import AdminDash from reflex.admin import AdminDash
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
@ -97,10 +100,34 @@ from reflex.state import (
from reflex.utils import codespaces, console, exceptions, format, prerequisites, types from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
from reflex.utils.exec import is_prod_mode, is_testing_env from reflex.utils.exec import is_prod_mode, is_testing_env
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
from reflex.vars.base import ComputedVar
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.vars import Var from reflex.vars import Var
try:
import dill
except ImportError:
dill = None
else:
# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
if not isinstance(State.validate.__func__, FunctionType):
import builtins
cython_function_or_method = type(State.validate.__func__)
builtins.cython_function_or_method = cython_function_or_method
@dill.register(cython_function_or_method)
def _dill_reduce_cython_function_or_method(pickler, obj):
# Ignore cython function when pickling.
pass
@dill.register(ConsoleThreadLocals)
def _dill_reduce_console_thread_locals(pickler, obj):
# Ignore console thread locals when pickling.
pass
# Define custom types. # Define custom types.
ComponentCallable = Callable[[], Component] ComponentCallable = Callable[[], Component]
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]] Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
@ -337,6 +364,11 @@ class App(MiddlewareMixin, LifespanMixin):
if not self.state: if not self.state:
self.state = State self.state = State
self._setup_state() self._setup_state()
enable_state_marker = (
prerequisites.get_web_dir() / "backend" / "enable_state"
)
enable_state_marker.parent.mkdir(parents=True, exist_ok=True)
enable_state_marker.touch()
def _setup_state(self) -> None: def _setup_state(self) -> None:
"""Set up the state for the app. """Set up the state for the app.
@ -415,7 +447,10 @@ class App(MiddlewareMixin, LifespanMixin):
def _add_optional_endpoints(self): def _add_optional_endpoints(self):
"""Add optional api endpoints (_upload).""" """Add optional api endpoints (_upload)."""
if Upload.is_used: upload_is_used_marker = (
prerequisites.get_web_dir() / "backend" / "upload_is_used"
)
if Upload.is_used or upload_is_used_marker.exists():
# To upload files. # To upload files.
self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
@ -425,6 +460,9 @@ class App(MiddlewareMixin, LifespanMixin):
StaticFiles(directory=get_upload_dir()), StaticFiles(directory=get_upload_dir()),
name="uploaded_files", name="uploaded_files",
) )
upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
upload_is_used_marker.touch()
if codespaces.is_running_in_codespaces(): if codespaces.is_running_in_codespaces():
self.api.get(str(constants.Endpoint.AUTH_CODESPACE))( self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
codespaces.auth_codespace codespaces.auth_codespace
@ -856,6 +894,18 @@ class App(MiddlewareMixin, LifespanMixin):
def get_compilation_time() -> str: def get_compilation_time() -> str:
return str(datetime.now().time()).split(".")[0] return str(datetime.now().time()).split(".")[0]
should_compile = self._should_compile()
backend_dir = prerequisites.get_web_dir() / "backend"
if not should_compile and backend_dir.exists():
enable_state_marker = backend_dir / "enable_state"
if enable_state_marker.exists():
self._enable_state()
pickle_states_root = backend_dir / "states"
if pickle_states_root.exists():
self._unpickle_dynamic_states(pickle_states_root)
self._add_optional_endpoints()
return
# Render a default 404 page if the user didn't supply one # Render a default 404 page if the user didn't supply one
if constants.Page404.SLUG not in self.unevaluated_pages: if constants.Page404.SLUG not in self.unevaluated_pages:
self.add_page(route=constants.Page404.SLUG) self.add_page(route=constants.Page404.SLUG)
@ -1077,6 +1127,37 @@ class App(MiddlewareMixin, LifespanMixin):
for output_path, code in compile_results: for output_path, code in compile_results:
compiler_utils.write_page(output_path, code) compiler_utils.write_page(output_path, code)
# Pickle dynamic states
if self.state is not None and dill is not None:
pickle_dir = prerequisites.get_web_dir() / "backend" / "states"
if pickle_dir.exists():
shutil.rmtree(pickle_dir)
pickle_dir.mkdir(parents=True, exist_ok=True)
unfuck_states = []
for state in reflex.istate.dynamic.__dict__.values():
if isinstance(state, type) and issubclass(state, self.state):
unfuck_states.append(state)
object.__setattr__(state.setvar, "state_cls", None)
ComputedVar._is_pickling = True
try:
dill.session.dump_session(
filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic
)
except TypeError:
with dill.detect.trace():
dill.session.dump_session(
filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic
)
ComputedVar._is_pickling = False
for state in unfuck_states:
object.__setattr__(state.setvar, "state_cls", state)
def _unpickle_dynamic_states(self, root: Path):
if dill is None:
raise ImportError("dill is required to unpickle dynamic states")
for pk_file in sorted(root.iterdir()):
dill.session.load_session(filename=pk_file, main=reflex.istate.dynamic)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]: async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state out of band. """Modify the state out of band.

View File

@ -1834,6 +1834,9 @@ class ComputedVar(Var[RETURN_TYPE]):
default_factory=lambda: lambda _: None default_factory=lambda: lambda _: None
) # type: ignore ) # type: ignore
# Flag determines whether we are pickling the computed var itself
_is_pickling: ClassVar[bool] = False
def __init__( def __init__(
self, self,
fget: Callable[[BASE_STATE], RETURN_TYPE], fget: Callable[[BASE_STATE], RETURN_TYPE],
@ -2227,6 +2230,8 @@ class ComputedVar(Var[RETURN_TYPE]):
Returns: Returns:
The class of the var. The class of the var.
""" """
if self._is_pickling:
return type(self)
return FakeComputedVarBaseClass return FakeComputedVarBaseClass
@property @property