From 8fe5798f73398caf773fa542237f63730c274c1f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 6 Feb 2025 13:03:19 -0800 Subject: [PATCH] cache order of imports that create BaseState subclasses --- reflex/app.py | 70 +++++++++++++++++++++++++++++++++++++++++++-- reflex/state.py | 43 ++++++++++++++++++++++++++++ reflex/vars/base.py | 5 ++++ 3 files changed, 115 insertions(+), 3 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index a3d0d8e10..5da2e1e71 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -8,16 +8,19 @@ import contextlib import copy import dataclasses import functools +import importlib import inspect import io import json import multiprocessing +import pickle import platform +import shutil import sys import traceback from datetime import datetime from pathlib import Path -from types import SimpleNamespace +from types import FunctionType, SimpleNamespace from typing import ( TYPE_CHECKING, Any, @@ -39,11 +42,13 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.middleware import cors from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles +from rich.console import ConsoleThreadLocals from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp, AsyncNamespace, AsyncServer from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.view import ModelView +import reflex.istate.dynamic from reflex import constants from reflex.admin import AdminDash from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin @@ -92,6 +97,7 @@ from reflex.route import ( ) from reflex.state import ( BaseState, + BaseState_import_order, RouterData, State, StateManager, @@ -102,10 +108,34 @@ from reflex.state import ( from reflex.utils import codespaces, console, exceptions, format, prerequisites, types from reflex.utils.exec import is_prod_mode, is_testing_env from reflex.utils.imports import ImportVar +from reflex.vars.base import ComputedVar if TYPE_CHECKING: from reflex.vars import Var +try: + import dill +except ImportError: + dill = None +else: + # Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes + if not isinstance(State.validate.__func__, FunctionType): + import builtins + + cython_function_or_method = type(State.validate.__func__) + builtins.cython_function_or_method = cython_function_or_method + + @dill.register(cython_function_or_method) + def _dill_reduce_cython_function_or_method(pickler, obj): + # Ignore cython function when pickling. + pass + + @dill.register(ConsoleThreadLocals) + def _dill_reduce_console_thread_locals(pickler, obj): + # Ignore console thread locals when pickling. + pass + + # Define custom types. ComponentCallable = Callable[[], Component] Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]] @@ -380,6 +410,11 @@ class App(MiddlewareMixin, LifespanMixin): if not self._state: self._state = State self._setup_state() + enable_state_marker = ( + prerequisites.get_web_dir() / "backend" / "enable_state" + ) + enable_state_marker.parent.mkdir(parents=True, exist_ok=True) + enable_state_marker.touch() def _setup_state(self) -> None: """Set up the state for the app. @@ -498,8 +533,10 @@ class App(MiddlewareMixin, LifespanMixin): """Add optional api endpoints (_upload).""" if not self.api: return - - if Upload.is_used: + upload_is_used_marker = ( + prerequisites.get_web_dir() / "backend" / "upload_is_used" + ) + if Upload.is_used or upload_is_used_marker.exists(): # To upload files. self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) @@ -509,6 +546,9 @@ class App(MiddlewareMixin, LifespanMixin): StaticFiles(directory=get_upload_dir()), name="uploaded_files", ) + + upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True) + upload_is_used_marker.touch() if codespaces.is_running_in_codespaces(): self.api.get(str(constants.Endpoint.AUTH_CODESPACE))( codespaces.auth_codespace @@ -965,6 +1005,23 @@ class App(MiddlewareMixin, LifespanMixin): def get_compilation_time() -> str: return str(datetime.now().time()).split(".")[0] + should_compile = self._should_compile() + backend_dir = prerequisites.get_web_dir() / "backend" + if not should_compile and backend_dir.exists(): + enable_state_marker = backend_dir / "enable_state" + if enable_state_marker.exists(): + import_order = pickle.loads(enable_state_marker.read_bytes()) + for bs_import in import_order: + if bs_import.is_module: + print(f"BE Importing stateful module: {bs_import.identifier}") + importlib.import_module(bs_import.identifier) + else: + print(f"BE Evaluating stateful page: {bs_import.identifier}") + self._compile_page(bs_import.identifier, save_page=False) + self._enable_state() + self._add_optional_endpoints() + return + # Render a default 404 page if the user didn't supply one if constants.Page404.SLUG not in self._unevaluated_pages: self.add_page(route=constants.Page404.SLUG) @@ -1204,6 +1261,13 @@ class App(MiddlewareMixin, LifespanMixin): for output_path, code in compile_results: compiler_utils.write_page(output_path, code) + # Pickle dynamic states + if self._state is not None: + enable_state = prerequisites.get_web_dir() / "backend" / "enable_state" + enable_state.write_bytes( + pickle.dumps(BaseState_import_order) + ) + @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state out of band. diff --git a/reflex/state.py b/reflex/state.py index 92aaa4710..b05201948 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -347,6 +347,47 @@ async def _resolve_delta(delta: Delta) -> Delta: return delta +# Tracking the import and potential exec history of BaseState subclasses. +# This is used to reconstruct the state tree for the backend without pickling +# the classes themselves. +@dataclasses.dataclass(frozen=True) +class BaseStateOrigin: + """A class to track the origin of BaseState subclasses. + + The origin is either evaluating some page, or importing a module. + """ + + identifier: str + is_module: bool + + @classmethod + def from_stack(cls): + """Find the most likely import in the stack.""" + stack = inspect.stack() + for frame in stack: + if ( + frame.code_context is not None + and any("class " in ctx for ctx in frame.code_context) + and frame.function == "" + ): + return cls( + identifier=inspect.getmodule(frame.frame).__name__, + is_module=True, + ) + if ( + frame.function == "compile_unevaluated_page" + and inspect.getmodule(frame.frame) == reflex.compiler.compiler + ): + # We hit a page in the compiler that needs to be evaluated + return cls( + identifier=frame.frame.f_locals["route"], + is_module=False, + ) + + +BaseState_import_order: dict[BaseStateOrigin, None] = {} + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -644,6 +685,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls._var_dependencies = {} cls._init_var_dependency_dicts() + BaseState_import_order[BaseStateOrigin.from_stack()] = None + @staticmethod def _copy_fn(fn: Callable) -> Callable: """Copy a function. Used to copy ComputedVars and EventHandlers from mixins. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index c9dd81986..50de0a022 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1996,6 +1996,9 @@ class ComputedVar(Var[RETURN_TYPE]): default_factory=lambda: lambda _: None ) # pyright: ignore [reportAssignmentType] + # Flag determines whether we are pickling the computed var itself + _is_pickling: ClassVar[bool] = False + def __init__( self, fget: Callable[[BASE_STATE], RETURN_TYPE], @@ -2407,6 +2410,8 @@ class ComputedVar(Var[RETURN_TYPE]): Returns: The class of the var. """ + if self._is_pickling: + return type(self) return FakeComputedVarBaseClass @property