diff --git a/reflex/__init__.py b/reflex/__init__.py index ffc4426f9..c3ab12753 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -320,6 +320,7 @@ _MAPPING: dict = { "upload_files", "window_alert", ], + "istate.builtins": ["ComponentState", "State"], "istate.storage": [ "Cookie", "LocalStorage", @@ -329,8 +330,6 @@ _MAPPING: dict = { "model": ["session", "Model"], "state": [ "var", - "ComponentState", - "State", "dynamic", ], "style": ["Style", "toggle_color_mode"], diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index aa1c92b72..2efa7c47f 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -174,6 +174,8 @@ from .event import stop_propagation as stop_propagation from .event import upload_files as upload_files from .event import window_alert as window_alert from .experimental import _x as _x +from .istate.builtins import ComponentState as ComponentState +from .istate.builtins import State as State from .istate.storage import Cookie as Cookie from .istate.storage import LocalStorage as LocalStorage from .istate.storage import SessionStorage as SessionStorage @@ -182,8 +184,6 @@ from .middleware import middleware as middleware from .model import Model as Model from .model import session as session from .page import page as page -from .state import ComponentState as ComponentState -from .state import State as State from .state import dynamic as dynamic from .state import var as var from .style import Style as Style diff --git a/reflex/app.py b/reflex/app.py index 617ffc933..aa9b58a8c 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -71,6 +71,7 @@ from reflex.components.core.upload import Upload, get_upload_dir from reflex.components.radix import themes from reflex.config import environment, get_config from reflex.event import Event, EventHandler, EventSpec, window_alert +from reflex.istate.builtins import State from reflex.model import Model, get_db_status from reflex.page import ( DECORATED_PAGES, @@ -83,7 +84,6 @@ from reflex.route import ( from reflex.state import ( BaseState, RouterData, - State, StateManager, StateUpdate, _substate_key, diff --git a/reflex/components/base/error_boundary.py b/reflex/components/base/error_boundary.py index 83becc034..f9ca53b3d 100644 --- a/reflex/components/base/error_boundary.py +++ b/reflex/components/base/error_boundary.py @@ -8,7 +8,7 @@ from reflex.compiler.compiler import _compile_component from reflex.components.component import Component from reflex.components.el import div, p from reflex.event import EventHandler -from reflex.state import FrontendEventExceptionState +from reflex.istate.builtins import FrontendEventExceptionState from reflex.vars.base import Var diff --git a/reflex/components/component.py b/reflex/components/component.py index 9fea2f05b..f03a68e47 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -21,7 +21,7 @@ from typing import ( Union, ) -import reflex.state +import reflex.istate.builtins from reflex.base import Base from reflex.compiler.templates import STATEFUL_COMPONENT from reflex.components.core.breakpoints import Breakpoints @@ -225,7 +225,7 @@ class Component(BaseComponent, ABC): _memoization_mode: MemoizationMode = MemoizationMode() # State class associated with this component instance - State: Optional[Type[reflex.state.State]] = None + State: Optional[Type[reflex.istate.builtins.State]] = None def add_imports(self) -> ImportDict | list[ImportDict]: """Add imports for the component. diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index a3f97d594..71a4d5a53 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -9,7 +9,7 @@ from reflex.components.base.fragment import Fragment from reflex.components.component import Component from reflex.components.tags import IterTag from reflex.constants import MemoizationMode -from reflex.state import ComponentState +from reflex.istate.builtins import ComponentState from reflex.vars.base import LiteralVar, Var diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index b7ffef161..3afbe3c7a 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -62,16 +62,20 @@ class CompileVars(SimpleNamespace): # The name of the function for converting a dict to an event. TO_EVENT = "Event" # The name of the internal on_load event. - ON_LOAD_INTERNAL = "reflex___state____on_load_internal_state.on_load_internal" + ON_LOAD_INTERNAL = ( + "reflex___istate___builtins____on_load_internal_state.on_load_internal" + ) # The name of the internal event to update generic state vars. UPDATE_VARS_INTERNAL = ( - "reflex___state____update_vars_internal_state.update_vars_internal" + "reflex___istate___builtins____update_vars_internal_state.update_vars_internal" ) # The name of the frontend event exception state - FRONTEND_EXCEPTION_STATE = "reflex___state____frontend_event_exception_state" + FRONTEND_EXCEPTION_STATE = ( + "reflex___istate___builtins____frontend_event_exception_state" + ) # The full name of the frontend exception state FRONTEND_EXCEPTION_STATE_FULL = ( - f"reflex___state____state.{FRONTEND_EXCEPTION_STATE}" + f"reflex___istate___builtins____state.{FRONTEND_EXCEPTION_STATE}" ) diff --git a/reflex/experimental/layout.py b/reflex/experimental/layout.py index a3b76581a..b7360af5e 100644 --- a/reflex/experimental/layout.py +++ b/reflex/experimental/layout.py @@ -14,7 +14,7 @@ from reflex.components.radix.themes.layout.container import Container from reflex.components.radix.themes.layout.stack import HStack from reflex.event import call_script from reflex.experimental import hooks -from reflex.state import ComponentState +from reflex.istate.builtins import ComponentState from reflex.style import Style from reflex.vars.base import Var diff --git a/reflex/experimental/layout.pyi b/reflex/experimental/layout.pyi index dcdac5b5d..411b39575 100644 --- a/reflex/experimental/layout.pyi +++ b/reflex/experimental/layout.pyi @@ -11,7 +11,7 @@ from reflex.components.component import Component, ComponentNamespace, Memoizati from reflex.components.radix.primitives.drawer import DrawerRoot from reflex.components.radix.themes.layout.box import Box from reflex.event import EventType -from reflex.state import ComponentState +from reflex.istate.builtins import ComponentState from reflex.style import Style from reflex.vars.base import Var diff --git a/reflex/istate/builtins.py b/reflex/istate/builtins.py new file mode 100644 index 000000000..8730b6388 --- /dev/null +++ b/reflex/istate/builtins.py @@ -0,0 +1,207 @@ +"""The built-in states used by reflex apps.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, ClassVar, Type + +import reflex.istate.dynamic +from reflex import constants, event +from reflex.event import Event, EventSpec, fix_events +from reflex.state import BaseState +from reflex.utils import prerequisites + +if TYPE_CHECKING: + from reflex.components.component import Component + + +class State(BaseState): + """The app Base State.""" + + # The hydrated bool. + is_hydrated: bool = False + + +class FrontendEventExceptionState(State): + """Substate for handling frontend exceptions.""" + + @event + def handle_frontend_exception(self, stack: str, component_stack: str) -> None: + """Handle frontend exceptions. + + If a frontend exception handler is provided, it will be called. + Otherwise, the default frontend exception handler will be called. + + Args: + stack: The stack trace of the exception. + component_stack: The stack trace of the component where the exception occurred. + + """ + app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) + app_instance.frontend_exception_handler(Exception(stack)) + + +class UpdateVarsInternalState(State): + """Substate for handling internal state var updates.""" + + async def update_vars_internal(self, vars: dict[str, Any]) -> None: + """Apply updates to fully qualified state vars. + + The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`, + and each value will be set on the appropriate substate instance. + + This function is primarily used to apply cookie and local storage + updates from the frontend to the appropriate substate. + + Args: + vars: The fully qualified vars and values to update. + """ + for var, value in vars.items(): + state_name, _, var_name = var.rpartition(".") + var_state_cls = State.get_class_substate(state_name) + var_state = await self.get_state(var_state_cls) + setattr(var_state, var_name, value) + + +class OnLoadInternalState(State): + """Substate for handling on_load event enumeration. + + This is a separate substate to avoid deserializing the entire state tree for every page navigation. + """ + + def on_load_internal(self) -> list[Event | EventSpec] | None: + """Queue on_load handlers for the current page. + + Returns: + The list of events to queue for on load handling. + """ + # Do not app._compile()! It should be already compiled by now. + app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + load_events = app.get_load_events(self.router.page.path) + if not load_events: + self.is_hydrated = True + return # Fast path for navigation with no on_load events defined. + self.is_hydrated = False + return [ + *fix_events( + load_events, + self.router.session.client_token, + router_data=self.router_data, + ), + State.set_is_hydrated(True), # type: ignore + ] + + +class ComponentState(State, mixin=True): + """Base class to allow for the creation of a state instance per component. + + This allows for the bundling of UI and state logic into a single class, + where each instance has a separate instance of the state. + + Subclass this class and define vars and event handlers in the traditional way. + Then define a `get_component` method that returns the UI for the component instance. + + See the full [docs](https://reflex.dev/docs/substates/component-state/) for more. + + Basic example: + ```python + # Subclass ComponentState and define vars and event handlers. + class Counter(rx.ComponentState): + # Define vars that change. + count: int = 0 + + # Define event handlers. + def increment(self): + self.count += 1 + + def decrement(self): + self.count -= 1 + + @classmethod + def get_component(cls, **props): + # Access the state vars and event handlers using `cls`. + return rx.hstack( + rx.button("Decrement", on_click=cls.decrement), + rx.text(cls.count), + rx.button("Increment", on_click=cls.increment), + **props, + ) + + counter = Counter.create() + ``` + """ + + # The number of components created from this class. + _per_component_state_instance_count: ClassVar[int] = 0 + + @classmethod + def __init_subclass__(cls, mixin: bool = True, **kwargs): + """Overwrite mixin default to True. + + Args: + mixin: Whether the subclass is a mixin and should not be initialized. + **kwargs: The kwargs to pass to the pydantic init_subclass method. + """ + super().__init_subclass__(mixin=mixin, **kwargs) + + @classmethod + def get_component(cls, *children, **props) -> "Component": + """Get the component instance. + + Args: + children: The children of the component. + props: The props of the component. + + Raises: + NotImplementedError: if the subclass does not override this method. + """ + raise NotImplementedError( + f"{cls.__name__} must implement get_component to return the component instance." + ) + + @classmethod + def create(cls, *children, **props) -> "Component": + """Create a new instance of the Component. + + Args: + children: The children of the component. + props: The props of the component. + + Returns: + A new instance of the Component with an independent copy of the State. + """ + cls._per_component_state_instance_count += 1 + state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" + component_state = type( + state_cls_name, + (cls, State), + {"__module__": reflex.istate.dynamic.__name__}, + mixin=False, + ) + # Save a reference to the dynamic state for pickle/unpickle. + setattr(reflex.istate.dynamic, state_cls_name, component_state) + component = component_state.get_component(*children, **props) + component.State = component_state + return component + + +def reload_state_module( + module: str, + state: Type[BaseState] = State, +) -> None: + """Reset rx.State subclasses to avoid conflict when reloading. + + Args: + module: The module to reload. + state: Recursive argument for the state class to reload. + + """ + for subclass in tuple(state.class_subclasses): + reload_state_module(module=module, state=subclass) + if subclass.__module__ == module and module is not None: + state.class_subclasses.remove(subclass) + state._always_dirty_substates.discard(subclass.get_name()) + state._computed_var_dependencies = defaultdict(set) + state._substate_var_dependencies = defaultdict(set) + state._init_var_dependency_dicts() + state.get_class_substate.cache_clear() diff --git a/reflex/state.py b/reflex/state.py index 6e229b97d..8e6a3985d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -40,7 +40,6 @@ from typing import ( from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self -from reflex import event from reflex.config import get_config from reflex.istate.data import RouterData from reflex.istate.storage import ( @@ -94,13 +93,12 @@ from reflex.utils.serializers import serializer from reflex.utils.types import get_origin, override from reflex.vars import VarData -if TYPE_CHECKING: - from reflex.components.component import Component - - Delta = Dict[str, Any] var = computed_var +if TYPE_CHECKING: + from reflex.components.component import Component + # If the state is this large, it's considered a performance issue. TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb @@ -899,11 +897,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if len(path) == 0: return cls + print("get_name: ", cls.get_name()) if path[0] == cls.get_name(): if len(path) == 1: return cls path = path[1:] for substate in cls.get_substates(): + print("substate get_name: ", substate.get_name()) if path[0] == substate.get_name(): return substate.get_class_substate(path[1:]) raise ValueError(f"Invalid path: {path}") @@ -2112,13 +2112,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return state -class State(BaseState): - """The app Base State.""" - - # The hydrated bool. - is_hydrated: bool = False - - T = TypeVar("T", bound=BaseState) @@ -2164,169 +2157,6 @@ def dynamic(func: Callable[[T], Component]): return wrapper -class FrontendEventExceptionState(State): - """Substate for handling frontend exceptions.""" - - @event - def handle_frontend_exception(self, stack: str, component_stack: str) -> None: - """Handle frontend exceptions. - - If a frontend exception handler is provided, it will be called. - Otherwise, the default frontend exception handler will be called. - - Args: - stack: The stack trace of the exception. - component_stack: The stack trace of the component where the exception occurred. - - """ - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - app_instance.frontend_exception_handler(Exception(stack)) - - -class UpdateVarsInternalState(State): - """Substate for handling internal state var updates.""" - - async def update_vars_internal(self, vars: dict[str, Any]) -> None: - """Apply updates to fully qualified state vars. - - The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`, - and each value will be set on the appropriate substate instance. - - This function is primarily used to apply cookie and local storage - updates from the frontend to the appropriate substate. - - Args: - vars: The fully qualified vars and values to update. - """ - for var, value in vars.items(): - state_name, _, var_name = var.rpartition(".") - var_state_cls = State.get_class_substate(state_name) - var_state = await self.get_state(var_state_cls) - setattr(var_state, var_name, value) - - -class OnLoadInternalState(State): - """Substate for handling on_load event enumeration. - - This is a separate substate to avoid deserializing the entire state tree for every page navigation. - """ - - def on_load_internal(self) -> list[Event | EventSpec] | None: - """Queue on_load handlers for the current page. - - Returns: - The list of events to queue for on load handling. - """ - # Do not app._compile()! It should be already compiled by now. - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - load_events = app.get_load_events(self.router.page.path) - if not load_events: - self.is_hydrated = True - return # Fast path for navigation with no on_load events defined. - self.is_hydrated = False - return [ - *fix_events( - load_events, - self.router.session.client_token, - router_data=self.router_data, - ), - State.set_is_hydrated(True), # type: ignore - ] - - -class ComponentState(State, mixin=True): - """Base class to allow for the creation of a state instance per component. - - This allows for the bundling of UI and state logic into a single class, - where each instance has a separate instance of the state. - - Subclass this class and define vars and event handlers in the traditional way. - Then define a `get_component` method that returns the UI for the component instance. - - See the full [docs](https://reflex.dev/docs/substates/component-state/) for more. - - Basic example: - ```python - # Subclass ComponentState and define vars and event handlers. - class Counter(rx.ComponentState): - # Define vars that change. - count: int = 0 - - # Define event handlers. - def increment(self): - self.count += 1 - - def decrement(self): - self.count -= 1 - - @classmethod - def get_component(cls, **props): - # Access the state vars and event handlers using `cls`. - return rx.hstack( - rx.button("Decrement", on_click=cls.decrement), - rx.text(cls.count), - rx.button("Increment", on_click=cls.increment), - **props, - ) - - counter = Counter.create() - ``` - """ - - # The number of components created from this class. - _per_component_state_instance_count: ClassVar[int] = 0 - - @classmethod - def __init_subclass__(cls, mixin: bool = True, **kwargs): - """Overwrite mixin default to True. - - Args: - mixin: Whether the subclass is a mixin and should not be initialized. - **kwargs: The kwargs to pass to the pydantic init_subclass method. - """ - super().__init_subclass__(mixin=mixin, **kwargs) - - @classmethod - def get_component(cls, *children, **props) -> "Component": - """Get the component instance. - - Args: - children: The children of the component. - props: The props of the component. - - Raises: - NotImplementedError: if the subclass does not override this method. - """ - raise NotImplementedError( - f"{cls.__name__} must implement get_component to return the component instance." - ) - - @classmethod - def create(cls, *children, **props) -> "Component": - """Create a new instance of the Component. - - Args: - children: The children of the component. - props: The props of the component. - - Returns: - A new instance of the Component with an independent copy of the State. - """ - cls._per_component_state_instance_count += 1 - state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" - component_state = type( - state_cls_name, - (cls, State), - {"__module__": reflex.istate.dynamic.__name__}, - mixin=False, - ) - # Save a reference to the dynamic state for pickle/unpickle. - setattr(reflex.istate.dynamic, state_cls_name, component_state) - component = component_state.get_component(*children, **props) - component.State = component_state - return component - - class StateProxy(wrapt.ObjectProxy): """Proxy of a state instance to control mutability of vars for a background task. @@ -2408,7 +2238,7 @@ class StateProxy(wrapt.ObjectProxy): super().__setattr__( "__wrapped__", await parent_state.get_state( - State.get_class_substate(self._self_substate_path) + BaseState.get_class_substate(self._self_substate_path) ), ) return self @@ -3743,25 +3573,3 @@ def code_uses_state_contexts(javascript_code: str) -> bool: True if the code attempts to access a member of StateContexts. """ return bool("useContext(StateContexts" in javascript_code) - - -def reload_state_module( - module: str, - state: Type[BaseState] = State, -) -> None: - """Reset rx.State subclasses to avoid conflict when reloading. - - Args: - module: The module to reload. - state: Recursive argument for the state class to reload. - - """ - for subclass in tuple(state.class_subclasses): - reload_state_module(module=module, state=subclass) - if subclass.__module__ == module and module is not None: - state.class_subclasses.remove(subclass) - state._always_dirty_substates.discard(subclass.get_name()) - state._computed_var_dependencies = defaultdict(set) - state._substate_var_dependencies = defaultdict(set) - state._init_var_dependency_dicts() - state.get_class_substate.cache_clear() diff --git a/reflex/testing.py b/reflex/testing.py index b41e56884..60f3b95ee 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -43,13 +43,13 @@ import reflex.utils.exec import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes +from reflex.istate.builtins import reload_state_module from reflex.state import ( BaseState, StateManager, StateManagerDisk, StateManagerMemory, StateManagerRedis, - reload_state_module, ) try: diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 65c0f049b..82f979f5d 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -429,7 +429,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: # Get the function name name = parts[-1] - from reflex.state import State + from reflex.istate.builtins import State if state_full_name == "state" and name not in State.__dict__: return ("", to_snake_case(handler.fn.__qualname__)) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 33165af0e..4402d5a20 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -279,7 +279,7 @@ def get_app(reload: bool = False) -> ModuleType: app = __import__(module, fromlist=(constants.CompileVars.APP,)) if reload: - from reflex.state import reload_state_module + from reflex.istate.builtins import reload_state_module # Reset rx.State subclasses to avoid conflict when reloading. reload_state_module(module=module) diff --git a/tests/integration/test_component_state.py b/tests/integration/test_component_state.py index f4a295d07..acf6b05b8 100644 --- a/tests/integration/test_component_state.py +++ b/tests/integration/test_component_state.py @@ -5,7 +5,8 @@ from typing import Generator import pytest from selenium.webdriver.common.by import By -from reflex.state import State, _substate_key +from reflex.istate.builtins import State +from reflex.state import _substate_key from reflex.testing import AppHarness from . import utils diff --git a/tests/units/components/core/test_foreach.py b/tests/units/components/core/test_foreach.py index 228165d3e..6ec3bb1fd 100644 --- a/tests/units/components/core/test_foreach.py +++ b/tests/units/components/core/test_foreach.py @@ -12,7 +12,8 @@ from reflex.components.core.foreach import ( ) from reflex.components.radix.themes.layout.box import box from reflex.components.radix.themes.typography.text import text -from reflex.state import BaseState, ComponentState +from reflex.istate.builtins import ComponentState +from reflex.state import BaseState from reflex.vars.base import Var from reflex.vars.number import NumberVar from reflex.vars.sequence import ArrayVar diff --git a/tests/units/components/core/test_html.py b/tests/units/components/core/test_html.py index 4847e1d5a..883531b0f 100644 --- a/tests/units/components/core/test_html.py +++ b/tests/units/components/core/test_html.py @@ -1,7 +1,7 @@ import pytest from reflex.components.core.html import Html -from reflex.state import State +from reflex.istate.builtins import State def test_html_no_children(): diff --git a/tests/units/components/core/test_upload.py b/tests/units/components/core/test_upload.py index 710baa161..09e6983a7 100644 --- a/tests/units/components/core/test_upload.py +++ b/tests/units/components/core/test_upload.py @@ -10,7 +10,7 @@ from reflex.components.core.upload import ( get_upload_url, ) from reflex.event import EventSpec -from reflex.state import State +from reflex.istate.builtins import State from reflex.vars.base import LiteralVar, Var diff --git a/tests/units/middleware/conftest.py b/tests/units/middleware/conftest.py index d786db652..356e8d6c3 100644 --- a/tests/units/middleware/conftest.py +++ b/tests/units/middleware/conftest.py @@ -1,7 +1,7 @@ import pytest from reflex.event import Event -from reflex.state import State +from reflex.istate.builtins import State def create_event(name): diff --git a/tests/units/middleware/test_hydrate_middleware.py b/tests/units/middleware/test_hydrate_middleware.py index 9ee8d8d25..255e2eb77 100644 --- a/tests/units/middleware/test_hydrate_middleware.py +++ b/tests/units/middleware/test_hydrate_middleware.py @@ -3,8 +3,9 @@ from __future__ import annotations import pytest from reflex.app import App +from reflex.istate.builtins import State from reflex.middleware.hydrate_middleware import HydrateMiddleware -from reflex.state import State, StateUpdate +from reflex.state import StateUpdate class TestState(State): @@ -38,7 +39,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1, mocker): event1: An Event. mocker: pytest mock object. """ - mocker.patch("reflex.state.State.class_subclasses", {TestState}) + mocker.patch("reflex.istate.builtins.State.class_subclasses", {TestState}) state = State() update = await hydrate_middleware.preprocess( app=App(state=State), diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py index f81e9f235..a00f9a68c 100644 --- a/tests/units/states/upload.py +++ b/tests/units/states/upload.py @@ -4,7 +4,8 @@ from pathlib import Path from typing import ClassVar, List import reflex as rx -from reflex.state import BaseState, State +from reflex.istate.builtins import State +from reflex.state import BaseState class UploadState(BaseState): diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 6bb81522f..c1cd25ee8 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -34,13 +34,12 @@ from reflex.components.base.fragment import Fragment from reflex.components.core.cond import Cond from reflex.components.radix.themes.typography.text import Text from reflex.event import Event +from reflex.istate.builtins import OnLoadInternalState, State from reflex.middleware import HydrateMiddleware from reflex.model import Model from reflex.state import ( BaseState, - OnLoadInternalState, RouterData, - State, StateManagerDisk, StateManagerMemory, StateManagerRedis, @@ -765,7 +764,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): mocker: pytest mocker object. """ mocker.patch( - "reflex.state.State.class_subclasses", + "reflex.istate.builtins.State.class_subclasses", {state if state is FileUploadState else FileStateBase1}, ) state._tmp_path = tmp_path diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 89dd1fd3d..c309621e4 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -24,14 +24,13 @@ from reflex.base import Base from reflex.components.sonner.toast import Toaster from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.event import Event, EventHandler +from reflex.istate.builtins import OnLoadInternalState, State from reflex.state import ( BaseState, ImmutableStateError, LockExpiredError, MutableProxy, - OnLoadInternalState, RouterData, - State, StateManager, StateManagerDisk, StateManagerMemory, @@ -2779,7 +2778,8 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker): mocker: pytest mock object. """ mocker.patch( - "reflex.state.State.class_subclasses", {test_state, OnLoadInternalState} + "reflex.istate.builtins.State.class_subclasses", + {test_state, OnLoadInternalState}, ) app = app_module_mock.app = App( state=State, load_events={"index": [test_state.test_handler]} @@ -2825,7 +2825,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker): mocker: pytest mock object. """ mocker.patch( - "reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState} + "reflex.istate.builtins.State.class_subclasses", + {OnLoadState, OnLoadInternalState}, ) app = app_module_mock.app = App( state=State, @@ -3231,7 +3232,8 @@ config = rx.Config( with chdir(proj_root): # reload config for each parameter to avoid stale values reflex.config.get_config(reload=True) - from reflex.state import State, StateManager + from reflex.istate.builtins import State + from reflex.state import StateManager state_manager = StateManager.create(state=State) assert state_manager.lock_expiration == expected_values[0] # type: ignore