diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2 index 16610de30..6f73c70c4 100644 --- a/reflex/.templates/jinja/web/pages/index.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/index.js.jinja2 @@ -8,7 +8,9 @@ {% block export %} export default function Component() { +{% if state_name %} const {{state_name}} = useContext(StateContext) +{% endif %} const {{const.router}} = useRouter() const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext) const focusRef = useRef(); diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2 index 274718a56..c931b7515 100644 --- a/reflex/.templates/jinja/web/utils/context.js.jinja2 +++ b/reflex/.templates/jinja/web/utils/context.js.jinja2 @@ -1,14 +1,29 @@ import { createContext, useState } from "react" import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js" +{% if initial_state %} export const initialState = {{ initial_state|json_dumps }} +{% else %} +export const initialState = {} +{% endif %} + export const ColorModeContext = createContext(null); export const StateContext = createContext(null); export const EventLoopContext = createContext(null); +{% if client_storage %} export const clientStorage = {{ client_storage|json_dumps }} +{% else %} +export const clientStorage = {} +{% endif %} + +{% if state_name %} export const initialEvents = () => [ Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)), ] +{% else %} +export const initialEvents = () => [] +{% endif %} + export const isDevMode = {{ is_dev_mode|json_dumps }} export function EventLoopProvider({ children }) { diff --git a/reflex/app.py b/reflex/app.py index 4cca15cc1..e30f8d1a5 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -53,11 +53,9 @@ from reflex.route import ( verify_route_validity, ) from reflex.state import ( - DefaultState, RouterData, State, StateManager, - StateManagerMemory, StateUpdate, ) from reflex.utils import console, format, prerequisites, types @@ -96,10 +94,10 @@ class App(Base): socket_app: Optional[ASGIApp] = None # The state class to use for the app. - state: Type[State] = DefaultState + state: Optional[Type[State]] = None # Class to manage many client states. - state_manager: StateManager = StateManagerMemory(state=DefaultState) + _state_manager: Optional[StateManager] = None # The styling to apply to each component. style: ComponentStyle = {} @@ -148,19 +146,19 @@ class App(Base): ) super().__init__(*args, **kwargs) state_subclasses = State.__subclasses__() - inferred_state = state_subclasses[-1] + inferred_state = state_subclasses[-1] if state_subclasses else None is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ # Special case to allow test cases have multiple subclasses of rx.State. if not is_testing_env: - # Only the default state and the client state should be allowed as subclasses. - if len(state_subclasses) > 2: + # Only one State class is allowed. + if len(state_subclasses) > 1: raise ValueError( "rx.State has been subclassed multiple times. Only one subclass is allowed" ) # verify that provided state is valid - if self.state not in [DefaultState, inferred_state]: + if self.state and inferred_state and self.state is not inferred_state: console.warn( f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported." f" Defaulting to root state: ({inferred_state.__name__})" @@ -172,15 +170,15 @@ class App(Base): # Add middleware. self.middleware.append(HydrateMiddleware()) - # Set up the state manager. - self.state_manager = StateManager.create(state=self.state) - # Set up the API. self.api = FastAPI() self.add_cors() self.add_default_endpoints() - if self.state is not DefaultState: + if self.state: + # Set up the state manager. + self._state_manager = StateManager.create(state=self.state) + # Set up the Socket.IO AsyncServer. self.sio = AsyncServer( async_mode="asgi", @@ -212,10 +210,7 @@ class App(Base): self.setup_admin_dash() # If a State is not used and no overlay_component is specified, do not render the connection modal - if ( - self.state is DefaultState - and self.overlay_component is default_overlay_component - ): + if self.state is None and self.overlay_component is default_overlay_component: self.overlay_component = None def __repr__(self) -> str: @@ -224,7 +219,7 @@ class App(Base): Returns: The string representation of the app. """ - return f"" + return f"" def __call__(self) -> FastAPI: """Run the backend api instance. @@ -252,6 +247,20 @@ class App(Base): allow_origins=["*"], ) + @property + def state_manager(self) -> StateManager: + """Get the state manager. + + Returns: + The initialized state manager. + + Raises: + ValueError: if the state has not been initialized. + """ + if self._state_manager is None: + raise ValueError("The state manager has not been initialized.") + return self._state_manager + async def preprocess(self, state: State, event: Event) -> StateUpdate | None: """Preprocess the event. @@ -385,7 +394,8 @@ class App(Base): verify_route_validity(route) # Apply dynamic args to the route. - self.state.setup_dynamic_args(get_route_args(route)) + if self.state: + self.state.setup_dynamic_args(get_route_args(route)) # Generate the component if it is a callable. component = self._generate_component(component) @@ -715,6 +725,7 @@ class App(Base): """ if self.event_namespace is None: raise RuntimeError("App has not been initialized yet.") + # Get exclusive access to the state. async with self.state_manager.modify_state(token) as state: # No other event handler can modify the state while in this context. @@ -862,6 +873,7 @@ def upload(app: App): for file in files: assert file.filename is not None file.filename = file.filename.split(":")[-1] + # Get the state for the session. async with app.state_manager.modify_state(token) as state: # get the current session ID diff --git a/reflex/app.pyi b/reflex/app.pyi index 74ea757d2..f7e63727a 100644 --- a/reflex/app.pyi +++ b/reflex/app.pyi @@ -32,7 +32,6 @@ from reflex.route import ( verify_route_validity as verify_route_validity, ) from reflex.state import ( - DefaultState as DefaultState, State as State, StateManager as StateManager, StateUpdate as StateUpdate, diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 371908567..7ee5c6ded 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -3,7 +3,7 @@ from __future__ import annotations import os from pathlib import Path -from typing import Type +from typing import Optional, Type from reflex import constants from reflex.compiler import templates, utils @@ -89,7 +89,7 @@ def _compile_theme(theme: dict) -> str: return templates.THEME.render(theme=theme) -def _compile_contexts(state: Type[State]) -> str: +def _compile_contexts(state: Optional[Type[State]]) -> str: """Compile the initial state and contexts. Args: @@ -98,11 +98,16 @@ def _compile_contexts(state: Type[State]) -> str: Returns: The compiled context file. """ - return templates.CONTEXT.render( - initial_state=utils.compile_state(state), - state_name=state.get_name(), - client_storage=utils.compile_client_storage(state), - is_dev_mode=os.environ.get("REFLEX_ENV_MODE", "dev") == "dev", + is_dev_mode = os.environ.get("REFLEX_ENV_MODE", "dev") == "dev" + return ( + templates.CONTEXT.render( + initial_state=utils.compile_state(state), + state_name=state.get_name(), + client_storage=utils.compile_client_storage(state), + is_dev_mode=is_dev_mode, + ) + if state + else templates.CONTEXT.render(is_dev_mode=is_dev_mode) ) @@ -125,13 +130,15 @@ def _compile_page( imports = utils.compile_imports(imports) # Compile the code to render the component. + kwargs = {"state_name": state.get_name()} if state else {} + return templates.PAGE.render( imports=imports, dynamic_imports=component.get_dynamic_imports(), custom_codes=component.get_custom_code(), - state_name=state.get_name(), hooks=component.get_hooks(), render=component.render(), + **kwargs, ) @@ -296,7 +303,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]: return output_path, code -def compile_contexts(state: Type[State]) -> tuple[str, str]: +def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]: """Compile the initial state / context. Args: diff --git a/reflex/state.py b/reflex/state.py index bfb0a14ca..e7f4113b9 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1368,12 +1368,6 @@ class StateProxy(wrapt.ObjectProxy): super().__setattr__(name, value) -class DefaultState(State): - """The default empty state.""" - - pass - - class StateUpdate(Base): """A state update sent to the frontend.""" @@ -1394,7 +1388,7 @@ class StateManager(Base, ABC): state: Type[State] @classmethod - def create(cls, state: Type[State] = DefaultState): + def create(cls, state: Type[State]): """Create a new state manager. Args: diff --git a/reflex/testing.py b/reflex/testing.py index e9962ab8a..8d204c3d5 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -495,13 +495,13 @@ class AppHarness: if isinstance(self.state_manager, StateManagerRedis): # Temporarily replace the app's state manager with our own, since # the redis connection is on the backend_thread event loop - self.app_instance.state_manager = self.state_manager + self.app_instance._state_manager = self.state_manager try: async with self.app_instance.modify_state(token) as state: yield state finally: if isinstance(self.state_manager, StateManagerRedis): - self.app_instance.state_manager = app_state_manager + self.app_instance._state_manager = app_state_manager await self.state_manager.redis.close() def poll_for_content( diff --git a/tests/test_app.py b/tests/test_app.py index 5ea59e1e0..349503fd6 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -25,7 +25,6 @@ from reflex import AdminDash, constants from reflex.app import ( App, ComponentCallable, - DefaultState, default_overlay_component, process, upload, @@ -49,6 +48,12 @@ from .states import ( ) +class EmptyState(State): + """An empty state.""" + + pass + + @pytest.fixture def index_page(): """An index page. @@ -192,7 +197,6 @@ def test_default_app(app: App): Args: app: The app to test. """ - assert app.state() == DefaultState() assert app.middleware == [HydrateMiddleware()] assert app.style == Style() assert app.admin_dash is None @@ -240,14 +244,14 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool): assert set(app.pages.keys()) == {"test"} -def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool): +def test_add_page_set_route_dynamic(index_page, windows_platform: bool): """Test adding a page with dynamic route variable to an app. Args: - app: The app to test. index_page: The index page. windows_platform: Whether the system is windows. """ + app = App(state=EmptyState) route = "/test/[dynamic]" if windows_platform: route.lstrip("/").replace("/", "\\") @@ -255,7 +259,7 @@ def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool app.add_page(index_page, route=route) assert set(app.pages.keys()) == {"test/[dynamic]"} assert "dynamic" in app.state.computed_vars - assert app.state.computed_vars["dynamic"]._deps(objclass=DefaultState) == { + assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == { constants.ROUTER } assert constants.ROUTER in app.state().computed_var_dependencies @@ -1093,9 +1097,9 @@ async def test_process_events(mocker, token: str): @pytest.mark.parametrize( ("state", "overlay_component", "exp_page_child"), [ - (DefaultState, default_overlay_component, None), - (DefaultState, None, None), - (DefaultState, Text.create("foo"), Text), + (None, default_overlay_component, None), + (None, None, None), + (None, Text.create("foo"), Text), (State, default_overlay_component, Fragment), (State, None, None), (State, Text.create("foo"), Text), diff --git a/tests/test_testing.py b/tests/test_testing.py index e24c7224f..ff87534ba 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -16,7 +16,10 @@ def test_app_harness(tmp_path): def BasicApp(): import reflex as rx - app = rx.App() + class State(rx.State): + pass + + app = rx.App(state=State) app.add_page(lambda: rx.text("Basic App"), route="/", title="index") app.compile()