WiP - pickle dynamic states to bring backend up faster
This commit is contained in:
parent
4dc106545b
commit
52d98b125a
@ -13,11 +13,12 @@ import io
|
|||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import platform
|
import platform
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import FunctionType, SimpleNamespace
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -39,11 +40,13 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile
|
|||||||
from fastapi.middleware import cors
|
from fastapi.middleware import cors
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from rich.console import ConsoleThreadLocals
|
||||||
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
|
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
|
||||||
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
||||||
from starlette_admin.contrib.sqla.admin import Admin
|
from starlette_admin.contrib.sqla.admin import Admin
|
||||||
from starlette_admin.contrib.sqla.view import ModelView
|
from starlette_admin.contrib.sqla.view import ModelView
|
||||||
|
|
||||||
|
import reflex.istate.dynamic
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.admin import AdminDash
|
from reflex.admin import AdminDash
|
||||||
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
|
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
|
||||||
@ -97,10 +100,34 @@ from reflex.state import (
|
|||||||
from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
|
from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
|
||||||
from reflex.utils.exec import is_prod_mode, is_testing_env
|
from reflex.utils.exec import is_prod_mode, is_testing_env
|
||||||
from reflex.utils.imports import ImportVar
|
from reflex.utils.imports import ImportVar
|
||||||
|
from reflex.vars.base import ComputedVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from reflex.vars import Var
|
from reflex.vars import Var
|
||||||
|
|
||||||
|
try:
|
||||||
|
import dill
|
||||||
|
except ImportError:
|
||||||
|
dill = None
|
||||||
|
else:
|
||||||
|
# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
|
||||||
|
if not isinstance(State.validate.__func__, FunctionType):
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
cython_function_or_method = type(State.validate.__func__)
|
||||||
|
builtins.cython_function_or_method = cython_function_or_method
|
||||||
|
|
||||||
|
@dill.register(cython_function_or_method)
|
||||||
|
def _dill_reduce_cython_function_or_method(pickler, obj):
|
||||||
|
# Ignore cython function when pickling.
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dill.register(ConsoleThreadLocals)
|
||||||
|
def _dill_reduce_console_thread_locals(pickler, obj):
|
||||||
|
# Ignore console thread locals when pickling.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Define custom types.
|
# Define custom types.
|
||||||
ComponentCallable = Callable[[], Component]
|
ComponentCallable = Callable[[], Component]
|
||||||
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
|
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
|
||||||
@ -337,6 +364,11 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
if not self.state:
|
if not self.state:
|
||||||
self.state = State
|
self.state = State
|
||||||
self._setup_state()
|
self._setup_state()
|
||||||
|
enable_state_marker = (
|
||||||
|
prerequisites.get_web_dir() / "backend" / "enable_state"
|
||||||
|
)
|
||||||
|
enable_state_marker.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
enable_state_marker.touch()
|
||||||
|
|
||||||
def _setup_state(self) -> None:
|
def _setup_state(self) -> None:
|
||||||
"""Set up the state for the app.
|
"""Set up the state for the app.
|
||||||
@ -415,7 +447,10 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
|
|
||||||
def _add_optional_endpoints(self):
|
def _add_optional_endpoints(self):
|
||||||
"""Add optional api endpoints (_upload)."""
|
"""Add optional api endpoints (_upload)."""
|
||||||
if Upload.is_used:
|
upload_is_used_marker = (
|
||||||
|
prerequisites.get_web_dir() / "backend" / "upload_is_used"
|
||||||
|
)
|
||||||
|
if Upload.is_used or upload_is_used_marker.exists():
|
||||||
# To upload files.
|
# To upload files.
|
||||||
self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
|
self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
|
||||||
|
|
||||||
@ -425,6 +460,9 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
StaticFiles(directory=get_upload_dir()),
|
StaticFiles(directory=get_upload_dir()),
|
||||||
name="uploaded_files",
|
name="uploaded_files",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
upload_is_used_marker.touch()
|
||||||
if codespaces.is_running_in_codespaces():
|
if codespaces.is_running_in_codespaces():
|
||||||
self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
|
self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
|
||||||
codespaces.auth_codespace
|
codespaces.auth_codespace
|
||||||
@ -856,6 +894,18 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
def get_compilation_time() -> str:
|
def get_compilation_time() -> str:
|
||||||
return str(datetime.now().time()).split(".")[0]
|
return str(datetime.now().time()).split(".")[0]
|
||||||
|
|
||||||
|
should_compile = self._should_compile()
|
||||||
|
backend_dir = prerequisites.get_web_dir() / "backend"
|
||||||
|
if not should_compile and backend_dir.exists():
|
||||||
|
enable_state_marker = backend_dir / "enable_state"
|
||||||
|
if enable_state_marker.exists():
|
||||||
|
self._enable_state()
|
||||||
|
pickle_states_root = backend_dir / "states"
|
||||||
|
if pickle_states_root.exists():
|
||||||
|
self._unpickle_dynamic_states(pickle_states_root)
|
||||||
|
self._add_optional_endpoints()
|
||||||
|
return
|
||||||
|
|
||||||
# Render a default 404 page if the user didn't supply one
|
# Render a default 404 page if the user didn't supply one
|
||||||
if constants.Page404.SLUG not in self.unevaluated_pages:
|
if constants.Page404.SLUG not in self.unevaluated_pages:
|
||||||
self.add_page(route=constants.Page404.SLUG)
|
self.add_page(route=constants.Page404.SLUG)
|
||||||
@ -1077,6 +1127,37 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
for output_path, code in compile_results:
|
for output_path, code in compile_results:
|
||||||
compiler_utils.write_page(output_path, code)
|
compiler_utils.write_page(output_path, code)
|
||||||
|
|
||||||
|
# Pickle dynamic states
|
||||||
|
if self.state is not None and dill is not None:
|
||||||
|
pickle_dir = prerequisites.get_web_dir() / "backend" / "states"
|
||||||
|
if pickle_dir.exists():
|
||||||
|
shutil.rmtree(pickle_dir)
|
||||||
|
pickle_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
unfuck_states = []
|
||||||
|
for state in reflex.istate.dynamic.__dict__.values():
|
||||||
|
if isinstance(state, type) and issubclass(state, self.state):
|
||||||
|
unfuck_states.append(state)
|
||||||
|
object.__setattr__(state.setvar, "state_cls", None)
|
||||||
|
ComputedVar._is_pickling = True
|
||||||
|
try:
|
||||||
|
dill.session.dump_session(
|
||||||
|
filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
with dill.detect.trace():
|
||||||
|
dill.session.dump_session(
|
||||||
|
filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic
|
||||||
|
)
|
||||||
|
ComputedVar._is_pickling = False
|
||||||
|
for state in unfuck_states:
|
||||||
|
object.__setattr__(state.setvar, "state_cls", state)
|
||||||
|
|
||||||
|
def _unpickle_dynamic_states(self, root: Path):
|
||||||
|
if dill is None:
|
||||||
|
raise ImportError("dill is required to unpickle dynamic states")
|
||||||
|
for pk_file in sorted(root.iterdir()):
|
||||||
|
dill.session.load_session(filename=pk_file, main=reflex.istate.dynamic)
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||||
"""Modify the state out of band.
|
"""Modify the state out of band.
|
||||||
|
@ -1834,6 +1834,9 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
default_factory=lambda: lambda _: None
|
default_factory=lambda: lambda _: None
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
# Flag determines whether we are pickling the computed var itself
|
||||||
|
_is_pickling: ClassVar[bool] = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fget: Callable[[BASE_STATE], RETURN_TYPE],
|
fget: Callable[[BASE_STATE], RETURN_TYPE],
|
||||||
@ -2227,6 +2230,8 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
Returns:
|
Returns:
|
||||||
The class of the var.
|
The class of the var.
|
||||||
"""
|
"""
|
||||||
|
if self._is_pickling:
|
||||||
|
return type(self)
|
||||||
return FakeComputedVarBaseClass
|
return FakeComputedVarBaseClass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user