diff --git a/integration/test_background_task.py b/integration/test_background_task.py
index 493dc5180..bc70ff01f 100644
--- a/integration/test_background_task.py
+++ b/integration/test_background_task.py
@@ -93,7 +93,7 @@ def BackgroundTask():
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)
- app = rx.App(state=State)
+ app = rx.App(state=rx.State)
app.add_page(index)
app.compile()
diff --git a/integration/test_call_script.py b/integration/test_call_script.py
index 8b24ce19b..95bdf37c9 100644
--- a/integration/test_call_script.py
+++ b/integration/test_call_script.py
@@ -135,7 +135,7 @@ def CallScript():
yield rx.call_script("inline_counter = 0; external_counter = 0")
self.reset()
- app = rx.App(state=CallScriptState)
+ app = rx.App(state=rx.State)
with open("assets/external.js", "w") as f:
f.write(external_scripts)
diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py
index 90b98f16d..a9a311d8c 100644
--- a/integration/test_client_storage.py
+++ b/integration/test_client_storage.py
@@ -97,7 +97,7 @@ def ClientSide():
rx.box(ClientSideSubSubState.l1s, id="l1s"),
)
- app = rx.App(state=ClientSideState)
+ app = rx.App(state=rx.State)
app.add_page(index)
app.add_page(index, route="/foo")
app.compile()
@@ -263,7 +263,6 @@ async def test_client_side_state(
state_var_input.send_keys("c7")
input_value_input.send_keys("c7 value")
set_sub_state_button.click()
-
state_var_input.send_keys("l1")
input_value_input.send_keys("l1 value")
set_sub_state_button.click()
@@ -276,7 +275,6 @@ async def test_client_side_state(
state_var_input.send_keys("l4")
input_value_input.send_keys("l4 value")
set_sub_state_button.click()
-
state_var_input.send_keys("c1s")
input_value_input.send_keys("c1s value")
set_sub_sub_state_button.click()
@@ -285,28 +283,28 @@ async def test_client_side_state(
set_sub_sub_state_button.click()
exp_cookies = {
- "client_side_state.client_side_sub_state.c1": {
+ "state.client_side_state.client_side_sub_state.c1": {
"domain": "localhost",
"httpOnly": False,
- "name": "client_side_state.client_side_sub_state.c1",
+ "name": "state.client_side_state.client_side_sub_state.c1",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c1%20value",
},
- "client_side_state.client_side_sub_state.c2": {
+ "state.client_side_state.client_side_sub_state.c2": {
"domain": "localhost",
"httpOnly": False,
- "name": "client_side_state.client_side_sub_state.c2",
+ "name": "state.client_side_state.client_side_sub_state.c2",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c2%20value",
},
- "client_side_state.client_side_sub_state.c4": {
+ "state.client_side_state.client_side_sub_state.c4": {
"domain": "localhost",
"httpOnly": False,
- "name": "client_side_state.client_side_sub_state.c4",
+ "name": "state.client_side_state.client_side_sub_state.c4",
"path": "/",
"sameSite": "Strict",
"secure": False,
@@ -321,19 +319,19 @@ async def test_client_side_state(
"secure": False,
"value": "c6%20value",
},
- "client_side_state.client_side_sub_state.c7": {
+ "state.client_side_state.client_side_sub_state.c7": {
"domain": "localhost",
"httpOnly": False,
- "name": "client_side_state.client_side_sub_state.c7",
+ "name": "state.client_side_state.client_side_sub_state.c7",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c7%20value",
},
- "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
+ "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
"domain": "localhost",
"httpOnly": False,
- "name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
+ "name": "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
"path": "/",
"sameSite": "Lax",
"secure": False,
@@ -354,40 +352,45 @@ async def test_client_side_state(
input_value_input.send_keys("c3 value")
set_sub_state_button.click()
AppHarness._poll_for(
- lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
+ lambda: "state.client_side_state.client_side_sub_state.c3"
+ in cookie_info_map(driver)
)
- c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
+ c3_cookie = cookie_info_map(driver)[
+ "state.client_side_state.client_side_sub_state.c3"
+ ]
assert c3_cookie.pop("expiry") is not None
assert c3_cookie == {
"domain": "localhost",
"httpOnly": False,
- "name": "client_side_state.client_side_sub_state.c3",
+ "name": "state.client_side_state.client_side_sub_state.c3",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c3%20value",
}
time.sleep(2) # wait for c3 to expire
- assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
+ assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map(
+ driver
+ )
local_storage_items = local_storage.items()
local_storage_items.pop("chakra-ui-color-mode", None)
assert (
- local_storage_items.pop("client_side_state.client_side_sub_state.l1")
+ local_storage_items.pop("state.client_side_state.client_side_sub_state.l1")
== "l1 value"
)
assert (
- local_storage_items.pop("client_side_state.client_side_sub_state.l2")
+ local_storage_items.pop("state.client_side_state.client_side_sub_state.l2")
== "l2 value"
)
assert local_storage_items.pop("l3") == "l3 value"
assert (
- local_storage_items.pop("client_side_state.client_side_sub_state.l4")
+ local_storage_items.pop("state.client_side_state.client_side_sub_state.l4")
== "l4 value"
)
assert (
local_storage_items.pop(
- "client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
+ "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
)
== "l1s value"
)
@@ -482,12 +485,15 @@ async def test_client_side_state(
# make sure c5 cookie shows up on the `/foo` route
AppHarness._poll_for(
- lambda: "client_side_state.client_side_sub_state.c5" in cookie_info_map(driver)
+ lambda: "state.client_side_state.client_side_sub_state.c5"
+ in cookie_info_map(driver)
)
- assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
+ assert cookie_info_map(driver)[
+ "state.client_side_state.client_side_sub_state.c5"
+ ] == {
"domain": "localhost",
"httpOnly": False,
- "name": "client_side_state.client_side_sub_state.c5",
+ "name": "state.client_side_state.client_side_sub_state.c5",
"path": "/foo/",
"sameSite": "Lax",
"secure": False,
diff --git a/integration/test_connection_banner.py b/integration/test_connection_banner.py
index c078df1f7..0468e6b53 100644
--- a/integration/test_connection_banner.py
+++ b/integration/test_connection_banner.py
@@ -19,7 +19,7 @@ def ConnectionBanner():
def index():
return rx.text("Hello World")
- app = rx.App(state=State)
+ app = rx.App(state=rx.State)
app.add_page(index)
app.compile()
diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py
index 5ab11b693..782e625ee 100644
--- a/integration/test_dynamic_routes.py
+++ b/integration/test_dynamic_routes.py
@@ -56,7 +56,7 @@ def DynamicRoute():
def redirect_page():
return rx.fragment(rx.text("redirecting..."))
- app = rx.App(state=DynamicState)
+ app = rx.App(state=rx.State)
app.add_page(index)
app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore
app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore
@@ -143,10 +143,12 @@ def poll_for_order(
return await dynamic_route.get_state(token)
async def _check():
- return (await _backend_state()).order == exp_order
+ return (await _backend_state()).substates[
+ "dynamic_state"
+ ].order == exp_order
await AppHarness._poll_for_async(_check)
- assert (await _backend_state()).order == exp_order
+ assert (await _backend_state()).substates["dynamic_state"].order == exp_order
return _poll_for_order
diff --git a/integration/test_event_actions.py b/integration/test_event_actions.py
index 8f5e0788b..da444a5cf 100644
--- a/integration/test_event_actions.py
+++ b/integration/test_event_actions.py
@@ -130,7 +130,7 @@ def TestEventAction():
on_click=EventActionState.on_click("outer"), # type: ignore
)
- app = rx.App(state=EventActionState)
+ app = rx.App(state=rx.State)
app.add_page(index)
app.compile()
@@ -211,10 +211,14 @@ def poll_for_order(
return await event_action.get_state(token)
async def _check():
- return (await _backend_state()).order == exp_order
+ return (await _backend_state()).substates[
+ "event_action_state"
+ ].order == exp_order
await AppHarness._poll_for_async(_check)
- assert (await _backend_state()).order == exp_order
+ assert (await _backend_state()).substates[
+ "event_action_state"
+ ].order == exp_order
return _poll_for_order
diff --git a/integration/test_event_chain.py b/integration/test_event_chain.py
index 5fbf7cc14..7d003635e 100644
--- a/integration/test_event_chain.py
+++ b/integration/test_event_chain.py
@@ -122,7 +122,7 @@ def EventChain():
time.sleep(0.5)
self.interim_value = "final"
- app = rx.App(state=State)
+ app = rx.App(state=rx.State)
token_input = rx.input(
value=State.router.session.client_token, is_read_only=True, id="token"
@@ -401,12 +401,12 @@ async def test_event_chain_click(
btn.click()
async def _has_all_events():
- return len((await event_chain.get_state(token)).event_order) == len(
- exp_event_order
- )
+ return len(
+ (await event_chain.get_state(token)).substates["state"].event_order
+ ) == len(exp_event_order)
await AppHarness._poll_for_async(_has_all_events)
- event_order = (await event_chain.get_state(token)).event_order
+ event_order = (await event_chain.get_state(token)).substates["state"].event_order
assert event_order == exp_event_order
@@ -453,12 +453,12 @@ async def test_event_chain_on_load(
token = assert_token(event_chain, driver)
async def _has_all_events():
- return len((await event_chain.get_state(token)).event_order) == len(
- exp_event_order
- )
+ return len(
+ (await event_chain.get_state(token)).substates["state"].event_order
+ ) == len(exp_event_order)
await AppHarness._poll_for_async(_has_all_events)
- backend_state = await event_chain.get_state(token)
+ backend_state = (await event_chain.get_state(token)).substates["state"]
assert backend_state.event_order == exp_event_order
assert backend_state.is_hydrated is True
@@ -529,12 +529,12 @@ async def test_event_chain_on_mount(
unmount_button.click()
async def _has_all_events():
- return len((await event_chain.get_state(token)).event_order) == len(
- exp_event_order
- )
+ return len(
+ (await event_chain.get_state(token)).substates["state"].event_order
+ ) == len(exp_event_order)
await AppHarness._poll_for_async(_has_all_events)
- event_order = (await event_chain.get_state(token)).event_order
+ event_order = (await event_chain.get_state(token)).substates["state"].event_order
assert event_order == exp_event_order
diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py
index 4a9f6c2d1..cc36b5f25 100644
--- a/integration/test_form_submit.py
+++ b/integration/test_form_submit.py
@@ -22,7 +22,7 @@ def FormSubmit():
def form_submit(self, form_data: dict):
self.form_data = form_data
- app = rx.App(state=FormState)
+ app = rx.App(state=rx.State)
@app.add_page
def index():
@@ -75,7 +75,7 @@ def FormSubmitName():
def form_submit(self, form_data: dict):
self.form_data = form_data
- app = rx.App(state=FormState)
+ app = rx.App(state=rx.State)
@app.add_page
def index():
@@ -210,7 +210,7 @@ async def test_submit(driver, form_submit: AppHarness):
submit_input.click()
async def get_form_data():
- return (await form_submit.get_state(token)).form_data
+ return (await form_submit.get_state(token)).substates["form_state"].form_data
# wait for the form data to arrive at the backend
form_data = await AppHarness._poll_for_async(get_form_data)
diff --git a/integration/test_input.py b/integration/test_input.py
index d24f76815..4a5179850 100644
--- a/integration/test_input.py
+++ b/integration/test_input.py
@@ -16,7 +16,7 @@ def FullyControlledInput():
class State(rx.State):
text: str = "initial"
- app = rx.App(state=State)
+ app = rx.App(state=rx.State)
@app.add_page
def index():
@@ -85,13 +85,15 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("foo")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "ifoonitial"
- assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
+ assert (await fully_controlled_input.get_state(token)).substates[
+ "state"
+ ].text == "ifoonitial"
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
# clear the input on the backend
async with fully_controlled_input.modify_state(token) as state:
- state.text = ""
- assert (await fully_controlled_input.get_state(token)).text == ""
+ state.substates["state"].text = ""
+ assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
assert (
fully_controlled_input.poll_for_value(
debounce_input, exp_not_equal="ifoonitial"
@@ -103,9 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("getting testing done")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "getting testing done"
- assert (
- await fully_controlled_input.get_state(token)
- ).text == "getting testing done"
+ assert (await fully_controlled_input.get_state(token)).substates[
+ "state"
+ ].text == "getting testing done"
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
# type into the on_change input
@@ -113,7 +115,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "overwrite the state"
assert on_change_input.get_attribute("value") == "overwrite the state"
- assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
+ assert (await fully_controlled_input.get_state(token)).substates[
+ "state"
+ ].text == "overwrite the state"
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
clear_button.click()
diff --git a/integration/test_login_flow.py b/integration/test_login_flow.py
index 68e0a864d..f53635743 100644
--- a/integration/test_login_flow.py
+++ b/integration/test_login_flow.py
@@ -42,7 +42,7 @@ def LoginSample():
rx.button("Do it", on_click=State.login, id="doit"),
)
- app = rx.App(state=State)
+ app = rx.App(state=rx.State)
app.add_page(index)
app.add_page(login)
app.compile()
@@ -137,6 +137,6 @@ def test_login_flow(
logout_button = driver.find_element(By.ID, "logout")
logout_button.click()
- assert login_sample._poll_for(lambda: local_storage["state.auth_token"] == "")
+ assert login_sample._poll_for(lambda: local_storage["state.state.auth_token"] == "")
with pytest.raises(NoSuchElementException):
driver.find_element(By.ID, "auth-token")
diff --git a/integration/test_radix_themes.py b/integration/test_radix_themes.py
index 8f07786c5..d4731d20e 100644
--- a/integration/test_radix_themes.py
+++ b/integration/test_radix_themes.py
@@ -81,7 +81,7 @@ def RadixThemesApp():
)
app = rx.App(
- state=State,
+ state=rx.State,
theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"),
)
app.add_page(index)
diff --git a/integration/test_server_side_event.py b/integration/test_server_side_event.py
index b24f47368..31a38ba36 100644
--- a/integration/test_server_side_event.py
+++ b/integration/test_server_side_event.py
@@ -33,7 +33,7 @@ def ServerSideEvent():
def set_value_return_c(self):
return rx.set_value("c", "")
- app = rx.App(state=SSState)
+ app = rx.App(state=rx.State)
@app.add_page
def index():
diff --git a/integration/test_table.py b/integration/test_table.py
index e947514a0..00e6a8b22 100644
--- a/integration/test_table.py
+++ b/integration/test_table.py
@@ -26,7 +26,7 @@ def Table():
caption: str = "random caption"
- app = rx.App(state=TableState)
+ app = rx.App(state=rx.State)
@app.add_page
def index():
diff --git a/integration/test_upload.py b/integration/test_upload.py
index 648c68be5..13fb28d36 100644
--- a/integration/test_upload.py
+++ b/integration/test_upload.py
@@ -113,7 +113,7 @@ def UploadFile():
),
)
- app = rx.App(state=UploadState)
+ app = rx.App(state=rx.State)
app.add_page(index)
app.compile()
@@ -192,7 +192,7 @@ async def test_upload_file(
# look up the backend state and assert on uploaded contents
async def get_file_data():
- return (await upload_file.get_state(token))._file_data
+ return (await upload_file.get_state(token)).substates["upload_state"]._file_data
file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
@@ -205,8 +205,8 @@ async def test_upload_file(
state = await upload_file.get_state(token)
if secondary:
# only the secondary form tracks progress and chain events
- assert state.event_order.count("upload_progress") == 1
- assert state.event_order.count("chain_event") == 1
+ assert state.substates["upload_state"].event_order.count("upload_progress") == 1
+ assert state.substates["upload_state"].event_order.count("chain_event") == 1
@pytest.mark.asyncio
@@ -251,7 +251,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
# look up the backend state and assert on uploaded contents
async def get_file_data():
- return (await upload_file.get_state(token))._file_data
+ return (await upload_file.get_state(token)).substates["upload_state"]._file_data
file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
@@ -349,7 +349,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
# look up the backend state and assert on progress
state = await upload_file.get_state(token)
- assert state.progress_dicts
- assert exp_name not in state._file_data
+ assert state.substates["upload_state"].progress_dicts
+ assert exp_name not in state.substates["upload_state"]._file_data
target_file.unlink()
diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py
index f327b662c..374344a87 100644
--- a/integration/test_var_operations.py
+++ b/integration/test_var_operations.py
@@ -30,7 +30,7 @@ def VarOperations():
dict2: dict = {3: 4}
html_str: str = "
hello
"
- app = rx.App(state=VarOperationState)
+ app = rx.App(state=rx.State)
@app.add_page
def index():
diff --git a/reflex/app.py b/reflex/app.py
index 68ad8b639..33b679378 100644
--- a/reflex/app.py
+++ b/reflex/app.py
@@ -57,6 +57,7 @@ from reflex.route import (
verify_route_validity,
)
from reflex.state import (
+ BaseState,
RouterData,
State,
StateManager,
@@ -98,7 +99,7 @@ class App(Base):
socket_app: Optional[ASGIApp] = None
# The state class to use for the app.
- state: Optional[Type[State]] = None
+ state: Optional[Type[BaseState]] = None
# Class to manage many client states.
_state_manager: Optional[StateManager] = None
@@ -149,25 +150,24 @@ class App(Base):
"`connect_error_component` is deprecated, use `overlay_component` instead"
)
super().__init__(*args, **kwargs)
- state_subclasses = State.__subclasses__()
- inferred_state = state_subclasses[-1] if state_subclasses else None
+ state_subclasses = BaseState.__subclasses__()
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
- # Special case to allow test cases have multiple subclasses of rx.State.
+ # Special case to allow test cases have multiple subclasses of rx.BaseState.
if not is_testing_env:
- # Only one State class is allowed.
+ # Only one Base State class is allowed.
if len(state_subclasses) > 1:
raise ValueError(
- "rx.State has been subclassed multiple times. Only one subclass is allowed"
+ "rx.BaseState cannot be subclassed multiple times. use rx.State instead"
)
# verify that provided state is valid
- if self.state and inferred_state and self.state is not inferred_state:
+ if self.state and self.state is not State:
console.warn(
f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported."
- f" Defaulting to root state: ({inferred_state.__name__})"
+ f" Defaulting to root state: ({State.__name__})"
)
- self.state = inferred_state
+ self.state = State
# Get the config
config = get_config()
@@ -265,7 +265,7 @@ class App(Base):
raise ValueError("The state manager has not been initialized.")
return self._state_manager
- async def preprocess(self, state: State, event: Event) -> StateUpdate | None:
+ async def preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
"""Preprocess the event.
This is where middleware can modify the event before it is processed.
@@ -290,7 +290,7 @@ class App(Base):
return out # type: ignore
async def postprocess(
- self, state: State, event: Event, update: StateUpdate
+ self, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate:
"""Postprocess the event.
@@ -764,7 +764,7 @@ class App(Base):
future.result()
@contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[State]:
+ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state out of band.
Args:
@@ -792,7 +792,9 @@ class App(Base):
sid=state.router.session.session_id,
)
- def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
+ def _process_background(
+ self, state: BaseState, event: Event
+ ) -> asyncio.Task | None:
"""Process an event in the background and emit updates as they arrive.
Args:
diff --git a/reflex/app.pyi b/reflex/app.pyi
index f7e63727a..667ebf52a 100644
--- a/reflex/app.pyi
+++ b/reflex/app.pyi
@@ -33,6 +33,7 @@ from reflex.route import (
)
from reflex.state import (
State as State,
+ BaseState as BaseState,
StateManager as StateManager,
StateUpdate as StateUpdate,
)
@@ -69,7 +70,7 @@ class App(Base):
api: FastAPI
sio: Optional[AsyncServer]
socket_app: Optional[ASGIApp]
- state: Type[State]
+ state: Type[BaseState]
state_manager: StateManager
style: ComponentStyle
middleware: List[Middleware]
diff --git a/reflex/base.py b/reflex/base.py
index 62bc70855..9efea60e0 100644
--- a/reflex/base.py
+++ b/reflex/base.py
@@ -1,11 +1,42 @@
"""Define the base Reflex class."""
from __future__ import annotations
-from typing import Any
+import os
+from typing import Any, List, Type
import pydantic
+from pydantic import BaseModel
from pydantic.fields import ModelField
+from reflex import constants
+
+
+def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
+ """Ensure that the field's name does not shadow an existing attribute of the model.
+
+ Args:
+ bases: List of base models to check for shadowed attrs.
+ field_name: name of attribute
+
+ Raises:
+ NameError: If state var field shadows another in its parent state
+ """
+ reload = os.getenv(constants.RELOAD_CONFIG) == "True"
+ for base in bases:
+ try:
+ if not reload and getattr(base, field_name, None):
+ pass
+ except TypeError as te:
+ raise NameError(
+ f'State var "{field_name}" in {base} has been shadowed by a substate var; '
+ f'use a different field name instead".'
+ ) from te
+
+
+# monkeypatch pydantic validate_field_name method to skip validating
+# shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
+pydantic.main.validate_field_name = validate_field_name # type: ignore
+
class Base(pydantic.BaseModel):
"""The base class subclassed by all Reflex classes.
diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py
index 1e5215872..1740af50a 100644
--- a/reflex/compiler/compiler.py
+++ b/reflex/compiler/compiler.py
@@ -15,7 +15,7 @@ from reflex.components.component import (
StatefulComponent,
)
from reflex.config import get_config
-from reflex.state import State
+from reflex.state import BaseState
from reflex.utils.imports import ImportVar
@@ -63,7 +63,7 @@ def _compile_theme(theme: dict) -> str:
return templates.THEME.render(theme=theme)
-def _compile_contexts(state: Optional[Type[State]]) -> str:
+def _compile_contexts(state: Optional[Type[BaseState]]) -> str:
"""Compile the initial state and contexts.
Args:
@@ -87,7 +87,7 @@ def _compile_contexts(state: Optional[Type[State]]) -> str:
def _compile_page(
component: Component,
- state: Type[State],
+ state: Type[BaseState],
) -> str:
"""Compile the component given the app state.
@@ -337,7 +337,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]:
return output_path, code
-def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
+def compile_contexts(state: Optional[Type[BaseState]]) -> tuple[str, str]:
"""Compile the initial state / context.
Args:
@@ -353,7 +353,7 @@ def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
def compile_page(
- path: str, component: Component, state: Type[State]
+ path: str, component: Component, state: Type[BaseState]
) -> tuple[str, str]:
"""Compile a single page.
diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py
index 3a866051a..0a4df7a4e 100644
--- a/reflex/compiler/utils.py
+++ b/reflex/compiler/utils.py
@@ -21,7 +21,7 @@ from reflex.components.base import (
Title,
)
from reflex.components.component import Component, ComponentStyle, CustomComponent
-from reflex.state import Cookie, LocalStorage, State
+from reflex.state import BaseState, Cookie, LocalStorage
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
@@ -128,7 +128,7 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None)
}
-def compile_state(state: Type[State]) -> dict:
+def compile_state(state: Type[BaseState]) -> dict:
"""Compile the state of the app.
Args:
@@ -170,7 +170,7 @@ def _compile_client_storage_field(
def _compile_client_storage_recursive(
- state: Type[State],
+ state: Type[BaseState],
) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
"""Compile the client-side storage for the given state recursively.
@@ -208,7 +208,7 @@ def _compile_client_storage_recursive(
return cookies, local_storage
-def compile_client_storage(state: Type[State]) -> dict[str, dict]:
+def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
"""Compile the client-side storage for the given state.
Args:
diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py
index bfe112e63..45aaa7248 100644
--- a/reflex/constants/__init__.py
+++ b/reflex/constants/__init__.py
@@ -6,6 +6,7 @@ from .base import (
LOCAL_STORAGE,
POLLING_MAX_HTTP_BUFFER_SIZE,
PYTEST_CURRENT_TEST,
+ RELOAD_CONFIG,
SKIP_COMPILE_ENV_VAR,
ColorMode,
Dirs,
@@ -85,6 +86,7 @@ __ALL__ = [
PYTEST_CURRENT_TEST,
PRODUCTION_BACKEND_URL,
Reflex,
+ RELOAD_CONFIG,
RequirementsTxt,
RouteArgType,
RouteRegex,
diff --git a/reflex/constants/base.py b/reflex/constants/base.py
index 0b28e18cb..82cb39da8 100644
--- a/reflex/constants/base.py
+++ b/reflex/constants/base.py
@@ -173,3 +173,4 @@ SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
# Testing variables.
# Testing os env set by pytest when running a test case.
PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
+RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"
diff --git a/reflex/event.py b/reflex/event.py
index cbaf65b0c..f0ad0957b 100644
--- a/reflex/event.py
+++ b/reflex/event.py
@@ -22,7 +22,7 @@ from reflex.utils.types import ArgsSpec
from reflex.vars import BaseVar, Var
if TYPE_CHECKING:
- from reflex.state import State
+ from reflex.state import BaseState
class Event(Base):
@@ -64,7 +64,7 @@ def background(fn):
def _no_chain_background_task(
- state_cls: Type["State"], name: str, fn: Callable
+ state_cls: Type["BaseState"], name: str, fn: Callable
) -> Callable:
"""Protect against directly chaining a background task from another event handler.
diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py
index 38d5fb14f..6108a90c4 100644
--- a/reflex/middleware/hydrate_middleware.py
+++ b/reflex/middleware/hydrate_middleware.py
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from reflex import constants
from reflex.event import Event, fix_events, get_hydrate_event
from reflex.middleware.middleware import Middleware
-from reflex.state import State, StateUpdate
+from reflex.state import BaseState, StateUpdate
from reflex.utils import format
if TYPE_CHECKING:
@@ -17,7 +17,7 @@ class HydrateMiddleware(Middleware):
"""Middleware to handle initial app hydration."""
async def preprocess(
- self, app: App, state: State, event: Event
+ self, app: App, state: BaseState, event: Event
) -> Optional[StateUpdate]:
"""Preprocess the event.
diff --git a/reflex/middleware/middleware.py b/reflex/middleware/middleware.py
index 726d81621..f522ff861 100644
--- a/reflex/middleware/middleware.py
+++ b/reflex/middleware/middleware.py
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from reflex.base import Base
from reflex.event import Event
-from reflex.state import State, StateUpdate
+from reflex.state import BaseState, StateUpdate
if TYPE_CHECKING:
from reflex.app import App
@@ -16,7 +16,7 @@ class Middleware(Base, ABC):
"""Middleware to preprocess and postprocess requests."""
async def preprocess(
- self, app: App, state: State, event: Event
+ self, app: App, state: BaseState, event: Event
) -> Optional[StateUpdate]:
"""Preprocess the event.
@@ -31,7 +31,7 @@ class Middleware(Base, ABC):
return None
async def postprocess(
- self, app: App, state: State, event: Event, update: StateUpdate
+ self, app: App, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate:
"""Postprocess the event.
diff --git a/reflex/state.py b/reflex/state.py
index 881a6b47c..f6b10e849 100644
--- a/reflex/state.py
+++ b/reflex/state.py
@@ -7,6 +7,7 @@ import copy
import functools
import inspect
import json
+import os
import traceback
import urllib.parse
import uuid
@@ -81,7 +82,7 @@ class HeaderData(Base):
class PageData(Base):
"""An object containing page data."""
- host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
+ host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
path: str = ""
raw_path: str = ""
full_path: str = ""
@@ -152,7 +153,7 @@ RESERVED_BACKEND_VAR_NAMES = {
}
-class State(Base, ABC, extra=pydantic.Extra.allow):
+class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""
# A map from the var name to the var.
@@ -176,6 +177,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
+ # A set of subclassses of this class.
+ class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
+
# Mapping of var name to set of computed variables that depend on it
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
@@ -189,10 +193,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
_always_dirty_substates: ClassVar[Set[str]] = set()
# The parent state.
- parent_state: Optional[State] = None
+ parent_state: Optional[BaseState] = None
# The substates of the state.
- substates: Dict[str, State] = {}
+ substates: Dict[str, BaseState] = {}
# The set of dirty vars.
dirty_vars: Set[str] = set()
@@ -209,10 +213,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The router data for the current page
router: RouterData = RouterData()
- # The hydrated bool.
- is_hydrated: bool = False
-
- def __init__(self, *args, parent_state: State | None = None, **kwargs):
+ def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
"""Initialize the state.
Args:
@@ -220,28 +221,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
parent_state: The parent state.
**kwargs: The kwargs to pass to the Pydantic init method.
- Raises:
- ValueError: If a substate class shadows another.
"""
kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs)
# Setup the substates.
for substate in self.get_substates():
- substate_name = substate.get_name()
- if substate_name in self.substates:
- raise ValueError(
- f"The substate class '{substate_name}' has been defined multiple times. Shadowing "
- f"substate classes is not allowed."
- )
- self.substates[substate_name] = substate(parent_state=self)
+ self.substates[substate.get_name()] = substate(parent_state=self)
# Convert the event handlers to functions.
self._init_event_handlers()
# Create a fresh copy of the backend variables for this instance
self._backend_vars = copy.deepcopy(self.backend_vars)
- def _init_event_handlers(self, state: State | None = None):
+ def _init_event_handlers(self, state: BaseState | None = None):
"""Initialize event handlers.
Allow event handlers to be called directly on the instance. This is
@@ -281,17 +274,44 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Args:
**kwargs: The kwargs to pass to the pydantic init_subclass method.
+
+ Raises:
+ ValueError: If a substate class shadows another.
"""
+ is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
super().__init_subclass__(**kwargs)
# Event handlers should not shadow builtin state methods.
cls._check_overridden_methods()
+ # Reset subclass tracking for this class.
+ cls.class_subclasses = set()
+
# Get the parent vars.
parent_state = cls.get_parent_state()
if parent_state is not None:
cls.inherited_vars = parent_state.vars
cls.inherited_backend_vars = parent_state.backend_vars
+ # Check if another substate class with the same name has already been defined.
+ if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
+ if is_testing_env:
+ # Clear existing subclass with same name when app is reloaded via
+ # utils.prerequisites.get_app(reload=True)
+ parent_state.class_subclasses = set(
+ c
+ for c in parent_state.class_subclasses
+ if c.__name__ != cls.__name__
+ )
+ else:
+ # During normal operation, subclasses cannot have the same name, even if they are
+ # defined in different modules.
+ raise ValueError(
+ f"The substate class '{cls.__name__}' has been defined multiple times. "
+ "Shadowing substate classes is not allowed."
+ )
+ # Track this new subclass in the parent state's subclasses set.
+ parent_state.class_subclasses.add(cls)
+
cls.new_backend_vars = {
name: value
for name, value in cls.__dict__.items()
@@ -437,7 +457,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
@classmethod
@functools.lru_cache()
- def get_parent_state(cls) -> Type[State] | None:
+ def get_parent_state(cls) -> Type[BaseState] | None:
"""Get the parent state.
Returns:
@@ -446,20 +466,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
parent_states = [
base
for base in cls.__bases__
- if types._issubclass(base, State) and base is not State
+ if types._issubclass(base, BaseState) and base is not BaseState
]
assert len(parent_states) < 2, "Only one parent state is allowed."
return parent_states[0] if len(parent_states) == 1 else None # type: ignore
@classmethod
- @functools.lru_cache()
- def get_substates(cls) -> set[Type[State]]:
+ def get_substates(cls) -> set[Type[BaseState]]:
"""Get the substates of the state.
Returns:
The substates of the state.
"""
- return set(cls.__subclasses__())
+ return cls.class_subclasses
@classmethod
@functools.lru_cache()
@@ -487,7 +506,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
@classmethod
@functools.lru_cache()
- def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
+ def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
"""Get the class substate.
Args:
@@ -643,7 +662,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
"""
return {
func[0]: func[1]
- for func in inspect.getmembers(State, predicate=inspect.isfunction)
+ for func in inspect.getmembers(BaseState, predicate=inspect.isfunction)
if not func[0].startswith("__")
}
@@ -909,7 +928,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
for substate in self.substates.values():
substate._reset_client_storage()
- def get_substate(self, path: Sequence[str]) -> State | None:
+ def get_substate(self, path: Sequence[str]) -> BaseState | None:
"""Get the substate.
Args:
@@ -933,7 +952,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
def _get_event_handler(
self, event: Event
- ) -> tuple[State | StateProxy, EventHandler]:
+ ) -> tuple[BaseState | StateProxy, EventHandler]:
"""Get the event handler for the given event.
Args:
@@ -1050,7 +1069,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
)
async def _process_event(
- self, handler: EventHandler, state: State | StateProxy, payload: Dict
+ self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
) -> AsyncIterator[StateUpdate]:
"""Process event.
@@ -1263,7 +1282,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
d.update(substate_d)
return d
- async def __aenter__(self) -> State:
+ async def __aenter__(self) -> BaseState:
"""Enter the async context manager protocol.
This should not be used for the State class, but exists for
@@ -1288,6 +1307,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
pass
+class State(BaseState):
+ """The app Base State."""
+
+ # The hydrated bool.
+ is_hydrated: bool = False
+
+
class StateProxy(wrapt.ObjectProxy):
"""Proxy of a state instance to control mutability of vars for a background task.
@@ -1455,10 +1481,10 @@ class StateManager(Base, ABC):
"""A class to manage many client states."""
# The state class to use.
- state: Type[State]
+ state: Type[BaseState]
@classmethod
- def create(cls, state: Type[State]):
+ def create(cls, state: Type[BaseState]):
"""Create a new state manager.
Args:
@@ -1473,7 +1499,7 @@ class StateManager(Base, ABC):
return StateManagerMemory(state=state)
@abstractmethod
- async def get_state(self, token: str) -> State:
+ async def get_state(self, token: str) -> BaseState:
"""Get the state for a token.
Args:
@@ -1485,7 +1511,7 @@ class StateManager(Base, ABC):
pass
@abstractmethod
- async def set_state(self, token: str, state: State):
+ async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.
Args:
@@ -1496,7 +1522,7 @@ class StateManager(Base, ABC):
@abstractmethod
@contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[State]:
+ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Args:
@@ -1512,7 +1538,7 @@ class StateManagerMemory(StateManager):
"""A state manager that stores states in memory."""
# The mapping of client ids to states.
- states: Dict[str, State] = {}
+ states: Dict[str, BaseState] = {}
# The mutex ensures the dict of mutexes is updated exclusively
_state_manager_lock = asyncio.Lock()
@@ -1527,7 +1553,7 @@ class StateManagerMemory(StateManager):
"_states_locks": {"exclude": True},
}
- async def get_state(self, token: str) -> State:
+ async def get_state(self, token: str) -> BaseState:
"""Get the state for a token.
Args:
@@ -1540,7 +1566,7 @@ class StateManagerMemory(StateManager):
self.states[token] = self.state()
return self.states[token]
- async def set_state(self, token: str, state: State):
+ async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.
Args:
@@ -1550,7 +1576,7 @@ class StateManagerMemory(StateManager):
pass
@contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[State]:
+ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Args:
@@ -1598,7 +1624,7 @@ class StateManagerRedis(StateManager):
b"evicted",
}
- async def get_state(self, token: str) -> State:
+ async def get_state(self, token: str) -> BaseState:
"""Get the state for a token.
Args:
@@ -1613,7 +1639,9 @@ class StateManagerRedis(StateManager):
return await self.get_state(token)
return cloudpickle.loads(redis_state)
- async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
+ async def set_state(
+ self, token: str, state: BaseState, lock_id: bytes | None = None
+ ):
"""Set the state for a token.
Args:
@@ -1637,7 +1665,7 @@ class StateManagerRedis(StateManager):
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
@contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[State]:
+ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Args:
@@ -1879,7 +1907,7 @@ class MutableProxy(wrapt.ObjectProxy):
__mutable_types__ = (list, dict, set, Base)
- def __init__(self, wrapped: Any, state: State, field_name: str):
+ def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes.
Args:
diff --git a/reflex/testing.py b/reflex/testing.py
index a6bb3d400..2cc20e6a4 100644
--- a/reflex/testing.py
+++ b/reflex/testing.py
@@ -38,7 +38,7 @@ import reflex.utils.build
import reflex.utils.exec
import reflex.utils.prerequisites
import reflex.utils.processes
-from reflex.state import State, StateManagerMemory, StateManagerRedis
+from reflex.state import BaseState, State, StateManagerMemory, StateManagerRedis
try:
from selenium import webdriver # pyright: ignore [reportMissingImports]
@@ -162,6 +162,9 @@ class AppHarness:
with chdir(self.app_path):
# ensure config and app are reloaded when testing different app
reflex.config.get_config(reload=True)
+ # reset rx.State subclasses
+ State.class_subclasses.clear()
+ # self.app_module.app.
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
self.app_instance = self.app_module.app
if isinstance(self.app_instance.state_manager, StateManagerRedis):
@@ -434,7 +437,7 @@ class AppHarness:
self._frontends.append(driver)
return driver
- async def get_state(self, token: str) -> State:
+ async def get_state(self, token: str) -> BaseState:
"""Get the state associated with the given token.
Args:
@@ -561,7 +564,7 @@ class AppHarness:
)
return element.get_attribute("value")
- def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, reflex.State]:
+ def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
"""Poll app state_manager for any connected clients.
Args:
diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py
index 17195f224..f0e53b2d2 100644
--- a/reflex/utils/prerequisites.py
+++ b/reflex/utils/prerequisites.py
@@ -124,11 +124,13 @@ def get_app(reload: bool = False) -> ModuleType:
Returns:
The app based on the default config.
"""
+ os.environ[constants.RELOAD_CONFIG] = str(reload)
config = get_config()
module = ".".join([config.app_name, config.app_name])
sys.path.insert(0, os.getcwd())
app = __import__(module, fromlist=(constants.CompileVars.APP,))
if reload:
+
importlib.reload(app)
return app
diff --git a/reflex/vars.py b/reflex/vars.py
index caec26f87..6bfd6c57d 100644
--- a/reflex/vars.py
+++ b/reflex/vars.py
@@ -41,7 +41,7 @@ from reflex.utils import console, format, imports, serializers, types
from reflex.utils.imports import ImportDict, ImportVar
if TYPE_CHECKING:
- from reflex.state import State
+ from reflex.state import BaseState
# Set of unique variable names.
USED_VARIABLES = set()
@@ -1472,7 +1472,7 @@ class Var:
)
)
- def _var_set_state(self, state: Type[State] | str) -> Any:
+ def _var_set_state(self, state: Type[BaseState] | str) -> Any:
"""Set the state of the var.
Args:
@@ -1604,14 +1604,14 @@ class BaseVar(Var):
return setter
return ".".join((self._var_data.state, setter))
- def get_setter(self) -> Callable[[State, Any], None]:
+ def get_setter(self) -> Callable[[BaseState, Any], None]:
"""Get the var's setter function.
Returns:
A function that that creates a setter for the var.
"""
- def setter(state: State, value: Any):
+ def setter(state: BaseState, value: Any):
"""Get the setter for the var.
Args:
@@ -1643,9 +1643,9 @@ class ComputedVar(Var, property):
def __init__(
self,
- fget: Callable[[State], Any],
- fset: Callable[[State, Any], None] | None = None,
- fdel: Callable[[State], Any] | None = None,
+ fget: Callable[[BaseState], Any],
+ fset: Callable[[BaseState, Any], None] | None = None,
+ fdel: Callable[[BaseState], Any] | None = None,
doc: str | None = None,
**kwargs,
):
diff --git a/reflex/vars.pyi b/reflex/vars.pyi
index c4c96af90..6ec1bd987 100644
--- a/reflex/vars.pyi
+++ b/reflex/vars.pyi
@@ -5,6 +5,7 @@ from _typeshed import Incomplete
from reflex import constants as constants
from reflex.base import Base as Base
from reflex.state import State as State
+from reflex.state import BaseState as BaseState
from reflex.utils import console as console, format as format, types as types
from reflex.utils.imports import ImportVar
from types import FunctionType
@@ -110,7 +111,7 @@ class Var:
def as_ref(self) -> Var: ...
@property
def _var_full_name(self) -> str: ...
- def _var_set_state(self, state: Type[State] | str) -> Any: ...
+ def _var_set_state(self, state: Type[BaseState] | str) -> Any: ...
@dataclass(eq=False)
class BaseVar(Var):
@@ -123,7 +124,7 @@ class BaseVar(Var):
def __hash__(self) -> int: ...
def get_default_value(self) -> Any: ...
def get_setter_name(self, include_state: bool = ...) -> str: ...
- def get_setter(self) -> Callable[[State, Any], None]: ...
+ def get_setter(self) -> Callable[[BaseState, Any], None]: ...
@dataclass(init=False)
class ComputedVar(Var):
diff --git a/tests/__init__.py b/tests/__init__.py
index b318bba63..b4d8570e5 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1 +1,4 @@
"""Root directory for tests."""
+import os
+
+from reflex import constants
diff --git a/tests/components/base/test_script.py b/tests/components/base/test_script.py
index cc16ab718..7ccdc0634 100644
--- a/tests/components/base/test_script.py
+++ b/tests/components/base/test_script.py
@@ -2,7 +2,7 @@
import pytest
from reflex.components.base.script import Script
-from reflex.state import State
+from reflex.state import BaseState
def test_script_inline():
@@ -31,7 +31,7 @@ def test_script_neither():
Script.create()
-class EvState(State):
+class EvState(BaseState):
"""State for testing event handlers."""
def on_ready(self):
diff --git a/tests/components/datadisplay/conftest.py b/tests/components/datadisplay/conftest.py
index c0e61c437..93796ed23 100644
--- a/tests/components/datadisplay/conftest.py
+++ b/tests/components/datadisplay/conftest.py
@@ -5,6 +5,7 @@ import pandas as pd
import pytest
import reflex as rx
+from reflex.state import BaseState
@pytest.fixture
@@ -18,7 +19,7 @@ def data_table_state(request):
The data table state class.
"""
- class DataTableState(rx.State):
+ class DataTableState(BaseState):
data = request.param["data"]
columns = ["column1", "column2"]
@@ -33,7 +34,7 @@ def data_table_state2():
The data table state class.
"""
- class DataTableState(rx.State):
+ class DataTableState(BaseState):
_data = pd.DataFrame()
@rx.var
@@ -51,7 +52,7 @@ def data_table_state3():
The data table state class.
"""
- class DataTableState(rx.State):
+ class DataTableState(BaseState):
_data: List = []
_columns: List = ["col1", "col2"]
@@ -74,7 +75,7 @@ def data_table_state4():
The data table state class.
"""
- class DataTableState(rx.State):
+ class DataTableState(BaseState):
_data: List = []
_columns: List = ["col1", "col2"]
diff --git a/tests/components/datadisplay/test_table.py b/tests/components/datadisplay/test_table.py
index 94e39f6e5..212f1fb59 100644
--- a/tests/components/datadisplay/test_table.py
+++ b/tests/components/datadisplay/test_table.py
@@ -4,12 +4,12 @@ from typing import List, Tuple
import pytest
from reflex.components.datadisplay.table import Tbody, Tfoot, Thead
-from reflex.state import State
+from reflex.state import BaseState
PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8
-class TableState(State):
+class TableState(BaseState):
"""Test State class."""
rows_List_List_str: List[List[str]] = [["random", "row"]]
diff --git a/tests/components/forms/test_debounce.py b/tests/components/forms/test_debounce.py
index 1965fabf5..97cfa8648 100644
--- a/tests/components/forms/test_debounce.py
+++ b/tests/components/forms/test_debounce.py
@@ -3,6 +3,7 @@
import pytest
import reflex as rx
+from reflex.state import BaseState
from reflex.vars import BaseVar
@@ -24,7 +25,7 @@ def test_render_many_child():
_ = rx.debounce_input("foo", "bar").render()
-class S(rx.State):
+class S(BaseState):
"""Example state for debounce tests."""
value: str = ""
diff --git a/tests/components/layout/test_cond.py b/tests/components/layout/test_cond.py
index 00cf4de7d..08c6883c4 100644
--- a/tests/components/layout/test_cond.py
+++ b/tests/components/layout/test_cond.py
@@ -15,12 +15,13 @@ from reflex.components.layout.responsive import (
tablet_only,
)
from reflex.components.typography.text import Text
+from reflex.state import BaseState
from reflex.vars import Var
@pytest.fixture
def cond_state(request):
- class CondState(rx.State):
+ class CondState(BaseState):
value: request.param["value_type"] = request.param["value"] # noqa
return CondState
diff --git a/tests/components/layout/test_foreach.py b/tests/components/layout/test_foreach.py
index aacdf3638..71ae36b23 100644
--- a/tests/components/layout/test_foreach.py
+++ b/tests/components/layout/test_foreach.py
@@ -4,10 +4,10 @@ import pytest
from reflex.components import box, foreach, text
from reflex.components.layout import Foreach
-from reflex.state import State
+from reflex.state import BaseState
-class ForEachState(State):
+class ForEachState(BaseState):
"""A state for testing the ForEach component."""
colors_list: List[str] = ["red", "yellow"]
diff --git a/tests/components/test_component.py b/tests/components/test_component.py
index d90efd658..ac5b23130 100644
--- a/tests/components/test_component.py
+++ b/tests/components/test_component.py
@@ -14,7 +14,7 @@ from reflex.components.component import (
from reflex.components.layout.box import Box
from reflex.constants import EventTriggers
from reflex.event import EventChain, EventHandler
-from reflex.state import State
+from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
from reflex.utils.imports import ImportVar
@@ -23,7 +23,7 @@ from reflex.vars import Var, VarData
@pytest.fixture
def test_state():
- class TestState(State):
+ class TestState(BaseState):
num: int
def do_something(self):
@@ -400,7 +400,7 @@ def test_get_event_triggers(component1, component2):
)
-class C1State(State):
+class C1State(BaseState):
"""State for testing C1 component."""
def mock_handler(self, _e, _bravo, _charlie):
diff --git a/tests/conftest.py b/tests/conftest.py
index d2dd301f4..e5dddc470 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,7 +8,6 @@ from typing import Dict, Generator
import pytest
-import reflex as rx
from reflex.app import App
from reflex.event import EventSpec
@@ -225,23 +224,3 @@ def token() -> str:
A fresh/unique token string.
"""
return str(uuid.uuid4())
-
-
-@pytest.fixture
-def duplicate_substate():
- """Create a Test state that has duplicate child substates.
-
- Returns:
- The test state.
- """
-
- class TestState(rx.State):
- pass
-
- class ChildTestState(TestState): # type: ignore # noqa
- pass
-
- class ChildTestState(TestState): # type: ignore # noqa
- pass
-
- return TestState
diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py
index 7767dcf8b..2f21557f0 100644
--- a/tests/middleware/test_hydrate_middleware.py
+++ b/tests/middleware/test_hydrate_middleware.py
@@ -2,13 +2,14 @@ from typing import Any, Dict
import pytest
+from reflex import constants
from reflex.app import App
from reflex.constants import CompileVars
from reflex.middleware.hydrate_middleware import HydrateMiddleware
-from reflex.state import State, StateUpdate
+from reflex.state import BaseState, StateUpdate
-def exp_is_hydrated(state: State) -> Dict[str, Any]:
+def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
Args:
@@ -20,7 +21,7 @@ def exp_is_hydrated(state: State) -> Dict[str, Any]:
return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
-class TestState(State):
+class TestState(BaseState):
"""A test state with no return in handler."""
__test__ = False
@@ -32,7 +33,7 @@ class TestState(State):
self.num += 1
-class TestState2(State):
+class TestState2(BaseState):
"""A test state with return in handler."""
__test__ = False
@@ -54,7 +55,7 @@ class TestState2(State):
self.name = "random"
-class TestState3(State):
+class TestState3(BaseState):
"""A test state with async handler."""
__test__ = False
@@ -97,6 +98,9 @@ async def test_preprocess(
event_fixture: The event fixture(an Event).
expected: Expected delta.
"""
+ test_state.add_var(
+ constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
+ )
app = App(state=test_state, load_events={"index": [test_state.test_handler]})
state = test_state()
diff --git a/tests/states/__init__.py b/tests/states/__init__.py
index 2c007172a..11e891ab4 100644
--- a/tests/states/__init__.py
+++ b/tests/states/__init__.py
@@ -1,10 +1,12 @@
-"""Common rx.State subclasses for use in tests."""
+"""Common rx.BaseState subclasses for use in tests."""
import reflex as rx
+from reflex.state import BaseState
from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
from .upload import (
ChildFileUploadState,
FileStateBase1,
+ FileStateBase2,
FileUploadState,
GrandChildFileUploadState,
SubUploadState,
@@ -12,7 +14,7 @@ from .upload import (
)
-class GenState(rx.State):
+class GenState(BaseState):
"""A state with event handlers that generate multiple updates."""
value: int
diff --git a/tests/states/mutation.py b/tests/states/mutation.py
index b3d98301f..5825b6d12 100644
--- a/tests/states/mutation.py
+++ b/tests/states/mutation.py
@@ -3,9 +3,10 @@
from typing import Dict, List, Set, Union
import reflex as rx
+from reflex.state import BaseState
-class DictMutationTestState(rx.State):
+class DictMutationTestState(BaseState):
"""A state for testing ReflexDict mutation."""
# plain dict
@@ -62,7 +63,7 @@ class DictMutationTestState(rx.State):
self.friend_in_nested_dict["friend"]["age"] = 30
-class ListMutationTestState(rx.State):
+class ListMutationTestState(BaseState):
"""A state for testing ReflexList mutation."""
# plain list
@@ -144,7 +145,7 @@ class CustomVar(rx.Base):
custom: OtherBase = OtherBase()
-class MutableTestState(rx.State):
+class MutableTestState(BaseState):
"""A test state."""
array: List[Union[str, List, Dict[str, str]]] = [
diff --git a/tests/states/upload.py b/tests/states/upload.py
index ec2585dd1..8abe11c24 100644
--- a/tests/states/upload.py
+++ b/tests/states/upload.py
@@ -3,9 +3,10 @@ from pathlib import Path
from typing import ClassVar, List
import reflex as rx
+from reflex.state import BaseState, State
-class UploadState(rx.State):
+class UploadState(BaseState):
"""The base state for uploading a file."""
async def handle_upload1(self, files: List[rx.UploadFile]):
@@ -17,7 +18,7 @@ class UploadState(rx.State):
pass
-class BaseState(rx.State):
+class BaseState(BaseState):
"""The test base state."""
pass
@@ -37,7 +38,7 @@ class SubUploadState(BaseState):
pass
-class FileUploadState(rx.State):
+class FileUploadState(State):
"""The base state for uploading a file."""
img_list: List[str]
@@ -79,7 +80,7 @@ class FileUploadState(rx.State):
pass
-class FileStateBase1(rx.State):
+class FileStateBase1(State):
"""The base state for a child FileUploadState."""
pass
diff --git a/tests/test_app.py b/tests/test_app.py
index 6aa73974e..7d667ac1e 100644
--- a/tests/test_app.py
+++ b/tests/test_app.py
@@ -28,7 +28,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text
from reflex.event import Event, get_hydrate_event
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
-from reflex.state import RouterData, State, StateManagerRedis, StateUpdate
+from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
from reflex.style import Style
from reflex.utils import format
from reflex.vars import ComputedVar
@@ -43,7 +43,7 @@ from .states import (
)
-class EmptyState(State):
+class EmptyState(BaseState):
"""An empty state."""
pass
@@ -77,14 +77,14 @@ def about_page():
return about
-class ATestState(State):
+class ATestState(BaseState):
"""A simple state for testing."""
var: int
@pytest.fixture()
-def test_state() -> Type[State]:
+def test_state() -> Type[BaseState]:
"""A default state.
Returns:
@@ -94,14 +94,14 @@ def test_state() -> Type[State]:
@pytest.fixture()
-def redundant_test_state() -> Type[State]:
+def redundant_test_state() -> Type[BaseState]:
"""A default state.
Returns:
A default state.
"""
- class RedundantTestState(State):
+ class RedundantTestState(BaseState):
var: int
return RedundantTestState
@@ -198,12 +198,12 @@ def test_default_app(app: App):
def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
- """Test that an error is thrown when multiple classes subclass rx.State.
+ """Test that an error is thrown when multiple classes subclass rx.BaseState.
Args:
monkeypatch: Pytest monkeypatch object.
- test_state: A test state subclassing rx.State.
- redundant_test_state: Another test state subclassing rx.State.
+ test_state: A test state subclassing rx.BaseState.
+ redundant_test_state: Another test state subclassing rx.BaseState.
"""
monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
with pytest.raises(ValueError):
@@ -705,12 +705,12 @@ async def test_dict_mutation_detection__plain_list(
[
(
FileUploadState,
- {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
+ {"state.file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
),
(
ChildFileUploadState,
{
- "file_state_base1.child_file_upload_state": {
+ "state.file_state_base1.child_file_upload_state": {
"img_list": ["image1.jpg", "image2.jpg"]
}
},
@@ -718,14 +718,14 @@ async def test_dict_mutation_detection__plain_list(
(
GrandChildFileUploadState,
{
- "file_state_base1.file_state_base2.grand_child_file_upload_state": {
+ "state.file_state_base1.file_state_base2.grand_child_file_upload_state": {
"img_list": ["image1.jpg", "image2.jpg"]
}
},
),
],
)
-async def test_upload_file(tmp_path, state, delta, token: str):
+async def test_upload_file(tmp_path, state, delta, token: str, mocker):
"""Test that file upload works correctly.
Args:
@@ -733,10 +733,15 @@ async def test_upload_file(tmp_path, state, delta, token: str):
state: The state class.
delta: Expected delta
token: a Token.
+ mocker: pytest mocker object.
"""
+ mocker.patch(
+ "reflex.state.State.class_subclasses",
+ {state if state is FileUploadState else FileStateBase1},
+ )
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
- app = App(state=state if state is FileUploadState else FileStateBase1)
+ app = App(state=State)
app.event_namespace.emit = AsyncMock() # type: ignore
current_state = await app.state_manager.get_state(token)
data = b"This is binary data"
@@ -749,7 +754,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
- "reflex-event-handler": f"{state_name}.multi_handle_upload",
+ "reflex-event-handler": f"state.{state_name}.multi_handle_upload",
}
file1 = UploadFile(
@@ -851,7 +856,7 @@ async def test_upload_file_background(state, tmp_path, token):
await app.state_manager.redis.close()
-class DynamicState(State):
+class DynamicState(BaseState):
"""State class for testing dynamic route var.
This is defined at module level because event handlers cannot be addressed
@@ -891,9 +896,7 @@ class DynamicState(State):
@pytest.mark.asyncio
async def test_dynamic_route_var_route_change_completed_on_load(
- index_page,
- windows_platform: bool,
- token: str,
+ index_page, windows_platform: bool, token: str, mocker
):
"""Create app with dynamic route var, and simulate navigation.
@@ -904,7 +907,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
index_page: The index page.
windows_platform: Whether the system is windows.
token: a Token.
+ mocker: pytest mocker object.
"""
+ mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
+ DynamicState.add_var(
+ constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
+ )
arg_name = "dynamic"
route = f"/test/[{arg_name}]"
if windows_platform:
diff --git a/tests/test_event.py b/tests/test_event.py
index 284012b13..2326f0920 100644
--- a/tests/test_event.py
+++ b/tests/test_event.py
@@ -4,7 +4,7 @@ import pytest
from reflex import event
from reflex.event import Event, EventHandler, EventSpec, fix_events
-from reflex.state import State
+from reflex.state import BaseState
from reflex.utils import format
from reflex.vars import Var
@@ -303,7 +303,7 @@ def test_event_actions():
def test_event_actions_on_state():
- class EventActionState(State):
+ class EventActionState(BaseState):
def handler(self):
pass
diff --git a/tests/test_state.py b/tests/test_state.py
index 591026d2c..b4ca0c87a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -19,11 +19,11 @@ from reflex.base import Base
from reflex.constants import CompileVars, RouteVar, SocketEvent
from reflex.event import Event, EventHandler
from reflex.state import (
+ BaseState,
ImmutableStateError,
LockExpiredError,
MutableProxy,
RouterData,
- State,
StateManager,
StateManagerMemory,
StateManagerRedis,
@@ -75,7 +75,7 @@ class Object(Base):
prop2: str = "hello"
-class TestState(State):
+class TestState(BaseState):
"""A test state."""
# Set this class as not test one
@@ -148,7 +148,7 @@ class GrandchildState(ChildState):
pass
-class DateTimeState(State):
+class DateTimeState(BaseState):
"""A State with some datetime fields."""
d: datetime.date = datetime.date.fromisoformat("1989-11-09")
@@ -253,7 +253,6 @@ def test_class_vars(test_state):
"""
cls = type(test_state)
assert set(cls.vars.keys()) == {
- CompileVars.IS_HYDRATED, # added by hydrate_middleware to all State
"router",
"num1",
"num2",
@@ -641,7 +640,6 @@ def test_reset(test_state, child_state):
"obj",
"upper",
"complex",
- "is_hydrated",
"fig",
"key",
"sum",
@@ -837,7 +835,7 @@ def test_get_query_params(test_state):
def test_add_var():
- class DynamicState(State):
+ class DynamicState(BaseState):
pass
ds1 = DynamicState()
@@ -870,7 +868,7 @@ def test_add_var_default_handlers(test_state):
assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler)
-class InterdependentState(State):
+class InterdependentState(BaseState):
"""A state with 3 vars and 3 computed vars.
x: a variable that no computed var depends on
@@ -915,7 +913,7 @@ class InterdependentState(State):
@pytest.fixture
-def interdependent_state() -> State:
+def interdependent_state() -> BaseState:
"""A state with varying dependency between vars.
Returns:
@@ -988,7 +986,7 @@ def test_per_state_backend_var(interdependent_state):
def test_child_state():
"""Test that the child state computed vars can reference parent state vars."""
- class MainState(State):
+ class MainState(BaseState):
v: int = 2
class ChildState(MainState):
@@ -1006,7 +1004,7 @@ def test_child_state():
def test_conditional_computed_vars():
"""Test that computed vars can have conditionals."""
- class MainState(State):
+ class MainState(BaseState):
flag: bool = False
t1: str = "a"
t2: str = "b"
@@ -1051,7 +1049,7 @@ def test_event_handlers_convert_to_fns(test_state, child_state):
def test_event_handlers_call_other_handlers():
"""Test that event handlers can call other event handlers."""
- class MainState(State):
+ class MainState(BaseState):
v: int = 0
def set_v(self, v: int):
@@ -1077,7 +1075,7 @@ def test_computed_var_cached():
"""Test that a ComputedVar doesn't recalculate when accessed."""
comp_v_calls = 0
- class ComputedState(State):
+ class ComputedState(BaseState):
v: int = 0
@rx.cached_var
@@ -1102,7 +1100,7 @@ def test_computed_var_cached():
def test_computed_var_cached_depends_on_non_cached():
"""Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
- class ComputedState(State):
+ class ComputedState(BaseState):
v: int = 0
@rx.var
@@ -1144,7 +1142,7 @@ def test_computed_var_depends_on_parent_non_cached():
"""Child state cached_var that depends on parent state un cached var is always recalculated."""
counter = 0
- class ParentState(State):
+ class ParentState(BaseState):
@rx.var
def no_cache_v(self) -> int:
nonlocal counter
@@ -1165,21 +1163,18 @@ def test_computed_var_depends_on_parent_non_cached():
dict1 = ps.dict()
assert dict1[ps.get_full_name()] == {
"no_cache_v": 1,
- CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
assert dict1[cs.get_full_name()] == {"dep_v": 2}
dict2 = ps.dict()
assert dict2[ps.get_full_name()] == {
"no_cache_v": 3,
- CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
assert dict2[cs.get_full_name()] == {"dep_v": 4}
dict3 = ps.dict()
assert dict3[ps.get_full_name()] == {
"no_cache_v": 5,
- CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
assert dict3[cs.get_full_name()] == {"dep_v": 6}
@@ -1195,7 +1190,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
"""
counter = 0
- class HandlerState(State):
+ class HandlerState(BaseState):
x: int = 42
def handler(self):
@@ -1226,7 +1221,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
def test_computed_var_dependencies():
"""Test that a ComputedVar correctly tracks its dependencies."""
- class ComputedState(State):
+ class ComputedState(BaseState):
v: int = 0
w: int = 0
x: int = 0
@@ -1293,7 +1288,7 @@ def test_computed_var_dependencies():
def test_backend_method():
"""A method with leading underscore should be callable from event handler."""
- class BackendMethodState(State):
+ class BackendMethodState(BaseState):
def _be_method(self):
return True
@@ -1369,7 +1364,7 @@ def test_error_on_state_method_shadow():
"""Test that an error is thrown when an event handler shadows a state method."""
with pytest.raises(NameError) as err:
- class InvalidTest(rx.State):
+ class InvalidTest(BaseState):
def reset(self):
pass
@@ -1382,7 +1377,7 @@ def test_error_on_state_method_shadow():
def test_state_with_invalid_yield():
"""Test that an error is thrown when a state yields an invalid value."""
- class StateWithInvalidYield(rx.State):
+ class StateWithInvalidYield(BaseState):
"""A state that yields an invalid value."""
def invalid_handler(self):
@@ -1666,7 +1661,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert mcall.kwargs["to"] == grandchild_state.get_sid()
-class BackgroundTaskState(State):
+class BackgroundTaskState(BaseState):
"""A state with a background task."""
order: List[str] = []
@@ -2192,9 +2187,20 @@ def test_mutable_copy_vars(mutable_state, copy_func):
assert not isinstance(var_copy, MutableProxy)
-def test_duplicate_substate_class(duplicate_substate):
+def test_duplicate_substate_class(mocker):
+ mocker.patch("reflex.state.os.environ", {})
with pytest.raises(ValueError):
- duplicate_substate()
+
+ class TestState(BaseState):
+ pass
+
+ class ChildTestState(TestState): # type: ignore # noqa
+ pass
+
+ class ChildTestState(TestState): # type: ignore # noqa
+ pass
+
+ return TestState
class Foo(Base):
@@ -2206,7 +2212,7 @@ class Foo(Base):
def test_json_dumps_with_mutables():
"""Test that json.dumps works with Base vars inside mutable types."""
- class MutableContainsBase(State):
+ class MutableContainsBase(BaseState):
items: List[Foo] = [Foo()]
dict_val = MutableContainsBase().dict()
@@ -2216,7 +2222,7 @@ def test_json_dumps_with_mutables():
f_formatted_router = str(formatted_router).replace("'", '"')
assert (
val
- == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
+ == f'{{"{MutableContainsBase.get_full_name()}": {{"items": {f_items}, "router": {f_formatted_router}}}}}'
)
@@ -2225,7 +2231,7 @@ def test_reset_with_mutables():
default = [[0, 0], [0, 1], [1, 1]]
copied_default = copy.deepcopy(default)
- class MutableResetState(State):
+ class MutableResetState(BaseState):
items: List[List[int]] = default
instance = MutableResetState()
@@ -2273,7 +2279,7 @@ class Custom3(Base):
def test_state_union_optional():
"""Test that state can be defined with Union and Optional vars."""
- class UnionState(State):
+ class UnionState(BaseState):
int_float: Union[int, float] = 0
opt_int: Optional[int]
c3: Optional[Custom3]
diff --git a/tests/test_var.py b/tests/test_var.py
index c57c7eb23..8e7358bf5 100644
--- a/tests/test_var.py
+++ b/tests/test_var.py
@@ -6,7 +6,7 @@ import pytest
from pandas import DataFrame
from reflex.base import Base
-from reflex.state import State
+from reflex.state import BaseState
from reflex.vars import (
BaseVar,
ComputedVar,
@@ -24,12 +24,6 @@ test_vars = [
]
-class BaseState(State):
- """A Test State."""
-
- val: str = "key"
-
-
@pytest.fixture
def TestObj():
class TestObj(Base):
@@ -41,7 +35,7 @@ def TestObj():
@pytest.fixture
def ParentState(TestObj):
- class ParentState(State):
+ class ParentState(BaseState):
foo: int
bar: int
@@ -74,7 +68,7 @@ def GrandChildState(ChildState, TestObj):
@pytest.fixture
def StateWithAnyVar(TestObj):
- class StateWithAnyVar(State):
+ class StateWithAnyVar(BaseState):
@ComputedVar
def var_without_annotation(self) -> typing.Any:
return TestObj
@@ -84,7 +78,7 @@ def StateWithAnyVar(TestObj):
@pytest.fixture
def StateWithCorrectVarAnnotation():
- class StateWithCorrectVarAnnotation(State):
+ class StateWithCorrectVarAnnotation(BaseState):
@ComputedVar
def var_with_annotation(self) -> str:
return "Correct annotation"
@@ -94,7 +88,7 @@ def StateWithCorrectVarAnnotation():
@pytest.fixture
def StateWithWrongVarAnnotation(TestObj):
- class StateWithWrongVarAnnotation(State):
+ class StateWithWrongVarAnnotation(BaseState):
@ComputedVar
def var_with_annotation(self) -> str:
return TestObj
diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py
index 257f37514..e528536a1 100644
--- a/tests/utils/test_format.py
+++ b/tests/utils/test_format.py
@@ -528,7 +528,6 @@ formatted_router = {
},
"dt": "1989-11-09 18:53:00+01:00",
"fig": [],
- "is_hydrated": False,
"key": "",
"map_key": "a",
"mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
@@ -553,7 +552,6 @@ formatted_router = {
DateTimeState.get_full_name(): {
"d": "1989-11-09",
"dt": "1989-11-09 18:53:00+01:00",
- "is_hydrated": False,
"t": "18:53:00+01:00",
"td": "11 days, 0:11:00",
"router": formatted_router,
diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py
index b64f5f188..82085dd74 100644
--- a/tests/utils/test_utils.py
+++ b/tests/utils/test_utils.py
@@ -10,7 +10,7 @@ from packaging import version
from reflex import constants
from reflex.base import Base
from reflex.event import EventHandler
-from reflex.state import State
+from reflex.state import BaseState
from reflex.utils import (
build,
prerequisites,
@@ -43,7 +43,7 @@ V056 = version.parse("0.5.6")
VMAXPLUS1 = version.parse(get_above_max_version())
-class ExampleTestState(State):
+class ExampleTestState(BaseState):
"""Test state class."""
def test_event_handler(self):