Track which pages create State subclasses during evaluation

These need to be replayed on the backend to ensure state alignment.
This commit is contained in:
Masen Furer 2025-02-07 09:44:38 -08:00
parent 8fe5798f73
commit 9b47e3e460
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
2 changed files with 20 additions and 86 deletions

View File

@ -8,19 +8,16 @@ import contextlib
import copy import copy
import dataclasses import dataclasses
import functools import functools
import importlib
import inspect import inspect
import io import io
import json import json
import multiprocessing import multiprocessing
import pickle
import platform import platform
import shutil
import sys import sys
import traceback import traceback
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from types import FunctionType, SimpleNamespace from types import SimpleNamespace
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -42,13 +39,11 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.middleware import cors from fastapi.middleware import cors
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from rich.console import ConsoleThreadLocals
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from socketio import ASGIApp, AsyncNamespace, AsyncServer from socketio import ASGIApp, AsyncNamespace, AsyncServer
from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.admin import Admin
from starlette_admin.contrib.sqla.view import ModelView from starlette_admin.contrib.sqla.view import ModelView
import reflex.istate.dynamic
from reflex import constants from reflex import constants
from reflex.admin import AdminDash from reflex.admin import AdminDash
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
@ -97,44 +92,21 @@ from reflex.route import (
) )
from reflex.state import ( from reflex.state import (
BaseState, BaseState,
BaseState_import_order,
RouterData, RouterData,
State, State,
StateManager, StateManager,
StateUpdate, StateUpdate,
_substate_key, _substate_key,
all_base_state_classes,
code_uses_state_contexts, code_uses_state_contexts,
) )
from reflex.utils import codespaces, console, exceptions, format, prerequisites, types from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
from reflex.utils.exec import is_prod_mode, is_testing_env from reflex.utils.exec import is_prod_mode, is_testing_env
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
from reflex.vars.base import ComputedVar
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.vars import Var from reflex.vars import Var
try:
import dill
except ImportError:
dill = None
else:
# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
if not isinstance(State.validate.__func__, FunctionType):
import builtins
cython_function_or_method = type(State.validate.__func__)
builtins.cython_function_or_method = cython_function_or_method
@dill.register(cython_function_or_method)
def _dill_reduce_cython_function_or_method(pickler, obj):
# Ignore cython function when pickling.
pass
@dill.register(ConsoleThreadLocals)
def _dill_reduce_console_thread_locals(pickler, obj):
# Ignore console thread locals when pickling.
pass
# Define custom types. # Define custom types.
ComponentCallable = Callable[[], Component] ComponentCallable = Callable[[], Component]
@ -314,6 +286,9 @@ class App(MiddlewareMixin, LifespanMixin):
# A map from a page route to the component to render. Users should use `add_page`. # A map from a page route to the component to render. Users should use `add_page`.
_pages: Dict[str, Component] = dataclasses.field(default_factory=dict) _pages: Dict[str, Component] = dataclasses.field(default_factory=dict)
# A mapping of pages which created states as they were being evaluated.
_stateful_pages: Dict[str, None] = dataclasses.field(default_factory=dict)
# The backend API object. # The backend API object.
_api: FastAPI | None = None _api: FastAPI | None = None
@ -410,11 +385,6 @@ class App(MiddlewareMixin, LifespanMixin):
if not self._state: if not self._state:
self._state = State self._state = State
self._setup_state() self._setup_state()
enable_state_marker = (
prerequisites.get_web_dir() / "backend" / "enable_state"
)
enable_state_marker.parent.mkdir(parents=True, exist_ok=True)
enable_state_marker.touch()
def _setup_state(self) -> None: def _setup_state(self) -> None:
"""Set up the state for the app. """Set up the state for the app.
@ -691,13 +661,19 @@ class App(MiddlewareMixin, LifespanMixin):
route: The route of the page to compile. route: The route of the page to compile.
save_page: If True, the compiled page is saved to self._pages. save_page: If True, the compiled page is saved to self._pages.
""" """
n_states_before = len(all_base_state_classes)
component, enable_state = compiler.compile_unevaluated_page( component, enable_state = compiler.compile_unevaluated_page(
route, self._unevaluated_pages[route], self._state, self.style, self.theme route, self._unevaluated_pages[route], self._state, self.style, self.theme
) )
# Indicate that the app should use state.
if enable_state: if enable_state:
self._enable_state() self._enable_state()
# Indicate that evaluating this page creates one or more state classes.
if len(all_base_state_classes) > n_states_before:
self._stateful_pages[route] = None
# Add the page. # Add the page.
self._check_routes_conflict(route) self._check_routes_conflict(route)
if save_page: if save_page:
@ -1010,14 +986,10 @@ class App(MiddlewareMixin, LifespanMixin):
if not should_compile and backend_dir.exists(): if not should_compile and backend_dir.exists():
enable_state_marker = backend_dir / "enable_state" enable_state_marker = backend_dir / "enable_state"
if enable_state_marker.exists(): if enable_state_marker.exists():
import_order = pickle.loads(enable_state_marker.read_bytes()) stateful_pages = json.load(enable_state_marker.open("r"))
for bs_import in import_order: for route in stateful_pages:
if bs_import.is_module: console.info(f"BE Evaluating stateful page: {route}")
print(f"BE Importing stateful module: {bs_import.identifier}") self._compile_page(route, save_page=False)
importlib.import_module(bs_import.identifier)
else:
print(f"BE Evaluating stateful page: {bs_import.identifier}")
self._compile_page(bs_import.identifier, save_page=False)
self._enable_state() self._enable_state()
self._add_optional_endpoints() self._add_optional_endpoints()
return return
@ -1264,9 +1236,8 @@ class App(MiddlewareMixin, LifespanMixin):
# Pickle dynamic states # Pickle dynamic states
if self._state is not None: if self._state is not None:
enable_state = prerequisites.get_web_dir() / "backend" / "enable_state" enable_state = prerequisites.get_web_dir() / "backend" / "enable_state"
enable_state.write_bytes( enable_state.parent.mkdir(parents=True, exist_ok=True)
pickle.dumps(BaseState_import_order) json.dump(list(self._stateful_pages), enable_state.open("w"))
)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]: async def modify_state(self, token: str) -> AsyncIterator[BaseState]:

View File

@ -347,45 +347,7 @@ async def _resolve_delta(delta: Delta) -> Delta:
return delta return delta
# Tracking the import and potential exec history of BaseState subclasses. all_base_state_classes: dict[str, None] = {}
# This is used to reconstruct the state tree for the backend without pickling
# the classes themselves.
@dataclasses.dataclass(frozen=True)
class BaseStateOrigin:
"""A class to track the origin of BaseState subclasses.
The origin is either evaluating some page, or importing a module.
"""
identifier: str
is_module: bool
@classmethod
def from_stack(cls):
"""Find the most likely import in the stack."""
stack = inspect.stack()
for frame in stack:
if (
frame.code_context is not None
and any("class " in ctx for ctx in frame.code_context)
and frame.function == "<module>"
):
return cls(
identifier=inspect.getmodule(frame.frame).__name__,
is_module=True,
)
if (
frame.function == "compile_unevaluated_page"
and inspect.getmodule(frame.frame) == reflex.compiler.compiler
):
# We hit a page in the compiler that needs to be evaluated
return cls(
identifier=frame.frame.f_locals["route"],
is_module=False,
)
BaseState_import_order: dict[BaseStateOrigin, None] = {}
class BaseState(Base, ABC, extra=pydantic.Extra.allow): class BaseState(Base, ABC, extra=pydantic.Extra.allow):
@ -685,7 +647,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cls._var_dependencies = {} cls._var_dependencies = {}
cls._init_var_dependency_dicts() cls._init_var_dependency_dicts()
BaseState_import_order[BaseStateOrigin.from_stack()] = None all_base_state_classes[cls.get_full_name()] = None
@staticmethod @staticmethod
def _copy_fn(fn: Callable) -> Callable: def _copy_fn(fn: Callable) -> Callable:
@ -4131,6 +4093,7 @@ def reload_state_module(
for subclass in tuple(state.class_subclasses): for subclass in tuple(state.class_subclasses):
reload_state_module(module=module, state=subclass) reload_state_module(module=module, state=subclass)
if subclass.__module__ == module and module is not None: if subclass.__module__ == module and module is not None:
all_base_state_classes.pop(subclass.get_full_name(), None)
state.class_subclasses.remove(subclass) state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name()) state._always_dirty_substates.discard(subclass.get_name())
state._var_dependencies = {} state._var_dependencies = {}