diff --git a/reflex/app.py b/reflex/app.py index 5da2e1e71..0aefe0a9a 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -8,19 +8,16 @@ import contextlib import copy import dataclasses import functools -import importlib import inspect import io import json import multiprocessing -import pickle import platform -import shutil import sys import traceback from datetime import datetime from pathlib import Path -from types import FunctionType, SimpleNamespace +from types import SimpleNamespace from typing import ( TYPE_CHECKING, Any, @@ -42,13 +39,11 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.middleware import cors from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles -from rich.console import ConsoleThreadLocals from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp, AsyncNamespace, AsyncServer from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.view import ModelView -import reflex.istate.dynamic from reflex import constants from reflex.admin import AdminDash from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin @@ -97,44 +92,21 @@ from reflex.route import ( ) from reflex.state import ( BaseState, - BaseState_import_order, RouterData, State, StateManager, StateUpdate, _substate_key, + all_base_state_classes, code_uses_state_contexts, ) from reflex.utils import codespaces, console, exceptions, format, prerequisites, types from reflex.utils.exec import is_prod_mode, is_testing_env from reflex.utils.imports import ImportVar -from reflex.vars.base import ComputedVar if TYPE_CHECKING: from reflex.vars import Var -try: - import dill -except ImportError: - dill = None -else: - # Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes - if not isinstance(State.validate.__func__, FunctionType): - import builtins - - cython_function_or_method = type(State.validate.__func__) - builtins.cython_function_or_method = cython_function_or_method - - @dill.register(cython_function_or_method) - def _dill_reduce_cython_function_or_method(pickler, obj): - # Ignore cython function when pickling. - pass - - @dill.register(ConsoleThreadLocals) - def _dill_reduce_console_thread_locals(pickler, obj): - # Ignore console thread locals when pickling. - pass - # Define custom types. ComponentCallable = Callable[[], Component] @@ -314,6 +286,9 @@ class App(MiddlewareMixin, LifespanMixin): # A map from a page route to the component to render. Users should use `add_page`. _pages: Dict[str, Component] = dataclasses.field(default_factory=dict) + # A mapping of pages which created states as they were being evaluated. + _stateful_pages: Dict[str, None] = dataclasses.field(default_factory=dict) + # The backend API object. _api: FastAPI | None = None @@ -410,11 +385,6 @@ class App(MiddlewareMixin, LifespanMixin): if not self._state: self._state = State self._setup_state() - enable_state_marker = ( - prerequisites.get_web_dir() / "backend" / "enable_state" - ) - enable_state_marker.parent.mkdir(parents=True, exist_ok=True) - enable_state_marker.touch() def _setup_state(self) -> None: """Set up the state for the app. @@ -691,13 +661,19 @@ class App(MiddlewareMixin, LifespanMixin): route: The route of the page to compile. save_page: If True, the compiled page is saved to self._pages. """ + n_states_before = len(all_base_state_classes) component, enable_state = compiler.compile_unevaluated_page( route, self._unevaluated_pages[route], self._state, self.style, self.theme ) + # Indicate that the app should use state. if enable_state: self._enable_state() + # Indicate that evaluating this page creates one or more state classes. + if len(all_base_state_classes) > n_states_before: + self._stateful_pages[route] = None + # Add the page. self._check_routes_conflict(route) if save_page: @@ -1010,14 +986,10 @@ class App(MiddlewareMixin, LifespanMixin): if not should_compile and backend_dir.exists(): enable_state_marker = backend_dir / "enable_state" if enable_state_marker.exists(): - import_order = pickle.loads(enable_state_marker.read_bytes()) - for bs_import in import_order: - if bs_import.is_module: - print(f"BE Importing stateful module: {bs_import.identifier}") - importlib.import_module(bs_import.identifier) - else: - print(f"BE Evaluating stateful page: {bs_import.identifier}") - self._compile_page(bs_import.identifier, save_page=False) + stateful_pages = json.load(enable_state_marker.open("r")) + for route in stateful_pages: + console.info(f"BE Evaluating stateful page: {route}") + self._compile_page(route, save_page=False) self._enable_state() self._add_optional_endpoints() return @@ -1264,9 +1236,8 @@ class App(MiddlewareMixin, LifespanMixin): # Pickle dynamic states if self._state is not None: enable_state = prerequisites.get_web_dir() / "backend" / "enable_state" - enable_state.write_bytes( - pickle.dumps(BaseState_import_order) - ) + enable_state.parent.mkdir(parents=True, exist_ok=True) + json.dump(list(self._stateful_pages), enable_state.open("w")) @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: diff --git a/reflex/state.py b/reflex/state.py index b05201948..a1362d15e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -347,45 +347,7 @@ async def _resolve_delta(delta: Delta) -> Delta: return delta -# Tracking the import and potential exec history of BaseState subclasses. -# This is used to reconstruct the state tree for the backend without pickling -# the classes themselves. -@dataclasses.dataclass(frozen=True) -class BaseStateOrigin: - """A class to track the origin of BaseState subclasses. - - The origin is either evaluating some page, or importing a module. - """ - - identifier: str - is_module: bool - - @classmethod - def from_stack(cls): - """Find the most likely import in the stack.""" - stack = inspect.stack() - for frame in stack: - if ( - frame.code_context is not None - and any("class " in ctx for ctx in frame.code_context) - and frame.function == "" - ): - return cls( - identifier=inspect.getmodule(frame.frame).__name__, - is_module=True, - ) - if ( - frame.function == "compile_unevaluated_page" - and inspect.getmodule(frame.frame) == reflex.compiler.compiler - ): - # We hit a page in the compiler that needs to be evaluated - return cls( - identifier=frame.frame.f_locals["route"], - is_module=False, - ) - - -BaseState_import_order: dict[BaseStateOrigin, None] = {} +all_base_state_classes: dict[str, None] = {} class BaseState(Base, ABC, extra=pydantic.Extra.allow): @@ -685,7 +647,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls._var_dependencies = {} cls._init_var_dependency_dicts() - BaseState_import_order[BaseStateOrigin.from_stack()] = None + all_base_state_classes[cls.get_full_name()] = None @staticmethod def _copy_fn(fn: Callable) -> Callable: @@ -4131,6 +4093,7 @@ def reload_state_module( for subclass in tuple(state.class_subclasses): reload_state_module(module=module, state=subclass) if subclass.__module__ == module and module is not None: + all_base_state_classes.pop(subclass.get_full_name(), None) state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) state._var_dependencies = {}