RED-1052/rx.State as Base State (#2146)

This commit is contained in:
Elijah Ahianyo 2023-11-29 17:43:33 +00:00 committed by GitHub
parent f8395b1fd6
commit e3ee98098a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 356 additions and 270 deletions

View File

@ -93,7 +93,7 @@ def BackgroundTask():
rx.button("Reset", on_click=State.reset_counter, id="reset"), rx.button("Reset", on_click=State.reset_counter, id="reset"),
) )
app = rx.App(state=State) app = rx.App(state=rx.State)
app.add_page(index) app.add_page(index)
app.compile() app.compile()

View File

@ -135,7 +135,7 @@ def CallScript():
yield rx.call_script("inline_counter = 0; external_counter = 0") yield rx.call_script("inline_counter = 0; external_counter = 0")
self.reset() self.reset()
app = rx.App(state=CallScriptState) app = rx.App(state=rx.State)
with open("assets/external.js", "w") as f: with open("assets/external.js", "w") as f:
f.write(external_scripts) f.write(external_scripts)

View File

@ -97,7 +97,7 @@ def ClientSide():
rx.box(ClientSideSubSubState.l1s, id="l1s"), rx.box(ClientSideSubSubState.l1s, id="l1s"),
) )
app = rx.App(state=ClientSideState) app = rx.App(state=rx.State)
app.add_page(index) app.add_page(index)
app.add_page(index, route="/foo") app.add_page(index, route="/foo")
app.compile() app.compile()
@ -263,7 +263,6 @@ async def test_client_side_state(
state_var_input.send_keys("c7") state_var_input.send_keys("c7")
input_value_input.send_keys("c7 value") input_value_input.send_keys("c7 value")
set_sub_state_button.click() set_sub_state_button.click()
state_var_input.send_keys("l1") state_var_input.send_keys("l1")
input_value_input.send_keys("l1 value") input_value_input.send_keys("l1 value")
set_sub_state_button.click() set_sub_state_button.click()
@ -276,7 +275,6 @@ async def test_client_side_state(
state_var_input.send_keys("l4") state_var_input.send_keys("l4")
input_value_input.send_keys("l4 value") input_value_input.send_keys("l4 value")
set_sub_state_button.click() set_sub_state_button.click()
state_var_input.send_keys("c1s") state_var_input.send_keys("c1s")
input_value_input.send_keys("c1s value") input_value_input.send_keys("c1s value")
set_sub_sub_state_button.click() set_sub_sub_state_button.click()
@ -285,28 +283,28 @@ async def test_client_side_state(
set_sub_sub_state_button.click() set_sub_sub_state_button.click()
exp_cookies = { exp_cookies = {
"client_side_state.client_side_sub_state.c1": { "state.client_side_state.client_side_sub_state.c1": {
"domain": "localhost", "domain": "localhost",
"httpOnly": False, "httpOnly": False,
"name": "client_side_state.client_side_sub_state.c1", "name": "state.client_side_state.client_side_sub_state.c1",
"path": "/", "path": "/",
"sameSite": "Lax", "sameSite": "Lax",
"secure": False, "secure": False,
"value": "c1%20value", "value": "c1%20value",
}, },
"client_side_state.client_side_sub_state.c2": { "state.client_side_state.client_side_sub_state.c2": {
"domain": "localhost", "domain": "localhost",
"httpOnly": False, "httpOnly": False,
"name": "client_side_state.client_side_sub_state.c2", "name": "state.client_side_state.client_side_sub_state.c2",
"path": "/", "path": "/",
"sameSite": "Lax", "sameSite": "Lax",
"secure": False, "secure": False,
"value": "c2%20value", "value": "c2%20value",
}, },
"client_side_state.client_side_sub_state.c4": { "state.client_side_state.client_side_sub_state.c4": {
"domain": "localhost", "domain": "localhost",
"httpOnly": False, "httpOnly": False,
"name": "client_side_state.client_side_sub_state.c4", "name": "state.client_side_state.client_side_sub_state.c4",
"path": "/", "path": "/",
"sameSite": "Strict", "sameSite": "Strict",
"secure": False, "secure": False,
@ -321,19 +319,19 @@ async def test_client_side_state(
"secure": False, "secure": False,
"value": "c6%20value", "value": "c6%20value",
}, },
"client_side_state.client_side_sub_state.c7": { "state.client_side_state.client_side_sub_state.c7": {
"domain": "localhost", "domain": "localhost",
"httpOnly": False, "httpOnly": False,
"name": "client_side_state.client_side_sub_state.c7", "name": "state.client_side_state.client_side_sub_state.c7",
"path": "/", "path": "/",
"sameSite": "Lax", "sameSite": "Lax",
"secure": False, "secure": False,
"value": "c7%20value", "value": "c7%20value",
}, },
"client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": { "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
"domain": "localhost", "domain": "localhost",
"httpOnly": False, "httpOnly": False,
"name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s", "name": "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
"path": "/", "path": "/",
"sameSite": "Lax", "sameSite": "Lax",
"secure": False, "secure": False,
@ -354,40 +352,45 @@ async def test_client_side_state(
input_value_input.send_keys("c3 value") input_value_input.send_keys("c3 value")
set_sub_state_button.click() set_sub_state_button.click()
AppHarness._poll_for( AppHarness._poll_for(
lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver) lambda: "state.client_side_state.client_side_sub_state.c3"
in cookie_info_map(driver)
) )
c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"] c3_cookie = cookie_info_map(driver)[
"state.client_side_state.client_side_sub_state.c3"
]
assert c3_cookie.pop("expiry") is not None assert c3_cookie.pop("expiry") is not None
assert c3_cookie == { assert c3_cookie == {
"domain": "localhost", "domain": "localhost",
"httpOnly": False, "httpOnly": False,
"name": "client_side_state.client_side_sub_state.c3", "name": "state.client_side_state.client_side_sub_state.c3",
"path": "/", "path": "/",
"sameSite": "Lax", "sameSite": "Lax",
"secure": False, "secure": False,
"value": "c3%20value", "value": "c3%20value",
} }
time.sleep(2) # wait for c3 to expire time.sleep(2) # wait for c3 to expire
assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver) assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map(
driver
)
local_storage_items = local_storage.items() local_storage_items = local_storage.items()
local_storage_items.pop("chakra-ui-color-mode", None) local_storage_items.pop("chakra-ui-color-mode", None)
assert ( assert (
local_storage_items.pop("client_side_state.client_side_sub_state.l1") local_storage_items.pop("state.client_side_state.client_side_sub_state.l1")
== "l1 value" == "l1 value"
) )
assert ( assert (
local_storage_items.pop("client_side_state.client_side_sub_state.l2") local_storage_items.pop("state.client_side_state.client_side_sub_state.l2")
== "l2 value" == "l2 value"
) )
assert local_storage_items.pop("l3") == "l3 value" assert local_storage_items.pop("l3") == "l3 value"
assert ( assert (
local_storage_items.pop("client_side_state.client_side_sub_state.l4") local_storage_items.pop("state.client_side_state.client_side_sub_state.l4")
== "l4 value" == "l4 value"
) )
assert ( assert (
local_storage_items.pop( local_storage_items.pop(
"client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s" "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
) )
== "l1s value" == "l1s value"
) )
@ -482,12 +485,15 @@ async def test_client_side_state(
# make sure c5 cookie shows up on the `/foo` route # make sure c5 cookie shows up on the `/foo` route
AppHarness._poll_for( AppHarness._poll_for(
lambda: "client_side_state.client_side_sub_state.c5" in cookie_info_map(driver) lambda: "state.client_side_state.client_side_sub_state.c5"
in cookie_info_map(driver)
) )
assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == { assert cookie_info_map(driver)[
"state.client_side_state.client_side_sub_state.c5"
] == {
"domain": "localhost", "domain": "localhost",
"httpOnly": False, "httpOnly": False,
"name": "client_side_state.client_side_sub_state.c5", "name": "state.client_side_state.client_side_sub_state.c5",
"path": "/foo/", "path": "/foo/",
"sameSite": "Lax", "sameSite": "Lax",
"secure": False, "secure": False,

View File

@ -19,7 +19,7 @@ def ConnectionBanner():
def index(): def index():
return rx.text("Hello World") return rx.text("Hello World")
app = rx.App(state=State) app = rx.App(state=rx.State)
app.add_page(index) app.add_page(index)
app.compile() app.compile()

View File

@ -56,7 +56,7 @@ def DynamicRoute():
def redirect_page(): def redirect_page():
return rx.fragment(rx.text("redirecting...")) return rx.fragment(rx.text("redirecting..."))
app = rx.App(state=DynamicState) app = rx.App(state=rx.State)
app.add_page(index) app.add_page(index)
app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore
app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore
@ -143,10 +143,12 @@ def poll_for_order(
return await dynamic_route.get_state(token) return await dynamic_route.get_state(token)
async def _check(): async def _check():
return (await _backend_state()).order == exp_order return (await _backend_state()).substates[
"dynamic_state"
].order == exp_order
await AppHarness._poll_for_async(_check) await AppHarness._poll_for_async(_check)
assert (await _backend_state()).order == exp_order assert (await _backend_state()).substates["dynamic_state"].order == exp_order
return _poll_for_order return _poll_for_order

View File

@ -130,7 +130,7 @@ def TestEventAction():
on_click=EventActionState.on_click("outer"), # type: ignore on_click=EventActionState.on_click("outer"), # type: ignore
) )
app = rx.App(state=EventActionState) app = rx.App(state=rx.State)
app.add_page(index) app.add_page(index)
app.compile() app.compile()
@ -211,10 +211,14 @@ def poll_for_order(
return await event_action.get_state(token) return await event_action.get_state(token)
async def _check(): async def _check():
return (await _backend_state()).order == exp_order return (await _backend_state()).substates[
"event_action_state"
].order == exp_order
await AppHarness._poll_for_async(_check) await AppHarness._poll_for_async(_check)
assert (await _backend_state()).order == exp_order assert (await _backend_state()).substates[
"event_action_state"
].order == exp_order
return _poll_for_order return _poll_for_order

View File

@ -122,7 +122,7 @@ def EventChain():
time.sleep(0.5) time.sleep(0.5)
self.interim_value = "final" self.interim_value = "final"
app = rx.App(state=State) app = rx.App(state=rx.State)
token_input = rx.input( token_input = rx.input(
value=State.router.session.client_token, is_read_only=True, id="token" value=State.router.session.client_token, is_read_only=True, id="token"
@ -401,12 +401,12 @@ async def test_event_chain_click(
btn.click() btn.click()
async def _has_all_events(): async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len( return len(
exp_event_order (await event_chain.get_state(token)).substates["state"].event_order
) ) == len(exp_event_order)
await AppHarness._poll_for_async(_has_all_events) await AppHarness._poll_for_async(_has_all_events)
event_order = (await event_chain.get_state(token)).event_order event_order = (await event_chain.get_state(token)).substates["state"].event_order
assert event_order == exp_event_order assert event_order == exp_event_order
@ -453,12 +453,12 @@ async def test_event_chain_on_load(
token = assert_token(event_chain, driver) token = assert_token(event_chain, driver)
async def _has_all_events(): async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len( return len(
exp_event_order (await event_chain.get_state(token)).substates["state"].event_order
) ) == len(exp_event_order)
await AppHarness._poll_for_async(_has_all_events) await AppHarness._poll_for_async(_has_all_events)
backend_state = await event_chain.get_state(token) backend_state = (await event_chain.get_state(token)).substates["state"]
assert backend_state.event_order == exp_event_order assert backend_state.event_order == exp_event_order
assert backend_state.is_hydrated is True assert backend_state.is_hydrated is True
@ -529,12 +529,12 @@ async def test_event_chain_on_mount(
unmount_button.click() unmount_button.click()
async def _has_all_events(): async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len( return len(
exp_event_order (await event_chain.get_state(token)).substates["state"].event_order
) ) == len(exp_event_order)
await AppHarness._poll_for_async(_has_all_events) await AppHarness._poll_for_async(_has_all_events)
event_order = (await event_chain.get_state(token)).event_order event_order = (await event_chain.get_state(token)).substates["state"].event_order
assert event_order == exp_event_order assert event_order == exp_event_order

View File

@ -22,7 +22,7 @@ def FormSubmit():
def form_submit(self, form_data: dict): def form_submit(self, form_data: dict):
self.form_data = form_data self.form_data = form_data
app = rx.App(state=FormState) app = rx.App(state=rx.State)
@app.add_page @app.add_page
def index(): def index():
@ -75,7 +75,7 @@ def FormSubmitName():
def form_submit(self, form_data: dict): def form_submit(self, form_data: dict):
self.form_data = form_data self.form_data = form_data
app = rx.App(state=FormState) app = rx.App(state=rx.State)
@app.add_page @app.add_page
def index(): def index():
@ -210,7 +210,7 @@ async def test_submit(driver, form_submit: AppHarness):
submit_input.click() submit_input.click()
async def get_form_data(): async def get_form_data():
return (await form_submit.get_state(token)).form_data return (await form_submit.get_state(token)).substates["form_state"].form_data
# wait for the form data to arrive at the backend # wait for the form data to arrive at the backend
form_data = await AppHarness._poll_for_async(get_form_data) form_data = await AppHarness._poll_for_async(get_form_data)

View File

@ -16,7 +16,7 @@ def FullyControlledInput():
class State(rx.State): class State(rx.State):
text: str = "initial" text: str = "initial"
app = rx.App(state=State) app = rx.App(state=rx.State)
@app.add_page @app.add_page
def index(): def index():
@ -85,13 +85,15 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("foo") debounce_input.send_keys("foo")
time.sleep(0.5) time.sleep(0.5)
assert debounce_input.get_attribute("value") == "ifoonitial" assert debounce_input.get_attribute("value") == "ifoonitial"
assert (await fully_controlled_input.get_state(token)).text == "ifoonitial" assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "ifoonitial"
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial" assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
# clear the input on the backend # clear the input on the backend
async with fully_controlled_input.modify_state(token) as state: async with fully_controlled_input.modify_state(token) as state:
state.text = "" state.substates["state"].text = ""
assert (await fully_controlled_input.get_state(token)).text == "" assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
assert ( assert (
fully_controlled_input.poll_for_value( fully_controlled_input.poll_for_value(
debounce_input, exp_not_equal="ifoonitial" debounce_input, exp_not_equal="ifoonitial"
@ -103,9 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("getting testing done") debounce_input.send_keys("getting testing done")
time.sleep(0.5) time.sleep(0.5)
assert debounce_input.get_attribute("value") == "getting testing done" assert debounce_input.get_attribute("value") == "getting testing done"
assert ( assert (await fully_controlled_input.get_state(token)).substates[
await fully_controlled_input.get_state(token) "state"
).text == "getting testing done" ].text == "getting testing done"
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done" assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
# type into the on_change input # type into the on_change input
@ -113,7 +115,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
time.sleep(0.5) time.sleep(0.5)
assert debounce_input.get_attribute("value") == "overwrite the state" assert debounce_input.get_attribute("value") == "overwrite the state"
assert on_change_input.get_attribute("value") == "overwrite the state" assert on_change_input.get_attribute("value") == "overwrite the state"
assert (await fully_controlled_input.get_state(token)).text == "overwrite the state" assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "overwrite the state"
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state" assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
clear_button.click() clear_button.click()

View File

@ -42,7 +42,7 @@ def LoginSample():
rx.button("Do it", on_click=State.login, id="doit"), rx.button("Do it", on_click=State.login, id="doit"),
) )
app = rx.App(state=State) app = rx.App(state=rx.State)
app.add_page(index) app.add_page(index)
app.add_page(login) app.add_page(login)
app.compile() app.compile()
@ -137,6 +137,6 @@ def test_login_flow(
logout_button = driver.find_element(By.ID, "logout") logout_button = driver.find_element(By.ID, "logout")
logout_button.click() logout_button.click()
assert login_sample._poll_for(lambda: local_storage["state.auth_token"] == "") assert login_sample._poll_for(lambda: local_storage["state.state.auth_token"] == "")
with pytest.raises(NoSuchElementException): with pytest.raises(NoSuchElementException):
driver.find_element(By.ID, "auth-token") driver.find_element(By.ID, "auth-token")

View File

@ -81,7 +81,7 @@ def RadixThemesApp():
) )
app = rx.App( app = rx.App(
state=State, state=rx.State,
theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"), theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"),
) )
app.add_page(index) app.add_page(index)

View File

@ -33,7 +33,7 @@ def ServerSideEvent():
def set_value_return_c(self): def set_value_return_c(self):
return rx.set_value("c", "") return rx.set_value("c", "")
app = rx.App(state=SSState) app = rx.App(state=rx.State)
@app.add_page @app.add_page
def index(): def index():

View File

@ -26,7 +26,7 @@ def Table():
caption: str = "random caption" caption: str = "random caption"
app = rx.App(state=TableState) app = rx.App(state=rx.State)
@app.add_page @app.add_page
def index(): def index():

View File

@ -113,7 +113,7 @@ def UploadFile():
), ),
) )
app = rx.App(state=UploadState) app = rx.App(state=rx.State)
app.add_page(index) app.add_page(index)
app.compile() app.compile()
@ -192,7 +192,7 @@ async def test_upload_file(
# look up the backend state and assert on uploaded contents # look up the backend state and assert on uploaded contents
async def get_file_data(): async def get_file_data():
return (await upload_file.get_state(token))._file_data return (await upload_file.get_state(token)).substates["upload_state"]._file_data
file_data = await AppHarness._poll_for_async(get_file_data) file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict) assert isinstance(file_data, dict)
@ -205,8 +205,8 @@ async def test_upload_file(
state = await upload_file.get_state(token) state = await upload_file.get_state(token)
if secondary: if secondary:
# only the secondary form tracks progress and chain events # only the secondary form tracks progress and chain events
assert state.event_order.count("upload_progress") == 1 assert state.substates["upload_state"].event_order.count("upload_progress") == 1
assert state.event_order.count("chain_event") == 1 assert state.substates["upload_state"].event_order.count("chain_event") == 1
@pytest.mark.asyncio @pytest.mark.asyncio
@ -251,7 +251,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
# look up the backend state and assert on uploaded contents # look up the backend state and assert on uploaded contents
async def get_file_data(): async def get_file_data():
return (await upload_file.get_state(token))._file_data return (await upload_file.get_state(token)).substates["upload_state"]._file_data
file_data = await AppHarness._poll_for_async(get_file_data) file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict) assert isinstance(file_data, dict)
@ -349,7 +349,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
# look up the backend state and assert on progress # look up the backend state and assert on progress
state = await upload_file.get_state(token) state = await upload_file.get_state(token)
assert state.progress_dicts assert state.substates["upload_state"].progress_dicts
assert exp_name not in state._file_data assert exp_name not in state.substates["upload_state"]._file_data
target_file.unlink() target_file.unlink()

View File

@ -30,7 +30,7 @@ def VarOperations():
dict2: dict = {3: 4} dict2: dict = {3: 4}
html_str: str = "<div>hello</div>" html_str: str = "<div>hello</div>"
app = rx.App(state=VarOperationState) app = rx.App(state=rx.State)
@app.add_page @app.add_page
def index(): def index():

View File

@ -57,6 +57,7 @@ from reflex.route import (
verify_route_validity, verify_route_validity,
) )
from reflex.state import ( from reflex.state import (
BaseState,
RouterData, RouterData,
State, State,
StateManager, StateManager,
@ -98,7 +99,7 @@ class App(Base):
socket_app: Optional[ASGIApp] = None socket_app: Optional[ASGIApp] = None
# The state class to use for the app. # The state class to use for the app.
state: Optional[Type[State]] = None state: Optional[Type[BaseState]] = None
# Class to manage many client states. # Class to manage many client states.
_state_manager: Optional[StateManager] = None _state_manager: Optional[StateManager] = None
@ -149,25 +150,24 @@ class App(Base):
"`connect_error_component` is deprecated, use `overlay_component` instead" "`connect_error_component` is deprecated, use `overlay_component` instead"
) )
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
state_subclasses = State.__subclasses__() state_subclasses = BaseState.__subclasses__()
inferred_state = state_subclasses[-1] if state_subclasses else None
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
# Special case to allow test cases have multiple subclasses of rx.State. # Special case to allow test cases have multiple subclasses of rx.BaseState.
if not is_testing_env: if not is_testing_env:
# Only one State class is allowed. # Only one Base State class is allowed.
if len(state_subclasses) > 1: if len(state_subclasses) > 1:
raise ValueError( raise ValueError(
"rx.State has been subclassed multiple times. Only one subclass is allowed" "rx.BaseState cannot be subclassed multiple times. use rx.State instead"
) )
# verify that provided state is valid # verify that provided state is valid
if self.state and inferred_state and self.state is not inferred_state: if self.state and self.state is not State:
console.warn( console.warn(
f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported." f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported."
f" Defaulting to root state: ({inferred_state.__name__})" f" Defaulting to root state: ({State.__name__})"
) )
self.state = inferred_state self.state = State
# Get the config # Get the config
config = get_config() config = get_config()
@ -265,7 +265,7 @@ class App(Base):
raise ValueError("The state manager has not been initialized.") raise ValueError("The state manager has not been initialized.")
return self._state_manager return self._state_manager
async def preprocess(self, state: State, event: Event) -> StateUpdate | None: async def preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
"""Preprocess the event. """Preprocess the event.
This is where middleware can modify the event before it is processed. This is where middleware can modify the event before it is processed.
@ -290,7 +290,7 @@ class App(Base):
return out # type: ignore return out # type: ignore
async def postprocess( async def postprocess(
self, state: State, event: Event, update: StateUpdate self, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate: ) -> StateUpdate:
"""Postprocess the event. """Postprocess the event.
@ -764,7 +764,7 @@ class App(Base):
future.result() future.result()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]: async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state out of band. """Modify the state out of band.
Args: Args:
@ -792,7 +792,9 @@ class App(Base):
sid=state.router.session.session_id, sid=state.router.session.session_id,
) )
def _process_background(self, state: State, event: Event) -> asyncio.Task | None: def _process_background(
self, state: BaseState, event: Event
) -> asyncio.Task | None:
"""Process an event in the background and emit updates as they arrive. """Process an event in the background and emit updates as they arrive.
Args: Args:

View File

@ -33,6 +33,7 @@ from reflex.route import (
) )
from reflex.state import ( from reflex.state import (
State as State, State as State,
BaseState as BaseState,
StateManager as StateManager, StateManager as StateManager,
StateUpdate as StateUpdate, StateUpdate as StateUpdate,
) )
@ -69,7 +70,7 @@ class App(Base):
api: FastAPI api: FastAPI
sio: Optional[AsyncServer] sio: Optional[AsyncServer]
socket_app: Optional[ASGIApp] socket_app: Optional[ASGIApp]
state: Type[State] state: Type[BaseState]
state_manager: StateManager state_manager: StateManager
style: ComponentStyle style: ComponentStyle
middleware: List[Middleware] middleware: List[Middleware]

View File

@ -1,11 +1,42 @@
"""Define the base Reflex class.""" """Define the base Reflex class."""
from __future__ import annotations from __future__ import annotations
from typing import Any import os
from typing import Any, List, Type
import pydantic import pydantic
from pydantic import BaseModel
from pydantic.fields import ModelField from pydantic.fields import ModelField
from reflex import constants
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
"""Ensure that the field's name does not shadow an existing attribute of the model.
Args:
bases: List of base models to check for shadowed attrs.
field_name: name of attribute
Raises:
NameError: If state var field shadows another in its parent state
"""
reload = os.getenv(constants.RELOAD_CONFIG) == "True"
for base in bases:
try:
if not reload and getattr(base, field_name, None):
pass
except TypeError as te:
raise NameError(
f'State var "{field_name}" in {base} has been shadowed by a substate var; '
f'use a different field name instead".'
) from te
# monkeypatch pydantic validate_field_name method to skip validating
# shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
pydantic.main.validate_field_name = validate_field_name # type: ignore
class Base(pydantic.BaseModel): class Base(pydantic.BaseModel):
"""The base class subclassed by all Reflex classes. """The base class subclassed by all Reflex classes.

View File

@ -15,7 +15,7 @@ from reflex.components.component import (
StatefulComponent, StatefulComponent,
) )
from reflex.config import get_config from reflex.config import get_config
from reflex.state import State from reflex.state import BaseState
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
@ -63,7 +63,7 @@ def _compile_theme(theme: dict) -> str:
return templates.THEME.render(theme=theme) return templates.THEME.render(theme=theme)
def _compile_contexts(state: Optional[Type[State]]) -> str: def _compile_contexts(state: Optional[Type[BaseState]]) -> str:
"""Compile the initial state and contexts. """Compile the initial state and contexts.
Args: Args:
@ -87,7 +87,7 @@ def _compile_contexts(state: Optional[Type[State]]) -> str:
def _compile_page( def _compile_page(
component: Component, component: Component,
state: Type[State], state: Type[BaseState],
) -> str: ) -> str:
"""Compile the component given the app state. """Compile the component given the app state.
@ -337,7 +337,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]:
return output_path, code return output_path, code
def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]: def compile_contexts(state: Optional[Type[BaseState]]) -> tuple[str, str]:
"""Compile the initial state / context. """Compile the initial state / context.
Args: Args:
@ -353,7 +353,7 @@ def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
def compile_page( def compile_page(
path: str, component: Component, state: Type[State] path: str, component: Component, state: Type[BaseState]
) -> tuple[str, str]: ) -> tuple[str, str]:
"""Compile a single page. """Compile a single page.

View File

@ -21,7 +21,7 @@ from reflex.components.base import (
Title, Title,
) )
from reflex.components.component import Component, ComponentStyle, CustomComponent from reflex.components.component import Component, ComponentStyle, CustomComponent
from reflex.state import Cookie, LocalStorage, State from reflex.state import BaseState, Cookie, LocalStorage
from reflex.style import Style from reflex.style import Style
from reflex.utils import console, format, imports, path_ops from reflex.utils import console, format, imports, path_ops
@ -128,7 +128,7 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None)
} }
def compile_state(state: Type[State]) -> dict: def compile_state(state: Type[BaseState]) -> dict:
"""Compile the state of the app. """Compile the state of the app.
Args: Args:
@ -170,7 +170,7 @@ def _compile_client_storage_field(
def _compile_client_storage_recursive( def _compile_client_storage_recursive(
state: Type[State], state: Type[BaseState],
) -> tuple[dict[str, dict], dict[str, dict[str, str]]]: ) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
"""Compile the client-side storage for the given state recursively. """Compile the client-side storage for the given state recursively.
@ -208,7 +208,7 @@ def _compile_client_storage_recursive(
return cookies, local_storage return cookies, local_storage
def compile_client_storage(state: Type[State]) -> dict[str, dict]: def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
"""Compile the client-side storage for the given state. """Compile the client-side storage for the given state.
Args: Args:

View File

@ -6,6 +6,7 @@ from .base import (
LOCAL_STORAGE, LOCAL_STORAGE,
POLLING_MAX_HTTP_BUFFER_SIZE, POLLING_MAX_HTTP_BUFFER_SIZE,
PYTEST_CURRENT_TEST, PYTEST_CURRENT_TEST,
RELOAD_CONFIG,
SKIP_COMPILE_ENV_VAR, SKIP_COMPILE_ENV_VAR,
ColorMode, ColorMode,
Dirs, Dirs,
@ -85,6 +86,7 @@ __ALL__ = [
PYTEST_CURRENT_TEST, PYTEST_CURRENT_TEST,
PRODUCTION_BACKEND_URL, PRODUCTION_BACKEND_URL,
Reflex, Reflex,
RELOAD_CONFIG,
RequirementsTxt, RequirementsTxt,
RouteArgType, RouteArgType,
RouteRegex, RouteRegex,

View File

@ -173,3 +173,4 @@ SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
# Testing variables. # Testing variables.
# Testing os env set by pytest when running a test case. # Testing os env set by pytest when running a test case.
PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"

View File

@ -22,7 +22,7 @@ from reflex.utils.types import ArgsSpec
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, Var
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.state import State from reflex.state import BaseState
class Event(Base): class Event(Base):
@ -64,7 +64,7 @@ def background(fn):
def _no_chain_background_task( def _no_chain_background_task(
state_cls: Type["State"], name: str, fn: Callable state_cls: Type["BaseState"], name: str, fn: Callable
) -> Callable: ) -> Callable:
"""Protect against directly chaining a background task from another event handler. """Protect against directly chaining a background task from another event handler.

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from reflex import constants from reflex import constants
from reflex.event import Event, fix_events, get_hydrate_event from reflex.event import Event, fix_events, get_hydrate_event
from reflex.middleware.middleware import Middleware from reflex.middleware.middleware import Middleware
from reflex.state import State, StateUpdate from reflex.state import BaseState, StateUpdate
from reflex.utils import format from reflex.utils import format
if TYPE_CHECKING: if TYPE_CHECKING:
@ -17,7 +17,7 @@ class HydrateMiddleware(Middleware):
"""Middleware to handle initial app hydration.""" """Middleware to handle initial app hydration."""
async def preprocess( async def preprocess(
self, app: App, state: State, event: Event self, app: App, state: BaseState, event: Event
) -> Optional[StateUpdate]: ) -> Optional[StateUpdate]:
"""Preprocess the event. """Preprocess the event.

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from reflex.base import Base from reflex.base import Base
from reflex.event import Event from reflex.event import Event
from reflex.state import State, StateUpdate from reflex.state import BaseState, StateUpdate
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.app import App from reflex.app import App
@ -16,7 +16,7 @@ class Middleware(Base, ABC):
"""Middleware to preprocess and postprocess requests.""" """Middleware to preprocess and postprocess requests."""
async def preprocess( async def preprocess(
self, app: App, state: State, event: Event self, app: App, state: BaseState, event: Event
) -> Optional[StateUpdate]: ) -> Optional[StateUpdate]:
"""Preprocess the event. """Preprocess the event.
@ -31,7 +31,7 @@ class Middleware(Base, ABC):
return None return None
async def postprocess( async def postprocess(
self, app: App, state: State, event: Event, update: StateUpdate self, app: App, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate: ) -> StateUpdate:
"""Postprocess the event. """Postprocess the event.

View File

@ -7,6 +7,7 @@ import copy
import functools import functools
import inspect import inspect
import json import json
import os
import traceback import traceback
import urllib.parse import urllib.parse
import uuid import uuid
@ -81,7 +82,7 @@ class HeaderData(Base):
class PageData(Base): class PageData(Base):
"""An object containing page data.""" """An object containing page data."""
host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
path: str = "" path: str = ""
raw_path: str = "" raw_path: str = ""
full_path: str = "" full_path: str = ""
@ -152,7 +153,7 @@ RESERVED_BACKEND_VAR_NAMES = {
} }
class State(Base, ABC, extra=pydantic.Extra.allow): class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app.""" """The state of the app."""
# A map from the var name to the var. # A map from the var name to the var.
@ -176,6 +177,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The event handlers. # The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {} event_handlers: ClassVar[Dict[str, EventHandler]] = {}
# A set of subclassses of this class.
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
# Mapping of var name to set of computed variables that depend on it # Mapping of var name to set of computed variables that depend on it
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
@ -189,10 +193,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
_always_dirty_substates: ClassVar[Set[str]] = set() _always_dirty_substates: ClassVar[Set[str]] = set()
# The parent state. # The parent state.
parent_state: Optional[State] = None parent_state: Optional[BaseState] = None
# The substates of the state. # The substates of the state.
substates: Dict[str, State] = {} substates: Dict[str, BaseState] = {}
# The set of dirty vars. # The set of dirty vars.
dirty_vars: Set[str] = set() dirty_vars: Set[str] = set()
@ -209,10 +213,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The router data for the current page # The router data for the current page
router: RouterData = RouterData() router: RouterData = RouterData()
# The hydrated bool. def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
is_hydrated: bool = False
def __init__(self, *args, parent_state: State | None = None, **kwargs):
"""Initialize the state. """Initialize the state.
Args: Args:
@ -220,28 +221,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
parent_state: The parent state. parent_state: The parent state.
**kwargs: The kwargs to pass to the Pydantic init method. **kwargs: The kwargs to pass to the Pydantic init method.
Raises:
ValueError: If a substate class shadows another.
""" """
kwargs["parent_state"] = parent_state kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Setup the substates. # Setup the substates.
for substate in self.get_substates(): for substate in self.get_substates():
substate_name = substate.get_name() self.substates[substate.get_name()] = substate(parent_state=self)
if substate_name in self.substates:
raise ValueError(
f"The substate class '{substate_name}' has been defined multiple times. Shadowing "
f"substate classes is not allowed."
)
self.substates[substate_name] = substate(parent_state=self)
# Convert the event handlers to functions. # Convert the event handlers to functions.
self._init_event_handlers() self._init_event_handlers()
# Create a fresh copy of the backend variables for this instance # Create a fresh copy of the backend variables for this instance
self._backend_vars = copy.deepcopy(self.backend_vars) self._backend_vars = copy.deepcopy(self.backend_vars)
def _init_event_handlers(self, state: State | None = None): def _init_event_handlers(self, state: BaseState | None = None):
"""Initialize event handlers. """Initialize event handlers.
Allow event handlers to be called directly on the instance. This is Allow event handlers to be called directly on the instance. This is
@ -281,17 +274,44 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Args: Args:
**kwargs: The kwargs to pass to the pydantic init_subclass method. **kwargs: The kwargs to pass to the pydantic init_subclass method.
Raises:
ValueError: If a substate class shadows another.
""" """
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
# Event handlers should not shadow builtin state methods. # Event handlers should not shadow builtin state methods.
cls._check_overridden_methods() cls._check_overridden_methods()
# Reset subclass tracking for this class.
cls.class_subclasses = set()
# Get the parent vars. # Get the parent vars.
parent_state = cls.get_parent_state() parent_state = cls.get_parent_state()
if parent_state is not None: if parent_state is not None:
cls.inherited_vars = parent_state.vars cls.inherited_vars = parent_state.vars
cls.inherited_backend_vars = parent_state.backend_vars cls.inherited_backend_vars = parent_state.backend_vars
# Check if another substate class with the same name has already been defined.
if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
if is_testing_env:
# Clear existing subclass with same name when app is reloaded via
# utils.prerequisites.get_app(reload=True)
parent_state.class_subclasses = set(
c
for c in parent_state.class_subclasses
if c.__name__ != cls.__name__
)
else:
# During normal operation, subclasses cannot have the same name, even if they are
# defined in different modules.
raise ValueError(
f"The substate class '{cls.__name__}' has been defined multiple times. "
"Shadowing substate classes is not allowed."
)
# Track this new subclass in the parent state's subclasses set.
parent_state.class_subclasses.add(cls)
cls.new_backend_vars = { cls.new_backend_vars = {
name: value name: value
for name, value in cls.__dict__.items() for name, value in cls.__dict__.items()
@ -437,7 +457,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
@classmethod @classmethod
@functools.lru_cache() @functools.lru_cache()
def get_parent_state(cls) -> Type[State] | None: def get_parent_state(cls) -> Type[BaseState] | None:
"""Get the parent state. """Get the parent state.
Returns: Returns:
@ -446,20 +466,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
parent_states = [ parent_states = [
base base
for base in cls.__bases__ for base in cls.__bases__
if types._issubclass(base, State) and base is not State if types._issubclass(base, BaseState) and base is not BaseState
] ]
assert len(parent_states) < 2, "Only one parent state is allowed." assert len(parent_states) < 2, "Only one parent state is allowed."
return parent_states[0] if len(parent_states) == 1 else None # type: ignore return parent_states[0] if len(parent_states) == 1 else None # type: ignore
@classmethod @classmethod
@functools.lru_cache() def get_substates(cls) -> set[Type[BaseState]]:
def get_substates(cls) -> set[Type[State]]:
"""Get the substates of the state. """Get the substates of the state.
Returns: Returns:
The substates of the state. The substates of the state.
""" """
return set(cls.__subclasses__()) return cls.class_subclasses
@classmethod @classmethod
@functools.lru_cache() @functools.lru_cache()
@ -487,7 +506,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
@classmethod @classmethod
@functools.lru_cache() @functools.lru_cache()
def get_class_substate(cls, path: Sequence[str]) -> Type[State]: def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
"""Get the class substate. """Get the class substate.
Args: Args:
@ -643,7 +662,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
""" """
return { return {
func[0]: func[1] func[0]: func[1]
for func in inspect.getmembers(State, predicate=inspect.isfunction) for func in inspect.getmembers(BaseState, predicate=inspect.isfunction)
if not func[0].startswith("__") if not func[0].startswith("__")
} }
@ -909,7 +928,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
for substate in self.substates.values(): for substate in self.substates.values():
substate._reset_client_storage() substate._reset_client_storage()
def get_substate(self, path: Sequence[str]) -> State | None: def get_substate(self, path: Sequence[str]) -> BaseState | None:
"""Get the substate. """Get the substate.
Args: Args:
@ -933,7 +952,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
def _get_event_handler( def _get_event_handler(
self, event: Event self, event: Event
) -> tuple[State | StateProxy, EventHandler]: ) -> tuple[BaseState | StateProxy, EventHandler]:
"""Get the event handler for the given event. """Get the event handler for the given event.
Args: Args:
@ -1050,7 +1069,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
) )
async def _process_event( async def _process_event(
self, handler: EventHandler, state: State | StateProxy, payload: Dict self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
) -> AsyncIterator[StateUpdate]: ) -> AsyncIterator[StateUpdate]:
"""Process event. """Process event.
@ -1263,7 +1282,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
d.update(substate_d) d.update(substate_d)
return d return d
async def __aenter__(self) -> State: async def __aenter__(self) -> BaseState:
"""Enter the async context manager protocol. """Enter the async context manager protocol.
This should not be used for the State class, but exists for This should not be used for the State class, but exists for
@ -1288,6 +1307,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
pass pass
class State(BaseState):
"""The app Base State."""
# The hydrated bool.
is_hydrated: bool = False
class StateProxy(wrapt.ObjectProxy): class StateProxy(wrapt.ObjectProxy):
"""Proxy of a state instance to control mutability of vars for a background task. """Proxy of a state instance to control mutability of vars for a background task.
@ -1455,10 +1481,10 @@ class StateManager(Base, ABC):
"""A class to manage many client states.""" """A class to manage many client states."""
# The state class to use. # The state class to use.
state: Type[State] state: Type[BaseState]
@classmethod @classmethod
def create(cls, state: Type[State]): def create(cls, state: Type[BaseState]):
"""Create a new state manager. """Create a new state manager.
Args: Args:
@ -1473,7 +1499,7 @@ class StateManager(Base, ABC):
return StateManagerMemory(state=state) return StateManagerMemory(state=state)
@abstractmethod @abstractmethod
async def get_state(self, token: str) -> State: async def get_state(self, token: str) -> BaseState:
"""Get the state for a token. """Get the state for a token.
Args: Args:
@ -1485,7 +1511,7 @@ class StateManager(Base, ABC):
pass pass
@abstractmethod @abstractmethod
async def set_state(self, token: str, state: State): async def set_state(self, token: str, state: BaseState):
"""Set the state for a token. """Set the state for a token.
Args: Args:
@ -1496,7 +1522,7 @@ class StateManager(Base, ABC):
@abstractmethod @abstractmethod
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]: async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock. """Modify the state for a token while holding exclusive lock.
Args: Args:
@ -1512,7 +1538,7 @@ class StateManagerMemory(StateManager):
"""A state manager that stores states in memory.""" """A state manager that stores states in memory."""
# The mapping of client ids to states. # The mapping of client ids to states.
states: Dict[str, State] = {} states: Dict[str, BaseState] = {}
# The mutex ensures the dict of mutexes is updated exclusively # The mutex ensures the dict of mutexes is updated exclusively
_state_manager_lock = asyncio.Lock() _state_manager_lock = asyncio.Lock()
@ -1527,7 +1553,7 @@ class StateManagerMemory(StateManager):
"_states_locks": {"exclude": True}, "_states_locks": {"exclude": True},
} }
async def get_state(self, token: str) -> State: async def get_state(self, token: str) -> BaseState:
"""Get the state for a token. """Get the state for a token.
Args: Args:
@ -1540,7 +1566,7 @@ class StateManagerMemory(StateManager):
self.states[token] = self.state() self.states[token] = self.state()
return self.states[token] return self.states[token]
async def set_state(self, token: str, state: State): async def set_state(self, token: str, state: BaseState):
"""Set the state for a token. """Set the state for a token.
Args: Args:
@ -1550,7 +1576,7 @@ class StateManagerMemory(StateManager):
pass pass
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]: async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock. """Modify the state for a token while holding exclusive lock.
Args: Args:
@ -1598,7 +1624,7 @@ class StateManagerRedis(StateManager):
b"evicted", b"evicted",
} }
async def get_state(self, token: str) -> State: async def get_state(self, token: str) -> BaseState:
"""Get the state for a token. """Get the state for a token.
Args: Args:
@ -1613,7 +1639,9 @@ class StateManagerRedis(StateManager):
return await self.get_state(token) return await self.get_state(token)
return cloudpickle.loads(redis_state) return cloudpickle.loads(redis_state)
async def set_state(self, token: str, state: State, lock_id: bytes | None = None): async def set_state(
self, token: str, state: BaseState, lock_id: bytes | None = None
):
"""Set the state for a token. """Set the state for a token.
Args: Args:
@ -1637,7 +1665,7 @@ class StateManagerRedis(StateManager):
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration) await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]: async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock. """Modify the state for a token while holding exclusive lock.
Args: Args:
@ -1879,7 +1907,7 @@ class MutableProxy(wrapt.ObjectProxy):
__mutable_types__ = (list, dict, set, Base) __mutable_types__ = (list, dict, set, Base)
def __init__(self, wrapped: Any, state: State, field_name: str): def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes. """Create a proxy for a mutable object that tracks changes.
Args: Args:

View File

@ -38,7 +38,7 @@ import reflex.utils.build
import reflex.utils.exec import reflex.utils.exec
import reflex.utils.prerequisites import reflex.utils.prerequisites
import reflex.utils.processes import reflex.utils.processes
from reflex.state import State, StateManagerMemory, StateManagerRedis from reflex.state import BaseState, State, StateManagerMemory, StateManagerRedis
try: try:
from selenium import webdriver # pyright: ignore [reportMissingImports] from selenium import webdriver # pyright: ignore [reportMissingImports]
@ -162,6 +162,9 @@ class AppHarness:
with chdir(self.app_path): with chdir(self.app_path):
# ensure config and app are reloaded when testing different app # ensure config and app are reloaded when testing different app
reflex.config.get_config(reload=True) reflex.config.get_config(reload=True)
# reset rx.State subclasses
State.class_subclasses.clear()
# self.app_module.app.
self.app_module = reflex.utils.prerequisites.get_app(reload=True) self.app_module = reflex.utils.prerequisites.get_app(reload=True)
self.app_instance = self.app_module.app self.app_instance = self.app_module.app
if isinstance(self.app_instance.state_manager, StateManagerRedis): if isinstance(self.app_instance.state_manager, StateManagerRedis):
@ -434,7 +437,7 @@ class AppHarness:
self._frontends.append(driver) self._frontends.append(driver)
return driver return driver
async def get_state(self, token: str) -> State: async def get_state(self, token: str) -> BaseState:
"""Get the state associated with the given token. """Get the state associated with the given token.
Args: Args:
@ -561,7 +564,7 @@ class AppHarness:
) )
return element.get_attribute("value") return element.get_attribute("value")
def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, reflex.State]: def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
"""Poll app state_manager for any connected clients. """Poll app state_manager for any connected clients.
Args: Args:

View File

@ -124,11 +124,13 @@ def get_app(reload: bool = False) -> ModuleType:
Returns: Returns:
The app based on the default config. The app based on the default config.
""" """
os.environ[constants.RELOAD_CONFIG] = str(reload)
config = get_config() config = get_config()
module = ".".join([config.app_name, config.app_name]) module = ".".join([config.app_name, config.app_name])
sys.path.insert(0, os.getcwd()) sys.path.insert(0, os.getcwd())
app = __import__(module, fromlist=(constants.CompileVars.APP,)) app = __import__(module, fromlist=(constants.CompileVars.APP,))
if reload: if reload:
importlib.reload(app) importlib.reload(app)
return app return app

View File

@ -41,7 +41,7 @@ from reflex.utils import console, format, imports, serializers, types
from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.imports import ImportDict, ImportVar
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.state import State from reflex.state import BaseState
# Set of unique variable names. # Set of unique variable names.
USED_VARIABLES = set() USED_VARIABLES = set()
@ -1472,7 +1472,7 @@ class Var:
) )
) )
def _var_set_state(self, state: Type[State] | str) -> Any: def _var_set_state(self, state: Type[BaseState] | str) -> Any:
"""Set the state of the var. """Set the state of the var.
Args: Args:
@ -1604,14 +1604,14 @@ class BaseVar(Var):
return setter return setter
return ".".join((self._var_data.state, setter)) return ".".join((self._var_data.state, setter))
def get_setter(self) -> Callable[[State, Any], None]: def get_setter(self) -> Callable[[BaseState, Any], None]:
"""Get the var's setter function. """Get the var's setter function.
Returns: Returns:
A function that that creates a setter for the var. A function that that creates a setter for the var.
""" """
def setter(state: State, value: Any): def setter(state: BaseState, value: Any):
"""Get the setter for the var. """Get the setter for the var.
Args: Args:
@ -1643,9 +1643,9 @@ class ComputedVar(Var, property):
def __init__( def __init__(
self, self,
fget: Callable[[State], Any], fget: Callable[[BaseState], Any],
fset: Callable[[State, Any], None] | None = None, fset: Callable[[BaseState, Any], None] | None = None,
fdel: Callable[[State], Any] | None = None, fdel: Callable[[BaseState], Any] | None = None,
doc: str | None = None, doc: str | None = None,
**kwargs, **kwargs,
): ):

View File

@ -5,6 +5,7 @@ from _typeshed import Incomplete
from reflex import constants as constants from reflex import constants as constants
from reflex.base import Base as Base from reflex.base import Base as Base
from reflex.state import State as State from reflex.state import State as State
from reflex.state import BaseState as BaseState
from reflex.utils import console as console, format as format, types as types from reflex.utils import console as console, format as format, types as types
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
from types import FunctionType from types import FunctionType
@ -110,7 +111,7 @@ class Var:
def as_ref(self) -> Var: ... def as_ref(self) -> Var: ...
@property @property
def _var_full_name(self) -> str: ... def _var_full_name(self) -> str: ...
def _var_set_state(self, state: Type[State] | str) -> Any: ... def _var_set_state(self, state: Type[BaseState] | str) -> Any: ...
@dataclass(eq=False) @dataclass(eq=False)
class BaseVar(Var): class BaseVar(Var):
@ -123,7 +124,7 @@ class BaseVar(Var):
def __hash__(self) -> int: ... def __hash__(self) -> int: ...
def get_default_value(self) -> Any: ... def get_default_value(self) -> Any: ...
def get_setter_name(self, include_state: bool = ...) -> str: ... def get_setter_name(self, include_state: bool = ...) -> str: ...
def get_setter(self) -> Callable[[State, Any], None]: ... def get_setter(self) -> Callable[[BaseState, Any], None]: ...
@dataclass(init=False) @dataclass(init=False)
class ComputedVar(Var): class ComputedVar(Var):

View File

@ -1 +1,4 @@
"""Root directory for tests.""" """Root directory for tests."""
import os
from reflex import constants

View File

@ -2,7 +2,7 @@
import pytest import pytest
from reflex.components.base.script import Script from reflex.components.base.script import Script
from reflex.state import State from reflex.state import BaseState
def test_script_inline(): def test_script_inline():
@ -31,7 +31,7 @@ def test_script_neither():
Script.create() Script.create()
class EvState(State): class EvState(BaseState):
"""State for testing event handlers.""" """State for testing event handlers."""
def on_ready(self): def on_ready(self):

View File

@ -5,6 +5,7 @@ import pandas as pd
import pytest import pytest
import reflex as rx import reflex as rx
from reflex.state import BaseState
@pytest.fixture @pytest.fixture
@ -18,7 +19,7 @@ def data_table_state(request):
The data table state class. The data table state class.
""" """
class DataTableState(rx.State): class DataTableState(BaseState):
data = request.param["data"] data = request.param["data"]
columns = ["column1", "column2"] columns = ["column1", "column2"]
@ -33,7 +34,7 @@ def data_table_state2():
The data table state class. The data table state class.
""" """
class DataTableState(rx.State): class DataTableState(BaseState):
_data = pd.DataFrame() _data = pd.DataFrame()
@rx.var @rx.var
@ -51,7 +52,7 @@ def data_table_state3():
The data table state class. The data table state class.
""" """
class DataTableState(rx.State): class DataTableState(BaseState):
_data: List = [] _data: List = []
_columns: List = ["col1", "col2"] _columns: List = ["col1", "col2"]
@ -74,7 +75,7 @@ def data_table_state4():
The data table state class. The data table state class.
""" """
class DataTableState(rx.State): class DataTableState(BaseState):
_data: List = [] _data: List = []
_columns: List = ["col1", "col2"] _columns: List = ["col1", "col2"]

View File

@ -4,12 +4,12 @@ from typing import List, Tuple
import pytest import pytest
from reflex.components.datadisplay.table import Tbody, Tfoot, Thead from reflex.components.datadisplay.table import Tbody, Tfoot, Thead
from reflex.state import State from reflex.state import BaseState
PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8 PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8
class TableState(State): class TableState(BaseState):
"""Test State class.""" """Test State class."""
rows_List_List_str: List[List[str]] = [["random", "row"]] rows_List_List_str: List[List[str]] = [["random", "row"]]

View File

@ -3,6 +3,7 @@
import pytest import pytest
import reflex as rx import reflex as rx
from reflex.state import BaseState
from reflex.vars import BaseVar from reflex.vars import BaseVar
@ -24,7 +25,7 @@ def test_render_many_child():
_ = rx.debounce_input("foo", "bar").render() _ = rx.debounce_input("foo", "bar").render()
class S(rx.State): class S(BaseState):
"""Example state for debounce tests.""" """Example state for debounce tests."""
value: str = "" value: str = ""

View File

@ -15,12 +15,13 @@ from reflex.components.layout.responsive import (
tablet_only, tablet_only,
) )
from reflex.components.typography.text import Text from reflex.components.typography.text import Text
from reflex.state import BaseState
from reflex.vars import Var from reflex.vars import Var
@pytest.fixture @pytest.fixture
def cond_state(request): def cond_state(request):
class CondState(rx.State): class CondState(BaseState):
value: request.param["value_type"] = request.param["value"] # noqa value: request.param["value_type"] = request.param["value"] # noqa
return CondState return CondState

View File

@ -4,10 +4,10 @@ import pytest
from reflex.components import box, foreach, text from reflex.components import box, foreach, text
from reflex.components.layout import Foreach from reflex.components.layout import Foreach
from reflex.state import State from reflex.state import BaseState
class ForEachState(State): class ForEachState(BaseState):
"""A state for testing the ForEach component.""" """A state for testing the ForEach component."""
colors_list: List[str] = ["red", "yellow"] colors_list: List[str] = ["red", "yellow"]

View File

@ -14,7 +14,7 @@ from reflex.components.component import (
from reflex.components.layout.box import Box from reflex.components.layout.box import Box
from reflex.constants import EventTriggers from reflex.constants import EventTriggers
from reflex.event import EventChain, EventHandler from reflex.event import EventChain, EventHandler
from reflex.state import State from reflex.state import BaseState
from reflex.style import Style from reflex.style import Style
from reflex.utils import imports from reflex.utils import imports
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
@ -23,7 +23,7 @@ from reflex.vars import Var, VarData
@pytest.fixture @pytest.fixture
def test_state(): def test_state():
class TestState(State): class TestState(BaseState):
num: int num: int
def do_something(self): def do_something(self):
@ -400,7 +400,7 @@ def test_get_event_triggers(component1, component2):
) )
class C1State(State): class C1State(BaseState):
"""State for testing C1 component.""" """State for testing C1 component."""
def mock_handler(self, _e, _bravo, _charlie): def mock_handler(self, _e, _bravo, _charlie):

View File

@ -8,7 +8,6 @@ from typing import Dict, Generator
import pytest import pytest
import reflex as rx
from reflex.app import App from reflex.app import App
from reflex.event import EventSpec from reflex.event import EventSpec
@ -225,23 +224,3 @@ def token() -> str:
A fresh/unique token string. A fresh/unique token string.
""" """
return str(uuid.uuid4()) return str(uuid.uuid4())
@pytest.fixture
def duplicate_substate():
"""Create a Test state that has duplicate child substates.
Returns:
The test state.
"""
class TestState(rx.State):
pass
class ChildTestState(TestState): # type: ignore # noqa
pass
class ChildTestState(TestState): # type: ignore # noqa
pass
return TestState

View File

@ -2,13 +2,14 @@ from typing import Any, Dict
import pytest import pytest
from reflex import constants
from reflex.app import App from reflex.app import App
from reflex.constants import CompileVars from reflex.constants import CompileVars
from reflex.middleware.hydrate_middleware import HydrateMiddleware from reflex.middleware.hydrate_middleware import HydrateMiddleware
from reflex.state import State, StateUpdate from reflex.state import BaseState, StateUpdate
def exp_is_hydrated(state: State) -> Dict[str, Any]: def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
Args: Args:
@ -20,7 +21,7 @@ def exp_is_hydrated(state: State) -> Dict[str, Any]:
return {state.get_name(): {CompileVars.IS_HYDRATED: True}} return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
class TestState(State): class TestState(BaseState):
"""A test state with no return in handler.""" """A test state with no return in handler."""
__test__ = False __test__ = False
@ -32,7 +33,7 @@ class TestState(State):
self.num += 1 self.num += 1
class TestState2(State): class TestState2(BaseState):
"""A test state with return in handler.""" """A test state with return in handler."""
__test__ = False __test__ = False
@ -54,7 +55,7 @@ class TestState2(State):
self.name = "random" self.name = "random"
class TestState3(State): class TestState3(BaseState):
"""A test state with async handler.""" """A test state with async handler."""
__test__ = False __test__ = False
@ -97,6 +98,9 @@ async def test_preprocess(
event_fixture: The event fixture(an Event). event_fixture: The event fixture(an Event).
expected: Expected delta. expected: Expected delta.
""" """
test_state.add_var(
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
)
app = App(state=test_state, load_events={"index": [test_state.test_handler]}) app = App(state=test_state, load_events={"index": [test_state.test_handler]})
state = test_state() state = test_state()

View File

@ -1,10 +1,12 @@
"""Common rx.State subclasses for use in tests.""" """Common rx.BaseState subclasses for use in tests."""
import reflex as rx import reflex as rx
from reflex.state import BaseState
from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
from .upload import ( from .upload import (
ChildFileUploadState, ChildFileUploadState,
FileStateBase1, FileStateBase1,
FileStateBase2,
FileUploadState, FileUploadState,
GrandChildFileUploadState, GrandChildFileUploadState,
SubUploadState, SubUploadState,
@ -12,7 +14,7 @@ from .upload import (
) )
class GenState(rx.State): class GenState(BaseState):
"""A state with event handlers that generate multiple updates.""" """A state with event handlers that generate multiple updates."""
value: int value: int

View File

@ -3,9 +3,10 @@
from typing import Dict, List, Set, Union from typing import Dict, List, Set, Union
import reflex as rx import reflex as rx
from reflex.state import BaseState
class DictMutationTestState(rx.State): class DictMutationTestState(BaseState):
"""A state for testing ReflexDict mutation.""" """A state for testing ReflexDict mutation."""
# plain dict # plain dict
@ -62,7 +63,7 @@ class DictMutationTestState(rx.State):
self.friend_in_nested_dict["friend"]["age"] = 30 self.friend_in_nested_dict["friend"]["age"] = 30
class ListMutationTestState(rx.State): class ListMutationTestState(BaseState):
"""A state for testing ReflexList mutation.""" """A state for testing ReflexList mutation."""
# plain list # plain list
@ -144,7 +145,7 @@ class CustomVar(rx.Base):
custom: OtherBase = OtherBase() custom: OtherBase = OtherBase()
class MutableTestState(rx.State): class MutableTestState(BaseState):
"""A test state.""" """A test state."""
array: List[Union[str, List, Dict[str, str]]] = [ array: List[Union[str, List, Dict[str, str]]] = [

View File

@ -3,9 +3,10 @@ from pathlib import Path
from typing import ClassVar, List from typing import ClassVar, List
import reflex as rx import reflex as rx
from reflex.state import BaseState, State
class UploadState(rx.State): class UploadState(BaseState):
"""The base state for uploading a file.""" """The base state for uploading a file."""
async def handle_upload1(self, files: List[rx.UploadFile]): async def handle_upload1(self, files: List[rx.UploadFile]):
@ -17,7 +18,7 @@ class UploadState(rx.State):
pass pass
class BaseState(rx.State): class BaseState(BaseState):
"""The test base state.""" """The test base state."""
pass pass
@ -37,7 +38,7 @@ class SubUploadState(BaseState):
pass pass
class FileUploadState(rx.State): class FileUploadState(State):
"""The base state for uploading a file.""" """The base state for uploading a file."""
img_list: List[str] img_list: List[str]
@ -79,7 +80,7 @@ class FileUploadState(rx.State):
pass pass
class FileStateBase1(rx.State): class FileStateBase1(State):
"""The base state for a child FileUploadState.""" """The base state for a child FileUploadState."""
pass pass

View File

@ -28,7 +28,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text
from reflex.event import Event, get_hydrate_event from reflex.event import Event, get_hydrate_event
from reflex.middleware import HydrateMiddleware from reflex.middleware import HydrateMiddleware
from reflex.model import Model from reflex.model import Model
from reflex.state import RouterData, State, StateManagerRedis, StateUpdate from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
from reflex.style import Style from reflex.style import Style
from reflex.utils import format from reflex.utils import format
from reflex.vars import ComputedVar from reflex.vars import ComputedVar
@ -43,7 +43,7 @@ from .states import (
) )
class EmptyState(State): class EmptyState(BaseState):
"""An empty state.""" """An empty state."""
pass pass
@ -77,14 +77,14 @@ def about_page():
return about return about
class ATestState(State): class ATestState(BaseState):
"""A simple state for testing.""" """A simple state for testing."""
var: int var: int
@pytest.fixture() @pytest.fixture()
def test_state() -> Type[State]: def test_state() -> Type[BaseState]:
"""A default state. """A default state.
Returns: Returns:
@ -94,14 +94,14 @@ def test_state() -> Type[State]:
@pytest.fixture() @pytest.fixture()
def redundant_test_state() -> Type[State]: def redundant_test_state() -> Type[BaseState]:
"""A default state. """A default state.
Returns: Returns:
A default state. A default state.
""" """
class RedundantTestState(State): class RedundantTestState(BaseState):
var: int var: int
return RedundantTestState return RedundantTestState
@ -198,12 +198,12 @@ def test_default_app(app: App):
def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
"""Test that an error is thrown when multiple classes subclass rx.State. """Test that an error is thrown when multiple classes subclass rx.BaseState.
Args: Args:
monkeypatch: Pytest monkeypatch object. monkeypatch: Pytest monkeypatch object.
test_state: A test state subclassing rx.State. test_state: A test state subclassing rx.BaseState.
redundant_test_state: Another test state subclassing rx.State. redundant_test_state: Another test state subclassing rx.BaseState.
""" """
monkeypatch.delenv(constants.PYTEST_CURRENT_TEST) monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -705,12 +705,12 @@ async def test_dict_mutation_detection__plain_list(
[ [
( (
FileUploadState, FileUploadState,
{"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}}, {"state.file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
), ),
( (
ChildFileUploadState, ChildFileUploadState,
{ {
"file_state_base1.child_file_upload_state": { "state.file_state_base1.child_file_upload_state": {
"img_list": ["image1.jpg", "image2.jpg"] "img_list": ["image1.jpg", "image2.jpg"]
} }
}, },
@ -718,14 +718,14 @@ async def test_dict_mutation_detection__plain_list(
( (
GrandChildFileUploadState, GrandChildFileUploadState,
{ {
"file_state_base1.file_state_base2.grand_child_file_upload_state": { "state.file_state_base1.file_state_base2.grand_child_file_upload_state": {
"img_list": ["image1.jpg", "image2.jpg"] "img_list": ["image1.jpg", "image2.jpg"]
} }
}, },
), ),
], ],
) )
async def test_upload_file(tmp_path, state, delta, token: str): async def test_upload_file(tmp_path, state, delta, token: str, mocker):
"""Test that file upload works correctly. """Test that file upload works correctly.
Args: Args:
@ -733,10 +733,15 @@ async def test_upload_file(tmp_path, state, delta, token: str):
state: The state class. state: The state class.
delta: Expected delta delta: Expected delta
token: a Token. token: a Token.
mocker: pytest mocker object.
""" """
mocker.patch(
"reflex.state.State.class_subclasses",
{state if state is FileUploadState else FileStateBase1},
)
state._tmp_path = tmp_path state._tmp_path = tmp_path
# The App state must be the "root" of the state tree # The App state must be the "root" of the state tree
app = App(state=state if state is FileUploadState else FileStateBase1) app = App(state=State)
app.event_namespace.emit = AsyncMock() # type: ignore app.event_namespace.emit = AsyncMock() # type: ignore
current_state = await app.state_manager.get_state(token) current_state = await app.state_manager.get_state(token)
data = b"This is binary data" data = b"This is binary data"
@ -749,7 +754,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
request_mock = unittest.mock.Mock() request_mock = unittest.mock.Mock()
request_mock.headers = { request_mock.headers = {
"reflex-client-token": token, "reflex-client-token": token,
"reflex-event-handler": f"{state_name}.multi_handle_upload", "reflex-event-handler": f"state.{state_name}.multi_handle_upload",
} }
file1 = UploadFile( file1 = UploadFile(
@ -851,7 +856,7 @@ async def test_upload_file_background(state, tmp_path, token):
await app.state_manager.redis.close() await app.state_manager.redis.close()
class DynamicState(State): class DynamicState(BaseState):
"""State class for testing dynamic route var. """State class for testing dynamic route var.
This is defined at module level because event handlers cannot be addressed This is defined at module level because event handlers cannot be addressed
@ -891,9 +896,7 @@ class DynamicState(State):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dynamic_route_var_route_change_completed_on_load( async def test_dynamic_route_var_route_change_completed_on_load(
index_page, index_page, windows_platform: bool, token: str, mocker
windows_platform: bool,
token: str,
): ):
"""Create app with dynamic route var, and simulate navigation. """Create app with dynamic route var, and simulate navigation.
@ -904,7 +907,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
index_page: The index page. index_page: The index page.
windows_platform: Whether the system is windows. windows_platform: Whether the system is windows.
token: a Token. token: a Token.
mocker: pytest mocker object.
""" """
mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
DynamicState.add_var(
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
)
arg_name = "dynamic" arg_name = "dynamic"
route = f"/test/[{arg_name}]" route = f"/test/[{arg_name}]"
if windows_platform: if windows_platform:

View File

@ -4,7 +4,7 @@ import pytest
from reflex import event from reflex import event
from reflex.event import Event, EventHandler, EventSpec, fix_events from reflex.event import Event, EventHandler, EventSpec, fix_events
from reflex.state import State from reflex.state import BaseState
from reflex.utils import format from reflex.utils import format
from reflex.vars import Var from reflex.vars import Var
@ -303,7 +303,7 @@ def test_event_actions():
def test_event_actions_on_state(): def test_event_actions_on_state():
class EventActionState(State): class EventActionState(BaseState):
def handler(self): def handler(self):
pass pass

View File

@ -19,11 +19,11 @@ from reflex.base import Base
from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.constants import CompileVars, RouteVar, SocketEvent
from reflex.event import Event, EventHandler from reflex.event import Event, EventHandler
from reflex.state import ( from reflex.state import (
BaseState,
ImmutableStateError, ImmutableStateError,
LockExpiredError, LockExpiredError,
MutableProxy, MutableProxy,
RouterData, RouterData,
State,
StateManager, StateManager,
StateManagerMemory, StateManagerMemory,
StateManagerRedis, StateManagerRedis,
@ -75,7 +75,7 @@ class Object(Base):
prop2: str = "hello" prop2: str = "hello"
class TestState(State): class TestState(BaseState):
"""A test state.""" """A test state."""
# Set this class as not test one # Set this class as not test one
@ -148,7 +148,7 @@ class GrandchildState(ChildState):
pass pass
class DateTimeState(State): class DateTimeState(BaseState):
"""A State with some datetime fields.""" """A State with some datetime fields."""
d: datetime.date = datetime.date.fromisoformat("1989-11-09") d: datetime.date = datetime.date.fromisoformat("1989-11-09")
@ -253,7 +253,6 @@ def test_class_vars(test_state):
""" """
cls = type(test_state) cls = type(test_state)
assert set(cls.vars.keys()) == { assert set(cls.vars.keys()) == {
CompileVars.IS_HYDRATED, # added by hydrate_middleware to all State
"router", "router",
"num1", "num1",
"num2", "num2",
@ -641,7 +640,6 @@ def test_reset(test_state, child_state):
"obj", "obj",
"upper", "upper",
"complex", "complex",
"is_hydrated",
"fig", "fig",
"key", "key",
"sum", "sum",
@ -837,7 +835,7 @@ def test_get_query_params(test_state):
def test_add_var(): def test_add_var():
class DynamicState(State): class DynamicState(BaseState):
pass pass
ds1 = DynamicState() ds1 = DynamicState()
@ -870,7 +868,7 @@ def test_add_var_default_handlers(test_state):
assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler) assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler)
class InterdependentState(State): class InterdependentState(BaseState):
"""A state with 3 vars and 3 computed vars. """A state with 3 vars and 3 computed vars.
x: a variable that no computed var depends on x: a variable that no computed var depends on
@ -915,7 +913,7 @@ class InterdependentState(State):
@pytest.fixture @pytest.fixture
def interdependent_state() -> State: def interdependent_state() -> BaseState:
"""A state with varying dependency between vars. """A state with varying dependency between vars.
Returns: Returns:
@ -988,7 +986,7 @@ def test_per_state_backend_var(interdependent_state):
def test_child_state(): def test_child_state():
"""Test that the child state computed vars can reference parent state vars.""" """Test that the child state computed vars can reference parent state vars."""
class MainState(State): class MainState(BaseState):
v: int = 2 v: int = 2
class ChildState(MainState): class ChildState(MainState):
@ -1006,7 +1004,7 @@ def test_child_state():
def test_conditional_computed_vars(): def test_conditional_computed_vars():
"""Test that computed vars can have conditionals.""" """Test that computed vars can have conditionals."""
class MainState(State): class MainState(BaseState):
flag: bool = False flag: bool = False
t1: str = "a" t1: str = "a"
t2: str = "b" t2: str = "b"
@ -1051,7 +1049,7 @@ def test_event_handlers_convert_to_fns(test_state, child_state):
def test_event_handlers_call_other_handlers(): def test_event_handlers_call_other_handlers():
"""Test that event handlers can call other event handlers.""" """Test that event handlers can call other event handlers."""
class MainState(State): class MainState(BaseState):
v: int = 0 v: int = 0
def set_v(self, v: int): def set_v(self, v: int):
@ -1077,7 +1075,7 @@ def test_computed_var_cached():
"""Test that a ComputedVar doesn't recalculate when accessed.""" """Test that a ComputedVar doesn't recalculate when accessed."""
comp_v_calls = 0 comp_v_calls = 0
class ComputedState(State): class ComputedState(BaseState):
v: int = 0 v: int = 0
@rx.cached_var @rx.cached_var
@ -1102,7 +1100,7 @@ def test_computed_var_cached():
def test_computed_var_cached_depends_on_non_cached(): def test_computed_var_cached_depends_on_non_cached():
"""Test that a cached_var is recalculated if it depends on non-cached ComputedVar.""" """Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
class ComputedState(State): class ComputedState(BaseState):
v: int = 0 v: int = 0
@rx.var @rx.var
@ -1144,7 +1142,7 @@ def test_computed_var_depends_on_parent_non_cached():
"""Child state cached_var that depends on parent state un cached var is always recalculated.""" """Child state cached_var that depends on parent state un cached var is always recalculated."""
counter = 0 counter = 0
class ParentState(State): class ParentState(BaseState):
@rx.var @rx.var
def no_cache_v(self) -> int: def no_cache_v(self) -> int:
nonlocal counter nonlocal counter
@ -1165,21 +1163,18 @@ def test_computed_var_depends_on_parent_non_cached():
dict1 = ps.dict() dict1 = ps.dict()
assert dict1[ps.get_full_name()] == { assert dict1[ps.get_full_name()] == {
"no_cache_v": 1, "no_cache_v": 1,
CompileVars.IS_HYDRATED: False,
"router": formatted_router, "router": formatted_router,
} }
assert dict1[cs.get_full_name()] == {"dep_v": 2} assert dict1[cs.get_full_name()] == {"dep_v": 2}
dict2 = ps.dict() dict2 = ps.dict()
assert dict2[ps.get_full_name()] == { assert dict2[ps.get_full_name()] == {
"no_cache_v": 3, "no_cache_v": 3,
CompileVars.IS_HYDRATED: False,
"router": formatted_router, "router": formatted_router,
} }
assert dict2[cs.get_full_name()] == {"dep_v": 4} assert dict2[cs.get_full_name()] == {"dep_v": 4}
dict3 = ps.dict() dict3 = ps.dict()
assert dict3[ps.get_full_name()] == { assert dict3[ps.get_full_name()] == {
"no_cache_v": 5, "no_cache_v": 5,
CompileVars.IS_HYDRATED: False,
"router": formatted_router, "router": formatted_router,
} }
assert dict3[cs.get_full_name()] == {"dep_v": 6} assert dict3[cs.get_full_name()] == {"dep_v": 6}
@ -1195,7 +1190,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
""" """
counter = 0 counter = 0
class HandlerState(State): class HandlerState(BaseState):
x: int = 42 x: int = 42
def handler(self): def handler(self):
@ -1226,7 +1221,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
def test_computed_var_dependencies(): def test_computed_var_dependencies():
"""Test that a ComputedVar correctly tracks its dependencies.""" """Test that a ComputedVar correctly tracks its dependencies."""
class ComputedState(State): class ComputedState(BaseState):
v: int = 0 v: int = 0
w: int = 0 w: int = 0
x: int = 0 x: int = 0
@ -1293,7 +1288,7 @@ def test_computed_var_dependencies():
def test_backend_method(): def test_backend_method():
"""A method with leading underscore should be callable from event handler.""" """A method with leading underscore should be callable from event handler."""
class BackendMethodState(State): class BackendMethodState(BaseState):
def _be_method(self): def _be_method(self):
return True return True
@ -1369,7 +1364,7 @@ def test_error_on_state_method_shadow():
"""Test that an error is thrown when an event handler shadows a state method.""" """Test that an error is thrown when an event handler shadows a state method."""
with pytest.raises(NameError) as err: with pytest.raises(NameError) as err:
class InvalidTest(rx.State): class InvalidTest(BaseState):
def reset(self): def reset(self):
pass pass
@ -1382,7 +1377,7 @@ def test_error_on_state_method_shadow():
def test_state_with_invalid_yield(): def test_state_with_invalid_yield():
"""Test that an error is thrown when a state yields an invalid value.""" """Test that an error is thrown when a state yields an invalid value."""
class StateWithInvalidYield(rx.State): class StateWithInvalidYield(BaseState):
"""A state that yields an invalid value.""" """A state that yields an invalid value."""
def invalid_handler(self): def invalid_handler(self):
@ -1666,7 +1661,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert mcall.kwargs["to"] == grandchild_state.get_sid() assert mcall.kwargs["to"] == grandchild_state.get_sid()
class BackgroundTaskState(State): class BackgroundTaskState(BaseState):
"""A state with a background task.""" """A state with a background task."""
order: List[str] = [] order: List[str] = []
@ -2192,9 +2187,20 @@ def test_mutable_copy_vars(mutable_state, copy_func):
assert not isinstance(var_copy, MutableProxy) assert not isinstance(var_copy, MutableProxy)
def test_duplicate_substate_class(duplicate_substate): def test_duplicate_substate_class(mocker):
mocker.patch("reflex.state.os.environ", {})
with pytest.raises(ValueError): with pytest.raises(ValueError):
duplicate_substate()
class TestState(BaseState):
pass
class ChildTestState(TestState): # type: ignore # noqa
pass
class ChildTestState(TestState): # type: ignore # noqa
pass
return TestState
class Foo(Base): class Foo(Base):
@ -2206,7 +2212,7 @@ class Foo(Base):
def test_json_dumps_with_mutables(): def test_json_dumps_with_mutables():
"""Test that json.dumps works with Base vars inside mutable types.""" """Test that json.dumps works with Base vars inside mutable types."""
class MutableContainsBase(State): class MutableContainsBase(BaseState):
items: List[Foo] = [Foo()] items: List[Foo] = [Foo()]
dict_val = MutableContainsBase().dict() dict_val = MutableContainsBase().dict()
@ -2216,7 +2222,7 @@ def test_json_dumps_with_mutables():
f_formatted_router = str(formatted_router).replace("'", '"') f_formatted_router = str(formatted_router).replace("'", '"')
assert ( assert (
val val
== f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}' == f'{{"{MutableContainsBase.get_full_name()}": {{"items": {f_items}, "router": {f_formatted_router}}}}}'
) )
@ -2225,7 +2231,7 @@ def test_reset_with_mutables():
default = [[0, 0], [0, 1], [1, 1]] default = [[0, 0], [0, 1], [1, 1]]
copied_default = copy.deepcopy(default) copied_default = copy.deepcopy(default)
class MutableResetState(State): class MutableResetState(BaseState):
items: List[List[int]] = default items: List[List[int]] = default
instance = MutableResetState() instance = MutableResetState()
@ -2273,7 +2279,7 @@ class Custom3(Base):
def test_state_union_optional(): def test_state_union_optional():
"""Test that state can be defined with Union and Optional vars.""" """Test that state can be defined with Union and Optional vars."""
class UnionState(State): class UnionState(BaseState):
int_float: Union[int, float] = 0 int_float: Union[int, float] = 0
opt_int: Optional[int] opt_int: Optional[int]
c3: Optional[Custom3] c3: Optional[Custom3]

View File

@ -6,7 +6,7 @@ import pytest
from pandas import DataFrame from pandas import DataFrame
from reflex.base import Base from reflex.base import Base
from reflex.state import State from reflex.state import BaseState
from reflex.vars import ( from reflex.vars import (
BaseVar, BaseVar,
ComputedVar, ComputedVar,
@ -24,12 +24,6 @@ test_vars = [
] ]
class BaseState(State):
"""A Test State."""
val: str = "key"
@pytest.fixture @pytest.fixture
def TestObj(): def TestObj():
class TestObj(Base): class TestObj(Base):
@ -41,7 +35,7 @@ def TestObj():
@pytest.fixture @pytest.fixture
def ParentState(TestObj): def ParentState(TestObj):
class ParentState(State): class ParentState(BaseState):
foo: int foo: int
bar: int bar: int
@ -74,7 +68,7 @@ def GrandChildState(ChildState, TestObj):
@pytest.fixture @pytest.fixture
def StateWithAnyVar(TestObj): def StateWithAnyVar(TestObj):
class StateWithAnyVar(State): class StateWithAnyVar(BaseState):
@ComputedVar @ComputedVar
def var_without_annotation(self) -> typing.Any: def var_without_annotation(self) -> typing.Any:
return TestObj return TestObj
@ -84,7 +78,7 @@ def StateWithAnyVar(TestObj):
@pytest.fixture @pytest.fixture
def StateWithCorrectVarAnnotation(): def StateWithCorrectVarAnnotation():
class StateWithCorrectVarAnnotation(State): class StateWithCorrectVarAnnotation(BaseState):
@ComputedVar @ComputedVar
def var_with_annotation(self) -> str: def var_with_annotation(self) -> str:
return "Correct annotation" return "Correct annotation"
@ -94,7 +88,7 @@ def StateWithCorrectVarAnnotation():
@pytest.fixture @pytest.fixture
def StateWithWrongVarAnnotation(TestObj): def StateWithWrongVarAnnotation(TestObj):
class StateWithWrongVarAnnotation(State): class StateWithWrongVarAnnotation(BaseState):
@ComputedVar @ComputedVar
def var_with_annotation(self) -> str: def var_with_annotation(self) -> str:
return TestObj return TestObj

View File

@ -528,7 +528,6 @@ formatted_router = {
}, },
"dt": "1989-11-09 18:53:00+01:00", "dt": "1989-11-09 18:53:00+01:00",
"fig": [], "fig": [],
"is_hydrated": False,
"key": "", "key": "",
"map_key": "a", "map_key": "a",
"mapping": {"a": [1, 2, 3], "b": [4, 5, 6]}, "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
@ -553,7 +552,6 @@ formatted_router = {
DateTimeState.get_full_name(): { DateTimeState.get_full_name(): {
"d": "1989-11-09", "d": "1989-11-09",
"dt": "1989-11-09 18:53:00+01:00", "dt": "1989-11-09 18:53:00+01:00",
"is_hydrated": False,
"t": "18:53:00+01:00", "t": "18:53:00+01:00",
"td": "11 days, 0:11:00", "td": "11 days, 0:11:00",
"router": formatted_router, "router": formatted_router,

View File

@ -10,7 +10,7 @@ from packaging import version
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.event import EventHandler from reflex.event import EventHandler
from reflex.state import State from reflex.state import BaseState
from reflex.utils import ( from reflex.utils import (
build, build,
prerequisites, prerequisites,
@ -43,7 +43,7 @@ V056 = version.parse("0.5.6")
VMAXPLUS1 = version.parse(get_above_max_version()) VMAXPLUS1 = version.parse(get_above_max_version())
class ExampleTestState(State): class ExampleTestState(BaseState):
"""Test state class.""" """Test state class."""
def test_event_handler(self): def test_event_handler(self):