Remove Default state (#1978)

This commit is contained in:
Elijah Ahianyo 2023-10-18 16:57:27 +00:00 committed by GitHub
parent b4bb849388
commit b652d40ee5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 82 additions and 46 deletions

View File

@ -8,7 +8,9 @@
{% block export %} {% block export %}
export default function Component() { export default function Component() {
{% if state_name %}
const {{state_name}} = useContext(StateContext) const {{state_name}} = useContext(StateContext)
{% endif %}
const {{const.router}} = useRouter() const {{const.router}} = useRouter()
const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext) const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext)
const focusRef = useRef(); const focusRef = useRef();

View File

@ -1,14 +1,29 @@
import { createContext, useState } from "react" import { createContext, useState } from "react"
import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js" import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
{% if initial_state %}
export const initialState = {{ initial_state|json_dumps }} export const initialState = {{ initial_state|json_dumps }}
{% else %}
export const initialState = {}
{% endif %}
export const ColorModeContext = createContext(null); export const ColorModeContext = createContext(null);
export const StateContext = createContext(null); export const StateContext = createContext(null);
export const EventLoopContext = createContext(null); export const EventLoopContext = createContext(null);
{% if client_storage %}
export const clientStorage = {{ client_storage|json_dumps }} export const clientStorage = {{ client_storage|json_dumps }}
{% else %}
export const clientStorage = {}
{% endif %}
{% if state_name %}
export const initialEvents = () => [ export const initialEvents = () => [
Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)), Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)),
] ]
{% else %}
export const initialEvents = () => []
{% endif %}
export const isDevMode = {{ is_dev_mode|json_dumps }} export const isDevMode = {{ is_dev_mode|json_dumps }}
export function EventLoopProvider({ children }) { export function EventLoopProvider({ children }) {

View File

@ -53,11 +53,9 @@ from reflex.route import (
verify_route_validity, verify_route_validity,
) )
from reflex.state import ( from reflex.state import (
DefaultState,
RouterData, RouterData,
State, State,
StateManager, StateManager,
StateManagerMemory,
StateUpdate, StateUpdate,
) )
from reflex.utils import console, format, prerequisites, types from reflex.utils import console, format, prerequisites, types
@ -96,10 +94,10 @@ class App(Base):
socket_app: Optional[ASGIApp] = None socket_app: Optional[ASGIApp] = None
# The state class to use for the app. # The state class to use for the app.
state: Type[State] = DefaultState state: Optional[Type[State]] = None
# Class to manage many client states. # Class to manage many client states.
state_manager: StateManager = StateManagerMemory(state=DefaultState) _state_manager: Optional[StateManager] = None
# The styling to apply to each component. # The styling to apply to each component.
style: ComponentStyle = {} style: ComponentStyle = {}
@ -148,19 +146,19 @@ class App(Base):
) )
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
state_subclasses = State.__subclasses__() state_subclasses = State.__subclasses__()
inferred_state = state_subclasses[-1] inferred_state = state_subclasses[-1] if state_subclasses else None
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
# Special case to allow test cases have multiple subclasses of rx.State. # Special case to allow test cases have multiple subclasses of rx.State.
if not is_testing_env: if not is_testing_env:
# Only the default state and the client state should be allowed as subclasses. # Only one State class is allowed.
if len(state_subclasses) > 2: if len(state_subclasses) > 1:
raise ValueError( raise ValueError(
"rx.State has been subclassed multiple times. Only one subclass is allowed" "rx.State has been subclassed multiple times. Only one subclass is allowed"
) )
# verify that provided state is valid # verify that provided state is valid
if self.state not in [DefaultState, inferred_state]: if self.state and inferred_state and self.state is not inferred_state:
console.warn( console.warn(
f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported." f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported."
f" Defaulting to root state: ({inferred_state.__name__})" f" Defaulting to root state: ({inferred_state.__name__})"
@ -172,15 +170,15 @@ class App(Base):
# Add middleware. # Add middleware.
self.middleware.append(HydrateMiddleware()) self.middleware.append(HydrateMiddleware())
# Set up the state manager.
self.state_manager = StateManager.create(state=self.state)
# Set up the API. # Set up the API.
self.api = FastAPI() self.api = FastAPI()
self.add_cors() self.add_cors()
self.add_default_endpoints() self.add_default_endpoints()
if self.state is not DefaultState: if self.state:
# Set up the state manager.
self._state_manager = StateManager.create(state=self.state)
# Set up the Socket.IO AsyncServer. # Set up the Socket.IO AsyncServer.
self.sio = AsyncServer( self.sio = AsyncServer(
async_mode="asgi", async_mode="asgi",
@ -212,10 +210,7 @@ class App(Base):
self.setup_admin_dash() self.setup_admin_dash()
# If a State is not used and no overlay_component is specified, do not render the connection modal # If a State is not used and no overlay_component is specified, do not render the connection modal
if ( if self.state is None and self.overlay_component is default_overlay_component:
self.state is DefaultState
and self.overlay_component is default_overlay_component
):
self.overlay_component = None self.overlay_component = None
def __repr__(self) -> str: def __repr__(self) -> str:
@ -224,7 +219,7 @@ class App(Base):
Returns: Returns:
The string representation of the app. The string representation of the app.
""" """
return f"<App state={self.state.__name__}>" return f"<App state={self.state.__name__ if self.state else None}>"
def __call__(self) -> FastAPI: def __call__(self) -> FastAPI:
"""Run the backend api instance. """Run the backend api instance.
@ -252,6 +247,20 @@ class App(Base):
allow_origins=["*"], allow_origins=["*"],
) )
@property
def state_manager(self) -> StateManager:
"""Get the state manager.
Returns:
The initialized state manager.
Raises:
ValueError: if the state has not been initialized.
"""
if self._state_manager is None:
raise ValueError("The state manager has not been initialized.")
return self._state_manager
async def preprocess(self, state: State, event: Event) -> StateUpdate | None: async def preprocess(self, state: State, event: Event) -> StateUpdate | None:
"""Preprocess the event. """Preprocess the event.
@ -385,7 +394,8 @@ class App(Base):
verify_route_validity(route) verify_route_validity(route)
# Apply dynamic args to the route. # Apply dynamic args to the route.
self.state.setup_dynamic_args(get_route_args(route)) if self.state:
self.state.setup_dynamic_args(get_route_args(route))
# Generate the component if it is a callable. # Generate the component if it is a callable.
component = self._generate_component(component) component = self._generate_component(component)
@ -715,6 +725,7 @@ class App(Base):
""" """
if self.event_namespace is None: if self.event_namespace is None:
raise RuntimeError("App has not been initialized yet.") raise RuntimeError("App has not been initialized yet.")
# Get exclusive access to the state. # Get exclusive access to the state.
async with self.state_manager.modify_state(token) as state: async with self.state_manager.modify_state(token) as state:
# No other event handler can modify the state while in this context. # No other event handler can modify the state while in this context.
@ -862,6 +873,7 @@ def upload(app: App):
for file in files: for file in files:
assert file.filename is not None assert file.filename is not None
file.filename = file.filename.split(":")[-1] file.filename = file.filename.split(":")[-1]
# Get the state for the session. # Get the state for the session.
async with app.state_manager.modify_state(token) as state: async with app.state_manager.modify_state(token) as state:
# get the current session ID # get the current session ID

View File

@ -32,7 +32,6 @@ from reflex.route import (
verify_route_validity as verify_route_validity, verify_route_validity as verify_route_validity,
) )
from reflex.state import ( from reflex.state import (
DefaultState as DefaultState,
State as State, State as State,
StateManager as StateManager, StateManager as StateManager,
StateUpdate as StateUpdate, StateUpdate as StateUpdate,

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import os import os
from pathlib import Path from pathlib import Path
from typing import Type from typing import Optional, Type
from reflex import constants from reflex import constants
from reflex.compiler import templates, utils from reflex.compiler import templates, utils
@ -89,7 +89,7 @@ def _compile_theme(theme: dict) -> str:
return templates.THEME.render(theme=theme) return templates.THEME.render(theme=theme)
def _compile_contexts(state: Type[State]) -> str: def _compile_contexts(state: Optional[Type[State]]) -> str:
"""Compile the initial state and contexts. """Compile the initial state and contexts.
Args: Args:
@ -98,11 +98,16 @@ def _compile_contexts(state: Type[State]) -> str:
Returns: Returns:
The compiled context file. The compiled context file.
""" """
return templates.CONTEXT.render( is_dev_mode = os.environ.get("REFLEX_ENV_MODE", "dev") == "dev"
initial_state=utils.compile_state(state), return (
state_name=state.get_name(), templates.CONTEXT.render(
client_storage=utils.compile_client_storage(state), initial_state=utils.compile_state(state),
is_dev_mode=os.environ.get("REFLEX_ENV_MODE", "dev") == "dev", state_name=state.get_name(),
client_storage=utils.compile_client_storage(state),
is_dev_mode=is_dev_mode,
)
if state
else templates.CONTEXT.render(is_dev_mode=is_dev_mode)
) )
@ -125,13 +130,15 @@ def _compile_page(
imports = utils.compile_imports(imports) imports = utils.compile_imports(imports)
# Compile the code to render the component. # Compile the code to render the component.
kwargs = {"state_name": state.get_name()} if state else {}
return templates.PAGE.render( return templates.PAGE.render(
imports=imports, imports=imports,
dynamic_imports=component.get_dynamic_imports(), dynamic_imports=component.get_dynamic_imports(),
custom_codes=component.get_custom_code(), custom_codes=component.get_custom_code(),
state_name=state.get_name(),
hooks=component.get_hooks(), hooks=component.get_hooks(),
render=component.render(), render=component.render(),
**kwargs,
) )
@ -296,7 +303,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]:
return output_path, code return output_path, code
def compile_contexts(state: Type[State]) -> tuple[str, str]: def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
"""Compile the initial state / context. """Compile the initial state / context.
Args: Args:

View File

@ -1368,12 +1368,6 @@ class StateProxy(wrapt.ObjectProxy):
super().__setattr__(name, value) super().__setattr__(name, value)
class DefaultState(State):
"""The default empty state."""
pass
class StateUpdate(Base): class StateUpdate(Base):
"""A state update sent to the frontend.""" """A state update sent to the frontend."""
@ -1394,7 +1388,7 @@ class StateManager(Base, ABC):
state: Type[State] state: Type[State]
@classmethod @classmethod
def create(cls, state: Type[State] = DefaultState): def create(cls, state: Type[State]):
"""Create a new state manager. """Create a new state manager.
Args: Args:

View File

@ -495,13 +495,13 @@ class AppHarness:
if isinstance(self.state_manager, StateManagerRedis): if isinstance(self.state_manager, StateManagerRedis):
# Temporarily replace the app's state manager with our own, since # Temporarily replace the app's state manager with our own, since
# the redis connection is on the backend_thread event loop # the redis connection is on the backend_thread event loop
self.app_instance.state_manager = self.state_manager self.app_instance._state_manager = self.state_manager
try: try:
async with self.app_instance.modify_state(token) as state: async with self.app_instance.modify_state(token) as state:
yield state yield state
finally: finally:
if isinstance(self.state_manager, StateManagerRedis): if isinstance(self.state_manager, StateManagerRedis):
self.app_instance.state_manager = app_state_manager self.app_instance._state_manager = app_state_manager
await self.state_manager.redis.close() await self.state_manager.redis.close()
def poll_for_content( def poll_for_content(

View File

@ -25,7 +25,6 @@ from reflex import AdminDash, constants
from reflex.app import ( from reflex.app import (
App, App,
ComponentCallable, ComponentCallable,
DefaultState,
default_overlay_component, default_overlay_component,
process, process,
upload, upload,
@ -49,6 +48,12 @@ from .states import (
) )
class EmptyState(State):
"""An empty state."""
pass
@pytest.fixture @pytest.fixture
def index_page(): def index_page():
"""An index page. """An index page.
@ -192,7 +197,6 @@ def test_default_app(app: App):
Args: Args:
app: The app to test. app: The app to test.
""" """
assert app.state() == DefaultState()
assert app.middleware == [HydrateMiddleware()] assert app.middleware == [HydrateMiddleware()]
assert app.style == Style() assert app.style == Style()
assert app.admin_dash is None assert app.admin_dash is None
@ -240,14 +244,14 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool):
assert set(app.pages.keys()) == {"test"} assert set(app.pages.keys()) == {"test"}
def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool): def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
"""Test adding a page with dynamic route variable to an app. """Test adding a page with dynamic route variable to an app.
Args: Args:
app: The app to test.
index_page: The index page. index_page: The index page.
windows_platform: Whether the system is windows. windows_platform: Whether the system is windows.
""" """
app = App(state=EmptyState)
route = "/test/[dynamic]" route = "/test/[dynamic]"
if windows_platform: if windows_platform:
route.lstrip("/").replace("/", "\\") route.lstrip("/").replace("/", "\\")
@ -255,7 +259,7 @@ def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool
app.add_page(index_page, route=route) app.add_page(index_page, route=route)
assert set(app.pages.keys()) == {"test/[dynamic]"} assert set(app.pages.keys()) == {"test/[dynamic]"}
assert "dynamic" in app.state.computed_vars assert "dynamic" in app.state.computed_vars
assert app.state.computed_vars["dynamic"]._deps(objclass=DefaultState) == { assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
constants.ROUTER constants.ROUTER
} }
assert constants.ROUTER in app.state().computed_var_dependencies assert constants.ROUTER in app.state().computed_var_dependencies
@ -1093,9 +1097,9 @@ async def test_process_events(mocker, token: str):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("state", "overlay_component", "exp_page_child"), ("state", "overlay_component", "exp_page_child"),
[ [
(DefaultState, default_overlay_component, None), (None, default_overlay_component, None),
(DefaultState, None, None), (None, None, None),
(DefaultState, Text.create("foo"), Text), (None, Text.create("foo"), Text),
(State, default_overlay_component, Fragment), (State, default_overlay_component, Fragment),
(State, None, None), (State, None, None),
(State, Text.create("foo"), Text), (State, Text.create("foo"), Text),

View File

@ -16,7 +16,10 @@ def test_app_harness(tmp_path):
def BasicApp(): def BasicApp():
import reflex as rx import reflex as rx
app = rx.App() class State(rx.State):
pass
app = rx.App(state=State)
app.add_page(lambda: rx.text("Basic App"), route="/", title="index") app.add_page(lambda: rx.text("Basic App"), route="/", title="index")
app.compile() app.compile()