cache order of imports that create BaseState subclasses
This commit is contained in:
parent
6f4d328cde
commit
8fe5798f73
@ -8,16 +8,19 @@ import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import multiprocessing
|
||||
import pickle
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from types import FunctionType, SimpleNamespace
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -39,11 +42,13 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile
|
||||
from fastapi.middleware import cors
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from rich.console import ConsoleThreadLocals
|
||||
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
|
||||
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
||||
from starlette_admin.contrib.sqla.admin import Admin
|
||||
from starlette_admin.contrib.sqla.view import ModelView
|
||||
|
||||
import reflex.istate.dynamic
|
||||
from reflex import constants
|
||||
from reflex.admin import AdminDash
|
||||
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
|
||||
@ -92,6 +97,7 @@ from reflex.route import (
|
||||
)
|
||||
from reflex.state import (
|
||||
BaseState,
|
||||
BaseState_import_order,
|
||||
RouterData,
|
||||
State,
|
||||
StateManager,
|
||||
@ -102,10 +108,34 @@ from reflex.state import (
|
||||
from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
|
||||
from reflex.utils.exec import is_prod_mode, is_testing_env
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars.base import ComputedVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.vars import Var
|
||||
|
||||
try:
|
||||
import dill
|
||||
except ImportError:
|
||||
dill = None
|
||||
else:
|
||||
# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
|
||||
if not isinstance(State.validate.__func__, FunctionType):
|
||||
import builtins
|
||||
|
||||
cython_function_or_method = type(State.validate.__func__)
|
||||
builtins.cython_function_or_method = cython_function_or_method
|
||||
|
||||
@dill.register(cython_function_or_method)
|
||||
def _dill_reduce_cython_function_or_method(pickler, obj):
|
||||
# Ignore cython function when pickling.
|
||||
pass
|
||||
|
||||
@dill.register(ConsoleThreadLocals)
|
||||
def _dill_reduce_console_thread_locals(pickler, obj):
|
||||
# Ignore console thread locals when pickling.
|
||||
pass
|
||||
|
||||
|
||||
# Define custom types.
|
||||
ComponentCallable = Callable[[], Component]
|
||||
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
|
||||
@ -380,6 +410,11 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
if not self._state:
|
||||
self._state = State
|
||||
self._setup_state()
|
||||
enable_state_marker = (
|
||||
prerequisites.get_web_dir() / "backend" / "enable_state"
|
||||
)
|
||||
enable_state_marker.parent.mkdir(parents=True, exist_ok=True)
|
||||
enable_state_marker.touch()
|
||||
|
||||
def _setup_state(self) -> None:
|
||||
"""Set up the state for the app.
|
||||
@ -498,8 +533,10 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
"""Add optional api endpoints (_upload)."""
|
||||
if not self.api:
|
||||
return
|
||||
|
||||
if Upload.is_used:
|
||||
upload_is_used_marker = (
|
||||
prerequisites.get_web_dir() / "backend" / "upload_is_used"
|
||||
)
|
||||
if Upload.is_used or upload_is_used_marker.exists():
|
||||
# To upload files.
|
||||
self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
|
||||
|
||||
@ -509,6 +546,9 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
StaticFiles(directory=get_upload_dir()),
|
||||
name="uploaded_files",
|
||||
)
|
||||
|
||||
upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
|
||||
upload_is_used_marker.touch()
|
||||
if codespaces.is_running_in_codespaces():
|
||||
self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
|
||||
codespaces.auth_codespace
|
||||
@ -965,6 +1005,23 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
def get_compilation_time() -> str:
|
||||
return str(datetime.now().time()).split(".")[0]
|
||||
|
||||
should_compile = self._should_compile()
|
||||
backend_dir = prerequisites.get_web_dir() / "backend"
|
||||
if not should_compile and backend_dir.exists():
|
||||
enable_state_marker = backend_dir / "enable_state"
|
||||
if enable_state_marker.exists():
|
||||
import_order = pickle.loads(enable_state_marker.read_bytes())
|
||||
for bs_import in import_order:
|
||||
if bs_import.is_module:
|
||||
print(f"BE Importing stateful module: {bs_import.identifier}")
|
||||
importlib.import_module(bs_import.identifier)
|
||||
else:
|
||||
print(f"BE Evaluating stateful page: {bs_import.identifier}")
|
||||
self._compile_page(bs_import.identifier, save_page=False)
|
||||
self._enable_state()
|
||||
self._add_optional_endpoints()
|
||||
return
|
||||
|
||||
# Render a default 404 page if the user didn't supply one
|
||||
if constants.Page404.SLUG not in self._unevaluated_pages:
|
||||
self.add_page(route=constants.Page404.SLUG)
|
||||
@ -1204,6 +1261,13 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
for output_path, code in compile_results:
|
||||
compiler_utils.write_page(output_path, code)
|
||||
|
||||
# Pickle dynamic states
|
||||
if self._state is not None:
|
||||
enable_state = prerequisites.get_web_dir() / "backend" / "enable_state"
|
||||
enable_state.write_bytes(
|
||||
pickle.dumps(BaseState_import_order)
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||
"""Modify the state out of band.
|
||||
|
@ -347,6 +347,47 @@ async def _resolve_delta(delta: Delta) -> Delta:
|
||||
return delta
|
||||
|
||||
|
||||
# Tracking the import and potential exec history of BaseState subclasses.
|
||||
# This is used to reconstruct the state tree for the backend without pickling
|
||||
# the classes themselves.
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BaseStateOrigin:
|
||||
"""A class to track the origin of BaseState subclasses.
|
||||
|
||||
The origin is either evaluating some page, or importing a module.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
is_module: bool
|
||||
|
||||
@classmethod
|
||||
def from_stack(cls):
|
||||
"""Find the most likely import in the stack."""
|
||||
stack = inspect.stack()
|
||||
for frame in stack:
|
||||
if (
|
||||
frame.code_context is not None
|
||||
and any("class " in ctx for ctx in frame.code_context)
|
||||
and frame.function == "<module>"
|
||||
):
|
||||
return cls(
|
||||
identifier=inspect.getmodule(frame.frame).__name__,
|
||||
is_module=True,
|
||||
)
|
||||
if (
|
||||
frame.function == "compile_unevaluated_page"
|
||||
and inspect.getmodule(frame.frame) == reflex.compiler.compiler
|
||||
):
|
||||
# We hit a page in the compiler that needs to be evaluated
|
||||
return cls(
|
||||
identifier=frame.frame.f_locals["route"],
|
||||
is_module=False,
|
||||
)
|
||||
|
||||
|
||||
BaseState_import_order: dict[BaseStateOrigin, None] = {}
|
||||
|
||||
|
||||
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""The state of the app."""
|
||||
|
||||
@ -644,6 +685,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls._var_dependencies = {}
|
||||
cls._init_var_dependency_dicts()
|
||||
|
||||
BaseState_import_order[BaseStateOrigin.from_stack()] = None
|
||||
|
||||
@staticmethod
|
||||
def _copy_fn(fn: Callable) -> Callable:
|
||||
"""Copy a function. Used to copy ComputedVars and EventHandlers from mixins.
|
||||
|
@ -1996,6 +1996,9 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
default_factory=lambda: lambda _: None
|
||||
) # pyright: ignore [reportAssignmentType]
|
||||
|
||||
# Flag determines whether we are pickling the computed var itself
|
||||
_is_pickling: ClassVar[bool] = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fget: Callable[[BASE_STATE], RETURN_TYPE],
|
||||
@ -2407,6 +2410,8 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
Returns:
|
||||
The class of the var.
|
||||
"""
|
||||
if self._is_pickling:
|
||||
return type(self)
|
||||
return FakeComputedVarBaseClass
|
||||
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user