Compare commits

...

6 Commits

Author SHA1 Message Date
Lendemor
76dce8bba3 Merge branch 'main' into lendemor/builtins_states 2024-10-25 12:49:36 +02:00
Lendemor
7320425285 Merge branch 'main' into lendemor/builtins_states 2024-10-23 17:28:15 +02:00
Lendemor
703a6da3d7 Merge branch 'main' into lendemor/builtins_states 2024-10-22 22:27:10 +02:00
Lendemor
bec66b894e fix tests and hardcoded constants 2024-10-22 21:05:10 +02:00
Lendemor
9f3752ea74 Merge branch 'main' into lendemor/builtins_states 2024-10-22 20:53:52 +02:00
Lendemor
4e8169e1ed move builtins states to their own file 2024-10-22 20:53:10 +02:00
23 changed files with 255 additions and 232 deletions

View File

@ -320,6 +320,7 @@ _MAPPING: dict = {
"upload_files", "upload_files",
"window_alert", "window_alert",
], ],
"istate.builtins": ["ComponentState", "State"],
"istate.storage": [ "istate.storage": [
"Cookie", "Cookie",
"LocalStorage", "LocalStorage",
@ -329,8 +330,6 @@ _MAPPING: dict = {
"model": ["session", "Model"], "model": ["session", "Model"],
"state": [ "state": [
"var", "var",
"ComponentState",
"State",
"dynamic", "dynamic",
], ],
"style": ["Style", "toggle_color_mode"], "style": ["Style", "toggle_color_mode"],

View File

@ -174,6 +174,8 @@ from .event import stop_propagation as stop_propagation
from .event import upload_files as upload_files from .event import upload_files as upload_files
from .event import window_alert as window_alert from .event import window_alert as window_alert
from .experimental import _x as _x from .experimental import _x as _x
from .istate.builtins import ComponentState as ComponentState
from .istate.builtins import State as State
from .istate.storage import Cookie as Cookie from .istate.storage import Cookie as Cookie
from .istate.storage import LocalStorage as LocalStorage from .istate.storage import LocalStorage as LocalStorage
from .istate.storage import SessionStorage as SessionStorage from .istate.storage import SessionStorage as SessionStorage
@ -182,8 +184,6 @@ from .middleware import middleware as middleware
from .model import Model as Model from .model import Model as Model
from .model import session as session from .model import session as session
from .page import page as page from .page import page as page
from .state import ComponentState as ComponentState
from .state import State as State
from .state import dynamic as dynamic from .state import dynamic as dynamic
from .state import var as var from .state import var as var
from .style import Style as Style from .style import Style as Style

View File

@ -66,6 +66,7 @@ from reflex.components.core.upload import Upload, get_upload_dir
from reflex.components.radix import themes from reflex.components.radix import themes
from reflex.config import environment, get_config from reflex.config import environment, get_config
from reflex.event import Event, EventHandler, EventSpec, window_alert from reflex.event import Event, EventHandler, EventSpec, window_alert
from reflex.istate.builtins import State
from reflex.model import Model, get_db_status from reflex.model import Model, get_db_status
from reflex.page import ( from reflex.page import (
DECORATED_PAGES, DECORATED_PAGES,
@ -78,7 +79,6 @@ from reflex.route import (
from reflex.state import ( from reflex.state import (
BaseState, BaseState,
RouterData, RouterData,
State,
StateManager, StateManager,
StateUpdate, StateUpdate,
_substate_key, _substate_key,

View File

@ -8,7 +8,7 @@ from reflex.compiler.compiler import _compile_component
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.el import div, p from reflex.components.el import div, p
from reflex.event import EventHandler from reflex.event import EventHandler
from reflex.state import FrontendEventExceptionState from reflex.istate.builtins import FrontendEventExceptionState
from reflex.vars.base import Var from reflex.vars.base import Var

View File

@ -21,7 +21,7 @@ from typing import (
Union, Union,
) )
import reflex.state import reflex.istate.builtins
from reflex.base import Base from reflex.base import Base
from reflex.compiler.templates import STATEFUL_COMPONENT from reflex.compiler.templates import STATEFUL_COMPONENT
from reflex.components.core.breakpoints import Breakpoints from reflex.components.core.breakpoints import Breakpoints
@ -225,7 +225,7 @@ class Component(BaseComponent, ABC):
_memoization_mode: MemoizationMode = MemoizationMode() _memoization_mode: MemoizationMode = MemoizationMode()
# State class associated with this component instance # State class associated with this component instance
State: Optional[Type[reflex.state.State]] = None State: Optional[Type[reflex.istate.builtins.State]] = None
def add_imports(self) -> ImportDict | list[ImportDict]: def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the component. """Add imports for the component.

View File

@ -9,7 +9,7 @@ from reflex.components.base.fragment import Fragment
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.tags import IterTag from reflex.components.tags import IterTag
from reflex.constants import MemoizationMode from reflex.constants import MemoizationMode
from reflex.state import ComponentState from reflex.istate.builtins import ComponentState
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var

View File

@ -62,16 +62,20 @@ class CompileVars(SimpleNamespace):
# The name of the function for converting a dict to an event. # The name of the function for converting a dict to an event.
TO_EVENT = "Event" TO_EVENT = "Event"
# The name of the internal on_load event. # The name of the internal on_load event.
ON_LOAD_INTERNAL = "reflex___state____on_load_internal_state.on_load_internal" ON_LOAD_INTERNAL = (
"reflex___istate___builtins____on_load_internal_state.on_load_internal"
)
# The name of the internal event to update generic state vars. # The name of the internal event to update generic state vars.
UPDATE_VARS_INTERNAL = ( UPDATE_VARS_INTERNAL = (
"reflex___state____update_vars_internal_state.update_vars_internal" "reflex___istate___builtins____update_vars_internal_state.update_vars_internal"
) )
# The name of the frontend event exception state # The name of the frontend event exception state
FRONTEND_EXCEPTION_STATE = "reflex___state____frontend_event_exception_state" FRONTEND_EXCEPTION_STATE = (
"reflex___istate___builtins____frontend_event_exception_state"
)
# The full name of the frontend exception state # The full name of the frontend exception state
FRONTEND_EXCEPTION_STATE_FULL = ( FRONTEND_EXCEPTION_STATE_FULL = (
f"reflex___state____state.{FRONTEND_EXCEPTION_STATE}" f"reflex___istate___builtins____state.{FRONTEND_EXCEPTION_STATE}"
) )

View File

@ -14,7 +14,7 @@ from reflex.components.radix.themes.layout.container import Container
from reflex.components.radix.themes.layout.stack import HStack from reflex.components.radix.themes.layout.stack import HStack
from reflex.event import call_script from reflex.event import call_script
from reflex.experimental import hooks from reflex.experimental import hooks
from reflex.state import ComponentState from reflex.istate.builtins import ComponentState
from reflex.style import Style from reflex.style import Style
from reflex.vars.base import Var from reflex.vars.base import Var

View File

@ -11,7 +11,7 @@ from reflex.components.component import Component, ComponentNamespace, Memoizati
from reflex.components.radix.primitives.drawer import DrawerRoot from reflex.components.radix.primitives.drawer import DrawerRoot
from reflex.components.radix.themes.layout.box import Box from reflex.components.radix.themes.layout.box import Box
from reflex.event import EventType from reflex.event import EventType
from reflex.state import ComponentState from reflex.istate.builtins import ComponentState
from reflex.style import Style from reflex.style import Style
from reflex.vars.base import Var from reflex.vars.base import Var

207
reflex/istate/builtins.py Normal file
View File

@ -0,0 +1,207 @@
"""The built-in states used by reflex apps."""
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any, ClassVar, Type
import reflex.istate.dynamic
from reflex import constants, event
from reflex.event import Event, EventSpec, fix_events
from reflex.state import BaseState
from reflex.utils import prerequisites
if TYPE_CHECKING:
from reflex.components.component import Component
class State(BaseState):
"""The app Base State."""
# The hydrated bool.
is_hydrated: bool = False
class FrontendEventExceptionState(State):
"""Substate for handling frontend exceptions."""
@event
def handle_frontend_exception(self, stack: str, component_stack: str) -> None:
"""Handle frontend exceptions.
If a frontend exception handler is provided, it will be called.
Otherwise, the default frontend exception handler will be called.
Args:
stack: The stack trace of the exception.
component_stack: The stack trace of the component where the exception occurred.
"""
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance.frontend_exception_handler(Exception(stack))
class UpdateVarsInternalState(State):
"""Substate for handling internal state var updates."""
async def update_vars_internal(self, vars: dict[str, Any]) -> None:
"""Apply updates to fully qualified state vars.
The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
and each value will be set on the appropriate substate instance.
This function is primarily used to apply cookie and local storage
updates from the frontend to the appropriate substate.
Args:
vars: The fully qualified vars and values to update.
"""
for var, value in vars.items():
state_name, _, var_name = var.rpartition(".")
var_state_cls = State.get_class_substate(state_name)
var_state = await self.get_state(var_state_cls)
setattr(var_state, var_name, value)
class OnLoadInternalState(State):
"""Substate for handling on_load event enumeration.
This is a separate substate to avoid deserializing the entire state tree for every page navigation.
"""
def on_load_internal(self) -> list[Event | EventSpec] | None:
"""Queue on_load handlers for the current page.
Returns:
The list of events to queue for on load handling.
"""
# Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
load_events = app.get_load_events(self.router.page.path)
if not load_events:
self.is_hydrated = True
return # Fast path for navigation with no on_load events defined.
self.is_hydrated = False
return [
*fix_events(
load_events,
self.router.session.client_token,
router_data=self.router_data,
),
State.set_is_hydrated(True), # type: ignore
]
class ComponentState(State, mixin=True):
"""Base class to allow for the creation of a state instance per component.
This allows for the bundling of UI and state logic into a single class,
where each instance has a separate instance of the state.
Subclass this class and define vars and event handlers in the traditional way.
Then define a `get_component` method that returns the UI for the component instance.
See the full [docs](https://reflex.dev/docs/substates/component-state/) for more.
Basic example:
```python
# Subclass ComponentState and define vars and event handlers.
class Counter(rx.ComponentState):
# Define vars that change.
count: int = 0
# Define event handlers.
def increment(self):
self.count += 1
def decrement(self):
self.count -= 1
@classmethod
def get_component(cls, **props):
# Access the state vars and event handlers using `cls`.
return rx.hstack(
rx.button("Decrement", on_click=cls.decrement),
rx.text(cls.count),
rx.button("Increment", on_click=cls.increment),
**props,
)
counter = Counter.create()
```
"""
# The number of components created from this class.
_per_component_state_instance_count: ClassVar[int] = 0
@classmethod
def __init_subclass__(cls, mixin: bool = True, **kwargs):
"""Overwrite mixin default to True.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
"""
super().__init_subclass__(mixin=mixin, **kwargs)
@classmethod
def get_component(cls, *children, **props) -> "Component":
"""Get the component instance.
Args:
children: The children of the component.
props: The props of the component.
Raises:
NotImplementedError: if the subclass does not override this method.
"""
raise NotImplementedError(
f"{cls.__name__} must implement get_component to return the component instance."
)
@classmethod
def create(cls, *children, **props) -> "Component":
"""Create a new instance of the Component.
Args:
children: The children of the component.
props: The props of the component.
Returns:
A new instance of the Component with an independent copy of the State.
"""
cls._per_component_state_instance_count += 1
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
component_state = type(
state_cls_name,
(cls, State),
{"__module__": reflex.istate.dynamic.__name__},
mixin=False,
)
# Save a reference to the dynamic state for pickle/unpickle.
setattr(reflex.istate.dynamic, state_cls_name, component_state)
component = component_state.get_component(*children, **props)
component.State = component_state
return component
def reload_state_module(
module: str,
state: Type[BaseState] = State,
) -> None:
"""Reset rx.State subclasses to avoid conflict when reloading.
Args:
module: The module to reload.
state: Recursive argument for the state class to reload.
"""
for subclass in tuple(state.class_subclasses):
reload_state_module(module=module, state=subclass)
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
state._computed_var_dependencies = defaultdict(set)
state._substate_var_dependencies = defaultdict(set)
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()

View File

@ -40,7 +40,6 @@ from typing import (
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self from typing_extensions import Self
from reflex import event
from reflex.config import get_config from reflex.config import get_config
from reflex.istate.data import RouterData from reflex.istate.data import RouterData
from reflex.istate.storage import ( from reflex.istate.storage import (
@ -94,13 +93,12 @@ from reflex.utils.serializers import serializer
from reflex.utils.types import get_origin, override from reflex.utils.types import get_origin, override
from reflex.vars import VarData from reflex.vars import VarData
if TYPE_CHECKING:
from reflex.components.component import Component
Delta = Dict[str, Any] Delta = Dict[str, Any]
var = computed_var var = computed_var
if TYPE_CHECKING:
from reflex.components.component import Component
# If the state is this large, it's considered a performance issue. # If the state is this large, it's considered a performance issue.
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
@ -899,11 +897,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if len(path) == 0: if len(path) == 0:
return cls return cls
print("get_name: ", cls.get_name())
if path[0] == cls.get_name(): if path[0] == cls.get_name():
if len(path) == 1: if len(path) == 1:
return cls return cls
path = path[1:] path = path[1:]
for substate in cls.get_substates(): for substate in cls.get_substates():
print("substate get_name: ", substate.get_name())
if path[0] == substate.get_name(): if path[0] == substate.get_name():
return substate.get_class_substate(path[1:]) return substate.get_class_substate(path[1:])
raise ValueError(f"Invalid path: {path}") raise ValueError(f"Invalid path: {path}")
@ -2112,13 +2112,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return state return state
class State(BaseState):
"""The app Base State."""
# The hydrated bool.
is_hydrated: bool = False
T = TypeVar("T", bound=BaseState) T = TypeVar("T", bound=BaseState)
@ -2164,169 +2157,6 @@ def dynamic(func: Callable[[T], Component]):
return wrapper return wrapper
class FrontendEventExceptionState(State):
"""Substate for handling frontend exceptions."""
@event
def handle_frontend_exception(self, stack: str, component_stack: str) -> None:
"""Handle frontend exceptions.
If a frontend exception handler is provided, it will be called.
Otherwise, the default frontend exception handler will be called.
Args:
stack: The stack trace of the exception.
component_stack: The stack trace of the component where the exception occurred.
"""
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance.frontend_exception_handler(Exception(stack))
class UpdateVarsInternalState(State):
"""Substate for handling internal state var updates."""
async def update_vars_internal(self, vars: dict[str, Any]) -> None:
"""Apply updates to fully qualified state vars.
The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
and each value will be set on the appropriate substate instance.
This function is primarily used to apply cookie and local storage
updates from the frontend to the appropriate substate.
Args:
vars: The fully qualified vars and values to update.
"""
for var, value in vars.items():
state_name, _, var_name = var.rpartition(".")
var_state_cls = State.get_class_substate(state_name)
var_state = await self.get_state(var_state_cls)
setattr(var_state, var_name, value)
class OnLoadInternalState(State):
"""Substate for handling on_load event enumeration.
This is a separate substate to avoid deserializing the entire state tree for every page navigation.
"""
def on_load_internal(self) -> list[Event | EventSpec] | None:
"""Queue on_load handlers for the current page.
Returns:
The list of events to queue for on load handling.
"""
# Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
load_events = app.get_load_events(self.router.page.path)
if not load_events:
self.is_hydrated = True
return # Fast path for navigation with no on_load events defined.
self.is_hydrated = False
return [
*fix_events(
load_events,
self.router.session.client_token,
router_data=self.router_data,
),
State.set_is_hydrated(True), # type: ignore
]
class ComponentState(State, mixin=True):
"""Base class to allow for the creation of a state instance per component.
This allows for the bundling of UI and state logic into a single class,
where each instance has a separate instance of the state.
Subclass this class and define vars and event handlers in the traditional way.
Then define a `get_component` method that returns the UI for the component instance.
See the full [docs](https://reflex.dev/docs/substates/component-state/) for more.
Basic example:
```python
# Subclass ComponentState and define vars and event handlers.
class Counter(rx.ComponentState):
# Define vars that change.
count: int = 0
# Define event handlers.
def increment(self):
self.count += 1
def decrement(self):
self.count -= 1
@classmethod
def get_component(cls, **props):
# Access the state vars and event handlers using `cls`.
return rx.hstack(
rx.button("Decrement", on_click=cls.decrement),
rx.text(cls.count),
rx.button("Increment", on_click=cls.increment),
**props,
)
counter = Counter.create()
```
"""
# The number of components created from this class.
_per_component_state_instance_count: ClassVar[int] = 0
@classmethod
def __init_subclass__(cls, mixin: bool = True, **kwargs):
"""Overwrite mixin default to True.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
"""
super().__init_subclass__(mixin=mixin, **kwargs)
@classmethod
def get_component(cls, *children, **props) -> "Component":
"""Get the component instance.
Args:
children: The children of the component.
props: The props of the component.
Raises:
NotImplementedError: if the subclass does not override this method.
"""
raise NotImplementedError(
f"{cls.__name__} must implement get_component to return the component instance."
)
@classmethod
def create(cls, *children, **props) -> "Component":
"""Create a new instance of the Component.
Args:
children: The children of the component.
props: The props of the component.
Returns:
A new instance of the Component with an independent copy of the State.
"""
cls._per_component_state_instance_count += 1
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
component_state = type(
state_cls_name,
(cls, State),
{"__module__": reflex.istate.dynamic.__name__},
mixin=False,
)
# Save a reference to the dynamic state for pickle/unpickle.
setattr(reflex.istate.dynamic, state_cls_name, component_state)
component = component_state.get_component(*children, **props)
component.State = component_state
return component
class StateProxy(wrapt.ObjectProxy): class StateProxy(wrapt.ObjectProxy):
"""Proxy of a state instance to control mutability of vars for a background task. """Proxy of a state instance to control mutability of vars for a background task.
@ -2408,7 +2238,7 @@ class StateProxy(wrapt.ObjectProxy):
super().__setattr__( super().__setattr__(
"__wrapped__", "__wrapped__",
await parent_state.get_state( await parent_state.get_state(
State.get_class_substate(self._self_substate_path) BaseState.get_class_substate(self._self_substate_path)
), ),
) )
return self return self
@ -3743,25 +3573,3 @@ def code_uses_state_contexts(javascript_code: str) -> bool:
True if the code attempts to access a member of StateContexts. True if the code attempts to access a member of StateContexts.
""" """
return bool("useContext(StateContexts" in javascript_code) return bool("useContext(StateContexts" in javascript_code)
def reload_state_module(
module: str,
state: Type[BaseState] = State,
) -> None:
"""Reset rx.State subclasses to avoid conflict when reloading.
Args:
module: The module to reload.
state: Recursive argument for the state class to reload.
"""
for subclass in tuple(state.class_subclasses):
reload_state_module(module=module, state=subclass)
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
state._computed_var_dependencies = defaultdict(set)
state._substate_var_dependencies = defaultdict(set)
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()

View File

@ -43,13 +43,13 @@ import reflex.utils.exec
import reflex.utils.format import reflex.utils.format
import reflex.utils.prerequisites import reflex.utils.prerequisites
import reflex.utils.processes import reflex.utils.processes
from reflex.istate.builtins import reload_state_module
from reflex.state import ( from reflex.state import (
BaseState, BaseState,
StateManager, StateManager,
StateManagerDisk, StateManagerDisk,
StateManagerMemory, StateManagerMemory,
StateManagerRedis, StateManagerRedis,
reload_state_module,
) )
try: try:

View File

@ -429,7 +429,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
# Get the function name # Get the function name
name = parts[-1] name = parts[-1]
from reflex.state import State from reflex.istate.builtins import State
if state_full_name == "state" and name not in State.__dict__: if state_full_name == "state" and name not in State.__dict__:
return ("", to_snake_case(handler.fn.__qualname__)) return ("", to_snake_case(handler.fn.__qualname__))

View File

@ -279,7 +279,7 @@ def get_app(reload: bool = False) -> ModuleType:
app = __import__(module, fromlist=(constants.CompileVars.APP,)) app = __import__(module, fromlist=(constants.CompileVars.APP,))
if reload: if reload:
from reflex.state import reload_state_module from reflex.istate.builtins import reload_state_module
# Reset rx.State subclasses to avoid conflict when reloading. # Reset rx.State subclasses to avoid conflict when reloading.
reload_state_module(module=module) reload_state_module(module=module)

View File

@ -5,7 +5,8 @@ from typing import Generator
import pytest import pytest
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from reflex.state import State, _substate_key from reflex.istate.builtins import State
from reflex.state import _substate_key
from reflex.testing import AppHarness from reflex.testing import AppHarness
from . import utils from . import utils

View File

@ -12,7 +12,8 @@ from reflex.components.core.foreach import (
) )
from reflex.components.radix.themes.layout.box import box from reflex.components.radix.themes.layout.box import box
from reflex.components.radix.themes.typography.text import text from reflex.components.radix.themes.typography.text import text
from reflex.state import BaseState, ComponentState from reflex.istate.builtins import ComponentState
from reflex.state import BaseState
from reflex.vars.base import Var from reflex.vars.base import Var
from reflex.vars.number import NumberVar from reflex.vars.number import NumberVar
from reflex.vars.sequence import ArrayVar from reflex.vars.sequence import ArrayVar

View File

@ -1,7 +1,7 @@
import pytest import pytest
from reflex.components.core.html import Html from reflex.components.core.html import Html
from reflex.state import State from reflex.istate.builtins import State
def test_html_no_children(): def test_html_no_children():

View File

@ -10,7 +10,7 @@ from reflex.components.core.upload import (
get_upload_url, get_upload_url,
) )
from reflex.event import EventSpec from reflex.event import EventSpec
from reflex.state import State from reflex.istate.builtins import State
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var

View File

@ -1,7 +1,7 @@
import pytest import pytest
from reflex.event import Event from reflex.event import Event
from reflex.state import State from reflex.istate.builtins import State
def create_event(name): def create_event(name):

View File

@ -3,8 +3,9 @@ from __future__ import annotations
import pytest import pytest
from reflex.app import App from reflex.app import App
from reflex.istate.builtins import State
from reflex.middleware.hydrate_middleware import HydrateMiddleware from reflex.middleware.hydrate_middleware import HydrateMiddleware
from reflex.state import State, StateUpdate from reflex.state import StateUpdate
class TestState(State): class TestState(State):
@ -38,7 +39,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1, mocker):
event1: An Event. event1: An Event.
mocker: pytest mock object. mocker: pytest mock object.
""" """
mocker.patch("reflex.state.State.class_subclasses", {TestState}) mocker.patch("reflex.istate.builtins.State.class_subclasses", {TestState})
state = State() state = State()
update = await hydrate_middleware.preprocess( update = await hydrate_middleware.preprocess(
app=App(state=State), app=App(state=State),

View File

@ -4,7 +4,8 @@ from pathlib import Path
from typing import ClassVar, List from typing import ClassVar, List
import reflex as rx import reflex as rx
from reflex.state import BaseState, State from reflex.istate.builtins import State
from reflex.state import BaseState
class UploadState(BaseState): class UploadState(BaseState):

View File

@ -34,13 +34,12 @@ from reflex.components.base.fragment import Fragment
from reflex.components.core.cond import Cond from reflex.components.core.cond import Cond
from reflex.components.radix.themes.typography.text import Text from reflex.components.radix.themes.typography.text import Text
from reflex.event import Event from reflex.event import Event
from reflex.istate.builtins import OnLoadInternalState, State
from reflex.middleware import HydrateMiddleware from reflex.middleware import HydrateMiddleware
from reflex.model import Model from reflex.model import Model
from reflex.state import ( from reflex.state import (
BaseState, BaseState,
OnLoadInternalState,
RouterData, RouterData,
State,
StateManagerDisk, StateManagerDisk,
StateManagerMemory, StateManagerMemory,
StateManagerRedis, StateManagerRedis,
@ -760,7 +759,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
mocker: pytest mocker object. mocker: pytest mocker object.
""" """
mocker.patch( mocker.patch(
"reflex.state.State.class_subclasses", "reflex.istate.builtins.State.class_subclasses",
{state if state is FileUploadState else FileStateBase1}, {state if state is FileUploadState else FileStateBase1},
) )
state._tmp_path = tmp_path state._tmp_path = tmp_path

View File

@ -24,14 +24,13 @@ from reflex.base import Base
from reflex.components.sonner.toast import Toaster from reflex.components.sonner.toast import Toaster
from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.constants import CompileVars, RouteVar, SocketEvent
from reflex.event import Event, EventHandler from reflex.event import Event, EventHandler
from reflex.istate.builtins import OnLoadInternalState, State
from reflex.state import ( from reflex.state import (
BaseState, BaseState,
ImmutableStateError, ImmutableStateError,
LockExpiredError, LockExpiredError,
MutableProxy, MutableProxy,
OnLoadInternalState,
RouterData, RouterData,
State,
StateManager, StateManager,
StateManagerDisk, StateManagerDisk,
StateManagerMemory, StateManagerMemory,
@ -2779,7 +2778,8 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
mocker: pytest mock object. mocker: pytest mock object.
""" """
mocker.patch( mocker.patch(
"reflex.state.State.class_subclasses", {test_state, OnLoadInternalState} "reflex.istate.builtins.State.class_subclasses",
{test_state, OnLoadInternalState},
) )
app = app_module_mock.app = App( app = app_module_mock.app = App(
state=State, load_events={"index": [test_state.test_handler]} state=State, load_events={"index": [test_state.test_handler]}
@ -2825,7 +2825,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
mocker: pytest mock object. mocker: pytest mock object.
""" """
mocker.patch( mocker.patch(
"reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState} "reflex.istate.builtins.State.class_subclasses",
{OnLoadState, OnLoadInternalState},
) )
app = app_module_mock.app = App( app = app_module_mock.app = App(
state=State, state=State,
@ -3231,7 +3232,8 @@ config = rx.Config(
with chdir(proj_root): with chdir(proj_root):
# reload config for each parameter to avoid stale values # reload config for each parameter to avoid stale values
reflex.config.get_config(reload=True) reflex.config.get_config(reload=True)
from reflex.state import State, StateManager from reflex.istate.builtins import State
from reflex.state import StateManager
state_manager = StateManager.create(state=State) state_manager = StateManager.create(state=State)
assert state_manager.lock_expiration == expected_values[0] # type: ignore assert state_manager.lock_expiration == expected_values[0] # type: ignore