From 4e8169e1ed02d40ae5a941062dacc751ae47d4ba Mon Sep 17 00:00:00 2001 From: Lendemor Date: Tue, 22 Oct 2024 20:53:10 +0200 Subject: [PATCH] move builtins states to their own file --- reflex/__init__.py | 3 +- reflex/__init__.pyi | 4 +- reflex/app.py | 2 +- reflex/components/base/error_boundary.py | 2 +- reflex/components/component.py | 4 +- reflex/components/core/foreach.py | 2 +- reflex/experimental/layout.py | 2 +- reflex/experimental/layout.pyi | 2 +- reflex/istate/builtins.py | 207 ++++++++++++++++++ reflex/state.py | 199 +---------------- reflex/testing.py | 2 +- reflex/utils/format.py | 2 +- reflex/utils/prerequisites.py | 2 +- tests/integration/test_component_state.py | 3 +- tests/units/components/core/test_foreach.py | 3 +- tests/units/components/core/test_html.py | 2 +- tests/units/components/core/test_upload.py | 2 +- tests/units/middleware/conftest.py | 2 +- .../middleware/test_hydrate_middleware.py | 5 +- tests/units/states/upload.py | 3 +- tests/units/test_app.py | 5 +- tests/units/test_state.py | 12 +- 22 files changed, 242 insertions(+), 228 deletions(-) create mode 100644 reflex/istate/builtins.py diff --git a/reflex/__init__.py b/reflex/__init__.py index ad51d2cf4..ba173ee4a 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -320,6 +320,7 @@ _MAPPING: dict = { "upload_files", "window_alert", ], + "istate.builtins": ["ComponentState", "State"], "middleware": ["middleware", "Middleware"], "model": ["session", "Model"], "state": [ @@ -327,8 +328,6 @@ _MAPPING: dict = { "Cookie", "LocalStorage", "SessionStorage", - "ComponentState", - "State", ], "style": ["Style", "toggle_color_mode"], "utils.imports": ["ImportVar"], diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index d928778d8..d4b14cfa0 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -174,16 +174,16 @@ from .event import stop_propagation as stop_propagation from .event import upload_files as upload_files from .event import window_alert as window_alert from .experimental import _x as _x +from .istate.builtins import ComponentState as ComponentState +from .istate.builtins import State as State from .middleware import Middleware as Middleware from .middleware import middleware as middleware from .model import Model as Model from .model import session as session from .page import page as page -from .state import ComponentState as ComponentState from .state import Cookie as Cookie from .state import LocalStorage as LocalStorage from .state import SessionStorage as SessionStorage -from .state import State as State from .state import var as var from .style import Style as Style from .style import toggle_color_mode as toggle_color_mode diff --git a/reflex/app.py b/reflex/app.py index abf0b5d41..037659fd6 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -66,6 +66,7 @@ from reflex.components.core.upload import Upload, get_upload_dir from reflex.components.radix import themes from reflex.config import environment, get_config from reflex.event import Event, EventHandler, EventSpec, window_alert +from reflex.istate.builtins import State from reflex.model import Model, get_db_status from reflex.page import ( DECORATED_PAGES, @@ -78,7 +79,6 @@ from reflex.route import ( from reflex.state import ( BaseState, RouterData, - State, StateManager, StateUpdate, _substate_key, diff --git a/reflex/components/base/error_boundary.py b/reflex/components/base/error_boundary.py index 83becc034..f9ca53b3d 100644 --- a/reflex/components/base/error_boundary.py +++ b/reflex/components/base/error_boundary.py @@ -8,7 +8,7 @@ from reflex.compiler.compiler import _compile_component from reflex.components.component import Component from reflex.components.el import div, p from reflex.event import EventHandler -from reflex.state import FrontendEventExceptionState +from reflex.istate.builtins import FrontendEventExceptionState from reflex.vars.base import Var diff --git a/reflex/components/component.py b/reflex/components/component.py index a0d9c93b0..d5649ba8f 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -21,7 +21,7 @@ from typing import ( Union, ) -import reflex.state +import reflex.istate.builtins from reflex.base import Base from reflex.compiler.templates import STATEFUL_COMPONENT from reflex.components.core.breakpoints import Breakpoints @@ -224,7 +224,7 @@ class Component(BaseComponent, ABC): _memoization_mode: MemoizationMode = MemoizationMode() # State class associated with this component instance - State: Optional[Type[reflex.state.State]] = None + State: Optional[Type[reflex.istate.builtins.State]] = None def add_imports(self) -> ImportDict | list[ImportDict]: """Add imports for the component. diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index a3f97d594..71a4d5a53 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -9,7 +9,7 @@ from reflex.components.base.fragment import Fragment from reflex.components.component import Component from reflex.components.tags import IterTag from reflex.constants import MemoizationMode -from reflex.state import ComponentState +from reflex.istate.builtins import ComponentState from reflex.vars.base import LiteralVar, Var diff --git a/reflex/experimental/layout.py b/reflex/experimental/layout.py index a3b76581a..b7360af5e 100644 --- a/reflex/experimental/layout.py +++ b/reflex/experimental/layout.py @@ -14,7 +14,7 @@ from reflex.components.radix.themes.layout.container import Container from reflex.components.radix.themes.layout.stack import HStack from reflex.event import call_script from reflex.experimental import hooks -from reflex.state import ComponentState +from reflex.istate.builtins import ComponentState from reflex.style import Style from reflex.vars.base import Var diff --git a/reflex/experimental/layout.pyi b/reflex/experimental/layout.pyi index dcdac5b5d..411b39575 100644 --- a/reflex/experimental/layout.pyi +++ b/reflex/experimental/layout.pyi @@ -11,7 +11,7 @@ from reflex.components.component import Component, ComponentNamespace, Memoizati from reflex.components.radix.primitives.drawer import DrawerRoot from reflex.components.radix.themes.layout.box import Box from reflex.event import EventType -from reflex.state import ComponentState +from reflex.istate.builtins import ComponentState from reflex.style import Style from reflex.vars.base import Var diff --git a/reflex/istate/builtins.py b/reflex/istate/builtins.py new file mode 100644 index 000000000..8730b6388 --- /dev/null +++ b/reflex/istate/builtins.py @@ -0,0 +1,207 @@ +"""The built-in states used by reflex apps.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, ClassVar, Type + +import reflex.istate.dynamic +from reflex import constants, event +from reflex.event import Event, EventSpec, fix_events +from reflex.state import BaseState +from reflex.utils import prerequisites + +if TYPE_CHECKING: + from reflex.components.component import Component + + +class State(BaseState): + """The app Base State.""" + + # The hydrated bool. + is_hydrated: bool = False + + +class FrontendEventExceptionState(State): + """Substate for handling frontend exceptions.""" + + @event + def handle_frontend_exception(self, stack: str, component_stack: str) -> None: + """Handle frontend exceptions. + + If a frontend exception handler is provided, it will be called. + Otherwise, the default frontend exception handler will be called. + + Args: + stack: The stack trace of the exception. + component_stack: The stack trace of the component where the exception occurred. + + """ + app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) + app_instance.frontend_exception_handler(Exception(stack)) + + +class UpdateVarsInternalState(State): + """Substate for handling internal state var updates.""" + + async def update_vars_internal(self, vars: dict[str, Any]) -> None: + """Apply updates to fully qualified state vars. + + The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`, + and each value will be set on the appropriate substate instance. + + This function is primarily used to apply cookie and local storage + updates from the frontend to the appropriate substate. + + Args: + vars: The fully qualified vars and values to update. + """ + for var, value in vars.items(): + state_name, _, var_name = var.rpartition(".") + var_state_cls = State.get_class_substate(state_name) + var_state = await self.get_state(var_state_cls) + setattr(var_state, var_name, value) + + +class OnLoadInternalState(State): + """Substate for handling on_load event enumeration. + + This is a separate substate to avoid deserializing the entire state tree for every page navigation. + """ + + def on_load_internal(self) -> list[Event | EventSpec] | None: + """Queue on_load handlers for the current page. + + Returns: + The list of events to queue for on load handling. + """ + # Do not app._compile()! It should be already compiled by now. + app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + load_events = app.get_load_events(self.router.page.path) + if not load_events: + self.is_hydrated = True + return # Fast path for navigation with no on_load events defined. + self.is_hydrated = False + return [ + *fix_events( + load_events, + self.router.session.client_token, + router_data=self.router_data, + ), + State.set_is_hydrated(True), # type: ignore + ] + + +class ComponentState(State, mixin=True): + """Base class to allow for the creation of a state instance per component. + + This allows for the bundling of UI and state logic into a single class, + where each instance has a separate instance of the state. + + Subclass this class and define vars and event handlers in the traditional way. + Then define a `get_component` method that returns the UI for the component instance. + + See the full [docs](https://reflex.dev/docs/substates/component-state/) for more. + + Basic example: + ```python + # Subclass ComponentState and define vars and event handlers. + class Counter(rx.ComponentState): + # Define vars that change. + count: int = 0 + + # Define event handlers. + def increment(self): + self.count += 1 + + def decrement(self): + self.count -= 1 + + @classmethod + def get_component(cls, **props): + # Access the state vars and event handlers using `cls`. + return rx.hstack( + rx.button("Decrement", on_click=cls.decrement), + rx.text(cls.count), + rx.button("Increment", on_click=cls.increment), + **props, + ) + + counter = Counter.create() + ``` + """ + + # The number of components created from this class. + _per_component_state_instance_count: ClassVar[int] = 0 + + @classmethod + def __init_subclass__(cls, mixin: bool = True, **kwargs): + """Overwrite mixin default to True. + + Args: + mixin: Whether the subclass is a mixin and should not be initialized. + **kwargs: The kwargs to pass to the pydantic init_subclass method. + """ + super().__init_subclass__(mixin=mixin, **kwargs) + + @classmethod + def get_component(cls, *children, **props) -> "Component": + """Get the component instance. + + Args: + children: The children of the component. + props: The props of the component. + + Raises: + NotImplementedError: if the subclass does not override this method. + """ + raise NotImplementedError( + f"{cls.__name__} must implement get_component to return the component instance." + ) + + @classmethod + def create(cls, *children, **props) -> "Component": + """Create a new instance of the Component. + + Args: + children: The children of the component. + props: The props of the component. + + Returns: + A new instance of the Component with an independent copy of the State. + """ + cls._per_component_state_instance_count += 1 + state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" + component_state = type( + state_cls_name, + (cls, State), + {"__module__": reflex.istate.dynamic.__name__}, + mixin=False, + ) + # Save a reference to the dynamic state for pickle/unpickle. + setattr(reflex.istate.dynamic, state_cls_name, component_state) + component = component_state.get_component(*children, **props) + component.State = component_state + return component + + +def reload_state_module( + module: str, + state: Type[BaseState] = State, +) -> None: + """Reset rx.State subclasses to avoid conflict when reloading. + + Args: + module: The module to reload. + state: Recursive argument for the state class to reload. + + """ + for subclass in tuple(state.class_subclasses): + reload_state_module(module=module, state=subclass) + if subclass.__module__ == module and module is not None: + state.class_subclasses.remove(subclass) + state._always_dirty_substates.discard(subclass.get_name()) + state._computed_var_dependencies = defaultdict(set) + state._substate_var_dependencies = defaultdict(set) + state._init_var_dependency_dicts() + state.get_class_substate.cache_clear() diff --git a/reflex/state.py b/reflex/state.py index 3422d1ba7..cddecb9c7 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -39,7 +39,6 @@ from typing import ( from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self -from reflex import event from reflex.config import get_config from reflex.istate.data import RouterData from reflex.vars.base import ( @@ -89,10 +88,6 @@ from reflex.utils.serializers import serializer from reflex.utils.types import get_origin, override from reflex.vars import VarData -if TYPE_CHECKING: - from reflex.components.component import Component - - Delta = Dict[str, Any] var = computed_var @@ -2085,176 +2080,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return state -class State(BaseState): - """The app Base State.""" - - # The hydrated bool. - is_hydrated: bool = False - - -class FrontendEventExceptionState(State): - """Substate for handling frontend exceptions.""" - - @event - def handle_frontend_exception(self, stack: str, component_stack: str) -> None: - """Handle frontend exceptions. - - If a frontend exception handler is provided, it will be called. - Otherwise, the default frontend exception handler will be called. - - Args: - stack: The stack trace of the exception. - component_stack: The stack trace of the component where the exception occurred. - - """ - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - app_instance.frontend_exception_handler(Exception(stack)) - - -class UpdateVarsInternalState(State): - """Substate for handling internal state var updates.""" - - async def update_vars_internal(self, vars: dict[str, Any]) -> None: - """Apply updates to fully qualified state vars. - - The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`, - and each value will be set on the appropriate substate instance. - - This function is primarily used to apply cookie and local storage - updates from the frontend to the appropriate substate. - - Args: - vars: The fully qualified vars and values to update. - """ - for var, value in vars.items(): - state_name, _, var_name = var.rpartition(".") - var_state_cls = State.get_class_substate(state_name) - var_state = await self.get_state(var_state_cls) - setattr(var_state, var_name, value) - - -class OnLoadInternalState(State): - """Substate for handling on_load event enumeration. - - This is a separate substate to avoid deserializing the entire state tree for every page navigation. - """ - - def on_load_internal(self) -> list[Event | EventSpec] | None: - """Queue on_load handlers for the current page. - - Returns: - The list of events to queue for on load handling. - """ - # Do not app._compile()! It should be already compiled by now. - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - load_events = app.get_load_events(self.router.page.path) - if not load_events: - self.is_hydrated = True - return # Fast path for navigation with no on_load events defined. - self.is_hydrated = False - return [ - *fix_events( - load_events, - self.router.session.client_token, - router_data=self.router_data, - ), - State.set_is_hydrated(True), # type: ignore - ] - - -class ComponentState(State, mixin=True): - """Base class to allow for the creation of a state instance per component. - - This allows for the bundling of UI and state logic into a single class, - where each instance has a separate instance of the state. - - Subclass this class and define vars and event handlers in the traditional way. - Then define a `get_component` method that returns the UI for the component instance. - - See the full [docs](https://reflex.dev/docs/substates/component-state/) for more. - - Basic example: - ```python - # Subclass ComponentState and define vars and event handlers. - class Counter(rx.ComponentState): - # Define vars that change. - count: int = 0 - - # Define event handlers. - def increment(self): - self.count += 1 - - def decrement(self): - self.count -= 1 - - @classmethod - def get_component(cls, **props): - # Access the state vars and event handlers using `cls`. - return rx.hstack( - rx.button("Decrement", on_click=cls.decrement), - rx.text(cls.count), - rx.button("Increment", on_click=cls.increment), - **props, - ) - - counter = Counter.create() - ``` - """ - - # The number of components created from this class. - _per_component_state_instance_count: ClassVar[int] = 0 - - @classmethod - def __init_subclass__(cls, mixin: bool = True, **kwargs): - """Overwrite mixin default to True. - - Args: - mixin: Whether the subclass is a mixin and should not be initialized. - **kwargs: The kwargs to pass to the pydantic init_subclass method. - """ - super().__init_subclass__(mixin=mixin, **kwargs) - - @classmethod - def get_component(cls, *children, **props) -> "Component": - """Get the component instance. - - Args: - children: The children of the component. - props: The props of the component. - - Raises: - NotImplementedError: if the subclass does not override this method. - """ - raise NotImplementedError( - f"{cls.__name__} must implement get_component to return the component instance." - ) - - @classmethod - def create(cls, *children, **props) -> "Component": - """Create a new instance of the Component. - - Args: - children: The children of the component. - props: The props of the component. - - Returns: - A new instance of the Component with an independent copy of the State. - """ - cls._per_component_state_instance_count += 1 - state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" - component_state = type( - state_cls_name, - (cls, State), - {"__module__": reflex.istate.dynamic.__name__}, - mixin=False, - ) - # Save a reference to the dynamic state for pickle/unpickle. - setattr(reflex.istate.dynamic, state_cls_name, component_state) - component = component_state.get_component(*children, **props) - component.State = component_state - return component - - class StateProxy(wrapt.ObjectProxy): """Proxy of a state instance to control mutability of vars for a background task. @@ -2336,7 +2161,7 @@ class StateProxy(wrapt.ObjectProxy): super().__setattr__( "__wrapped__", await parent_state.get_state( - State.get_class_substate(self._self_substate_path) + BaseState.get_class_substate(self._self_substate_path) ), ) return self @@ -3804,25 +3629,3 @@ def code_uses_state_contexts(javascript_code: str) -> bool: True if the code attempts to access a member of StateContexts. """ return bool("useContext(StateContexts" in javascript_code) - - -def reload_state_module( - module: str, - state: Type[BaseState] = State, -) -> None: - """Reset rx.State subclasses to avoid conflict when reloading. - - Args: - module: The module to reload. - state: Recursive argument for the state class to reload. - - """ - for subclass in tuple(state.class_subclasses): - reload_state_module(module=module, state=subclass) - if subclass.__module__ == module and module is not None: - state.class_subclasses.remove(subclass) - state._always_dirty_substates.discard(subclass.get_name()) - state._computed_var_dependencies = defaultdict(set) - state._substate_var_dependencies = defaultdict(set) - state._init_var_dependency_dicts() - state.get_class_substate.cache_clear() diff --git a/reflex/testing.py b/reflex/testing.py index 6a45c51eb..414b49c53 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -43,13 +43,13 @@ import reflex.utils.exec import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes +from reflex.istate.builtins import reload_state_module from reflex.state import ( BaseState, StateManager, StateManagerDisk, StateManagerMemory, StateManagerRedis, - reload_state_module, ) try: diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 65c0f049b..82f979f5d 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -429,7 +429,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: # Get the function name name = parts[-1] - from reflex.state import State + from reflex.istate.builtins import State if state_full_name == "state" and name not in State.__dict__: return ("", to_snake_case(handler.fn.__qualname__)) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 33165af0e..4402d5a20 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -279,7 +279,7 @@ def get_app(reload: bool = False) -> ModuleType: app = __import__(module, fromlist=(constants.CompileVars.APP,)) if reload: - from reflex.state import reload_state_module + from reflex.istate.builtins import reload_state_module # Reset rx.State subclasses to avoid conflict when reloading. reload_state_module(module=module) diff --git a/tests/integration/test_component_state.py b/tests/integration/test_component_state.py index f4a295d07..acf6b05b8 100644 --- a/tests/integration/test_component_state.py +++ b/tests/integration/test_component_state.py @@ -5,7 +5,8 @@ from typing import Generator import pytest from selenium.webdriver.common.by import By -from reflex.state import State, _substate_key +from reflex.istate.builtins import State +from reflex.state import _substate_key from reflex.testing import AppHarness from . import utils diff --git a/tests/units/components/core/test_foreach.py b/tests/units/components/core/test_foreach.py index 228165d3e..6ec3bb1fd 100644 --- a/tests/units/components/core/test_foreach.py +++ b/tests/units/components/core/test_foreach.py @@ -12,7 +12,8 @@ from reflex.components.core.foreach import ( ) from reflex.components.radix.themes.layout.box import box from reflex.components.radix.themes.typography.text import text -from reflex.state import BaseState, ComponentState +from reflex.istate.builtins import ComponentState +from reflex.state import BaseState from reflex.vars.base import Var from reflex.vars.number import NumberVar from reflex.vars.sequence import ArrayVar diff --git a/tests/units/components/core/test_html.py b/tests/units/components/core/test_html.py index 4847e1d5a..883531b0f 100644 --- a/tests/units/components/core/test_html.py +++ b/tests/units/components/core/test_html.py @@ -1,7 +1,7 @@ import pytest from reflex.components.core.html import Html -from reflex.state import State +from reflex.istate.builtins import State def test_html_no_children(): diff --git a/tests/units/components/core/test_upload.py b/tests/units/components/core/test_upload.py index 710baa161..09e6983a7 100644 --- a/tests/units/components/core/test_upload.py +++ b/tests/units/components/core/test_upload.py @@ -10,7 +10,7 @@ from reflex.components.core.upload import ( get_upload_url, ) from reflex.event import EventSpec -from reflex.state import State +from reflex.istate.builtins import State from reflex.vars.base import LiteralVar, Var diff --git a/tests/units/middleware/conftest.py b/tests/units/middleware/conftest.py index d786db652..356e8d6c3 100644 --- a/tests/units/middleware/conftest.py +++ b/tests/units/middleware/conftest.py @@ -1,7 +1,7 @@ import pytest from reflex.event import Event -from reflex.state import State +from reflex.istate.builtins import State def create_event(name): diff --git a/tests/units/middleware/test_hydrate_middleware.py b/tests/units/middleware/test_hydrate_middleware.py index 9ee8d8d25..255e2eb77 100644 --- a/tests/units/middleware/test_hydrate_middleware.py +++ b/tests/units/middleware/test_hydrate_middleware.py @@ -3,8 +3,9 @@ from __future__ import annotations import pytest from reflex.app import App +from reflex.istate.builtins import State from reflex.middleware.hydrate_middleware import HydrateMiddleware -from reflex.state import State, StateUpdate +from reflex.state import StateUpdate class TestState(State): @@ -38,7 +39,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1, mocker): event1: An Event. mocker: pytest mock object. """ - mocker.patch("reflex.state.State.class_subclasses", {TestState}) + mocker.patch("reflex.istate.builtins.State.class_subclasses", {TestState}) state = State() update = await hydrate_middleware.preprocess( app=App(state=State), diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py index f81e9f235..a00f9a68c 100644 --- a/tests/units/states/upload.py +++ b/tests/units/states/upload.py @@ -4,7 +4,8 @@ from pathlib import Path from typing import ClassVar, List import reflex as rx -from reflex.state import BaseState, State +from reflex.istate.builtins import State +from reflex.state import BaseState class UploadState(BaseState): diff --git a/tests/units/test_app.py b/tests/units/test_app.py index a4ecfc5f7..0daecdc3c 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -34,13 +34,12 @@ from reflex.components.base.fragment import Fragment from reflex.components.core.cond import Cond from reflex.components.radix.themes.typography.text import Text from reflex.event import Event +from reflex.istate.builtins import OnLoadInternalState, State from reflex.middleware import HydrateMiddleware from reflex.model import Model from reflex.state import ( BaseState, - OnLoadInternalState, RouterData, - State, StateManagerDisk, StateManagerMemory, StateManagerRedis, @@ -760,7 +759,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): mocker: pytest mocker object. """ mocker.patch( - "reflex.state.State.class_subclasses", + "reflex.istate.builtins.State.class_subclasses", {state if state is FileUploadState else FileStateBase1}, ) state._tmp_path = tmp_path diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 610d69110..161f01f94 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -24,14 +24,13 @@ from reflex.base import Base from reflex.components.sonner.toast import Toaster from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.event import Event, EventHandler +from reflex.istate.builtins import OnLoadInternalState, State from reflex.state import ( BaseState, ImmutableStateError, LockExpiredError, MutableProxy, - OnLoadInternalState, RouterData, - State, StateManager, StateManagerDisk, StateManagerMemory, @@ -2768,7 +2767,8 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker): mocker: pytest mock object. """ mocker.patch( - "reflex.state.State.class_subclasses", {test_state, OnLoadInternalState} + "reflex.istate.builtins.State.class_subclasses", + {test_state, OnLoadInternalState}, ) app = app_module_mock.app = App( state=State, load_events={"index": [test_state.test_handler]} @@ -2814,7 +2814,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker): mocker: pytest mock object. """ mocker.patch( - "reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState} + "reflex.istate.builtins.State.class_subclasses", + {OnLoadState, OnLoadInternalState}, ) app = app_module_mock.app = App( state=State, @@ -3213,7 +3214,8 @@ config = rx.Config( with chdir(proj_root): # reload config for each parameter to avoid stale values reflex.config.get_config(reload=True) - from reflex.state import State, StateManager + from reflex.istate.builtins import State + from reflex.state import StateManager state_manager = StateManager.create(state=State) assert state_manager.lock_expiration == expected_values[0] # type: ignore