Track which pages create State subclasses during evaluation
These need to be replayed on the backend to ensure state alignment.
This commit is contained in:
parent
8fe5798f73
commit
9b47e3e460
@ -8,19 +8,16 @@ import contextlib
|
|||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
|
||||||
import inspect
|
import inspect
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import pickle
|
|
||||||
import platform
|
import platform
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import FunctionType, SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -42,13 +39,11 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile
|
|||||||
from fastapi.middleware import cors
|
from fastapi.middleware import cors
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from rich.console import ConsoleThreadLocals
|
|
||||||
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
|
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
|
||||||
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
||||||
from starlette_admin.contrib.sqla.admin import Admin
|
from starlette_admin.contrib.sqla.admin import Admin
|
||||||
from starlette_admin.contrib.sqla.view import ModelView
|
from starlette_admin.contrib.sqla.view import ModelView
|
||||||
|
|
||||||
import reflex.istate.dynamic
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.admin import AdminDash
|
from reflex.admin import AdminDash
|
||||||
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
|
from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
|
||||||
@ -97,44 +92,21 @@ from reflex.route import (
|
|||||||
)
|
)
|
||||||
from reflex.state import (
|
from reflex.state import (
|
||||||
BaseState,
|
BaseState,
|
||||||
BaseState_import_order,
|
|
||||||
RouterData,
|
RouterData,
|
||||||
State,
|
State,
|
||||||
StateManager,
|
StateManager,
|
||||||
StateUpdate,
|
StateUpdate,
|
||||||
_substate_key,
|
_substate_key,
|
||||||
|
all_base_state_classes,
|
||||||
code_uses_state_contexts,
|
code_uses_state_contexts,
|
||||||
)
|
)
|
||||||
from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
|
from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
|
||||||
from reflex.utils.exec import is_prod_mode, is_testing_env
|
from reflex.utils.exec import is_prod_mode, is_testing_env
|
||||||
from reflex.utils.imports import ImportVar
|
from reflex.utils.imports import ImportVar
|
||||||
from reflex.vars.base import ComputedVar
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from reflex.vars import Var
|
from reflex.vars import Var
|
||||||
|
|
||||||
try:
|
|
||||||
import dill
|
|
||||||
except ImportError:
|
|
||||||
dill = None
|
|
||||||
else:
|
|
||||||
# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
|
|
||||||
if not isinstance(State.validate.__func__, FunctionType):
|
|
||||||
import builtins
|
|
||||||
|
|
||||||
cython_function_or_method = type(State.validate.__func__)
|
|
||||||
builtins.cython_function_or_method = cython_function_or_method
|
|
||||||
|
|
||||||
@dill.register(cython_function_or_method)
|
|
||||||
def _dill_reduce_cython_function_or_method(pickler, obj):
|
|
||||||
# Ignore cython function when pickling.
|
|
||||||
pass
|
|
||||||
|
|
||||||
@dill.register(ConsoleThreadLocals)
|
|
||||||
def _dill_reduce_console_thread_locals(pickler, obj):
|
|
||||||
# Ignore console thread locals when pickling.
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Define custom types.
|
# Define custom types.
|
||||||
ComponentCallable = Callable[[], Component]
|
ComponentCallable = Callable[[], Component]
|
||||||
@ -314,6 +286,9 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
# A map from a page route to the component to render. Users should use `add_page`.
|
# A map from a page route to the component to render. Users should use `add_page`.
|
||||||
_pages: Dict[str, Component] = dataclasses.field(default_factory=dict)
|
_pages: Dict[str, Component] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
# A mapping of pages which created states as they were being evaluated.
|
||||||
|
_stateful_pages: Dict[str, None] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
# The backend API object.
|
# The backend API object.
|
||||||
_api: FastAPI | None = None
|
_api: FastAPI | None = None
|
||||||
|
|
||||||
@ -410,11 +385,6 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
if not self._state:
|
if not self._state:
|
||||||
self._state = State
|
self._state = State
|
||||||
self._setup_state()
|
self._setup_state()
|
||||||
enable_state_marker = (
|
|
||||||
prerequisites.get_web_dir() / "backend" / "enable_state"
|
|
||||||
)
|
|
||||||
enable_state_marker.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
enable_state_marker.touch()
|
|
||||||
|
|
||||||
def _setup_state(self) -> None:
|
def _setup_state(self) -> None:
|
||||||
"""Set up the state for the app.
|
"""Set up the state for the app.
|
||||||
@ -691,13 +661,19 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
route: The route of the page to compile.
|
route: The route of the page to compile.
|
||||||
save_page: If True, the compiled page is saved to self._pages.
|
save_page: If True, the compiled page is saved to self._pages.
|
||||||
"""
|
"""
|
||||||
|
n_states_before = len(all_base_state_classes)
|
||||||
component, enable_state = compiler.compile_unevaluated_page(
|
component, enable_state = compiler.compile_unevaluated_page(
|
||||||
route, self._unevaluated_pages[route], self._state, self.style, self.theme
|
route, self._unevaluated_pages[route], self._state, self.style, self.theme
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Indicate that the app should use state.
|
||||||
if enable_state:
|
if enable_state:
|
||||||
self._enable_state()
|
self._enable_state()
|
||||||
|
|
||||||
|
# Indicate that evaluating this page creates one or more state classes.
|
||||||
|
if len(all_base_state_classes) > n_states_before:
|
||||||
|
self._stateful_pages[route] = None
|
||||||
|
|
||||||
# Add the page.
|
# Add the page.
|
||||||
self._check_routes_conflict(route)
|
self._check_routes_conflict(route)
|
||||||
if save_page:
|
if save_page:
|
||||||
@ -1010,14 +986,10 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
if not should_compile and backend_dir.exists():
|
if not should_compile and backend_dir.exists():
|
||||||
enable_state_marker = backend_dir / "enable_state"
|
enable_state_marker = backend_dir / "enable_state"
|
||||||
if enable_state_marker.exists():
|
if enable_state_marker.exists():
|
||||||
import_order = pickle.loads(enable_state_marker.read_bytes())
|
stateful_pages = json.load(enable_state_marker.open("r"))
|
||||||
for bs_import in import_order:
|
for route in stateful_pages:
|
||||||
if bs_import.is_module:
|
console.info(f"BE Evaluating stateful page: {route}")
|
||||||
print(f"BE Importing stateful module: {bs_import.identifier}")
|
self._compile_page(route, save_page=False)
|
||||||
importlib.import_module(bs_import.identifier)
|
|
||||||
else:
|
|
||||||
print(f"BE Evaluating stateful page: {bs_import.identifier}")
|
|
||||||
self._compile_page(bs_import.identifier, save_page=False)
|
|
||||||
self._enable_state()
|
self._enable_state()
|
||||||
self._add_optional_endpoints()
|
self._add_optional_endpoints()
|
||||||
return
|
return
|
||||||
@ -1264,9 +1236,8 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
# Pickle dynamic states
|
# Pickle dynamic states
|
||||||
if self._state is not None:
|
if self._state is not None:
|
||||||
enable_state = prerequisites.get_web_dir() / "backend" / "enable_state"
|
enable_state = prerequisites.get_web_dir() / "backend" / "enable_state"
|
||||||
enable_state.write_bytes(
|
enable_state.parent.mkdir(parents=True, exist_ok=True)
|
||||||
pickle.dumps(BaseState_import_order)
|
json.dump(list(self._stateful_pages), enable_state.open("w"))
|
||||||
)
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||||
|
@ -347,45 +347,7 @@ async def _resolve_delta(delta: Delta) -> Delta:
|
|||||||
return delta
|
return delta
|
||||||
|
|
||||||
|
|
||||||
# Tracking the import and potential exec history of BaseState subclasses.
|
all_base_state_classes: dict[str, None] = {}
|
||||||
# This is used to reconstruct the state tree for the backend without pickling
|
|
||||||
# the classes themselves.
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class BaseStateOrigin:
|
|
||||||
"""A class to track the origin of BaseState subclasses.
|
|
||||||
|
|
||||||
The origin is either evaluating some page, or importing a module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
identifier: str
|
|
||||||
is_module: bool
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_stack(cls):
|
|
||||||
"""Find the most likely import in the stack."""
|
|
||||||
stack = inspect.stack()
|
|
||||||
for frame in stack:
|
|
||||||
if (
|
|
||||||
frame.code_context is not None
|
|
||||||
and any("class " in ctx for ctx in frame.code_context)
|
|
||||||
and frame.function == "<module>"
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
identifier=inspect.getmodule(frame.frame).__name__,
|
|
||||||
is_module=True,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
frame.function == "compile_unevaluated_page"
|
|
||||||
and inspect.getmodule(frame.frame) == reflex.compiler.compiler
|
|
||||||
):
|
|
||||||
# We hit a page in the compiler that needs to be evaluated
|
|
||||||
return cls(
|
|
||||||
identifier=frame.frame.f_locals["route"],
|
|
||||||
is_module=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
BaseState_import_order: dict[BaseStateOrigin, None] = {}
|
|
||||||
|
|
||||||
|
|
||||||
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||||
@ -685,7 +647,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
cls._var_dependencies = {}
|
cls._var_dependencies = {}
|
||||||
cls._init_var_dependency_dicts()
|
cls._init_var_dependency_dicts()
|
||||||
|
|
||||||
BaseState_import_order[BaseStateOrigin.from_stack()] = None
|
all_base_state_classes[cls.get_full_name()] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _copy_fn(fn: Callable) -> Callable:
|
def _copy_fn(fn: Callable) -> Callable:
|
||||||
@ -4131,6 +4093,7 @@ def reload_state_module(
|
|||||||
for subclass in tuple(state.class_subclasses):
|
for subclass in tuple(state.class_subclasses):
|
||||||
reload_state_module(module=module, state=subclass)
|
reload_state_module(module=module, state=subclass)
|
||||||
if subclass.__module__ == module and module is not None:
|
if subclass.__module__ == module and module is not None:
|
||||||
|
all_base_state_classes.pop(subclass.get_full_name(), None)
|
||||||
state.class_subclasses.remove(subclass)
|
state.class_subclasses.remove(subclass)
|
||||||
state._always_dirty_substates.discard(subclass.get_name())
|
state._always_dirty_substates.discard(subclass.get_name())
|
||||||
state._var_dependencies = {}
|
state._var_dependencies = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user