RED-1052/rx.State as Base State (#2146)
This commit is contained in:
parent
f8395b1fd6
commit
e3ee98098a
@ -93,7 +93,7 @@ def BackgroundTask():
|
||||
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
||||
)
|
||||
|
||||
app = rx.App(state=State)
|
||||
app = rx.App(state=rx.State)
|
||||
app.add_page(index)
|
||||
app.compile()
|
||||
|
||||
|
@ -135,7 +135,7 @@ def CallScript():
|
||||
yield rx.call_script("inline_counter = 0; external_counter = 0")
|
||||
self.reset()
|
||||
|
||||
app = rx.App(state=CallScriptState)
|
||||
app = rx.App(state=rx.State)
|
||||
with open("assets/external.js", "w") as f:
|
||||
f.write(external_scripts)
|
||||
|
||||
|
@ -97,7 +97,7 @@ def ClientSide():
|
||||
rx.box(ClientSideSubSubState.l1s, id="l1s"),
|
||||
)
|
||||
|
||||
app = rx.App(state=ClientSideState)
|
||||
app = rx.App(state=rx.State)
|
||||
app.add_page(index)
|
||||
app.add_page(index, route="/foo")
|
||||
app.compile()
|
||||
@ -263,7 +263,6 @@ async def test_client_side_state(
|
||||
state_var_input.send_keys("c7")
|
||||
input_value_input.send_keys("c7 value")
|
||||
set_sub_state_button.click()
|
||||
|
||||
state_var_input.send_keys("l1")
|
||||
input_value_input.send_keys("l1 value")
|
||||
set_sub_state_button.click()
|
||||
@ -276,7 +275,6 @@ async def test_client_side_state(
|
||||
state_var_input.send_keys("l4")
|
||||
input_value_input.send_keys("l4 value")
|
||||
set_sub_state_button.click()
|
||||
|
||||
state_var_input.send_keys("c1s")
|
||||
input_value_input.send_keys("c1s value")
|
||||
set_sub_sub_state_button.click()
|
||||
@ -285,28 +283,28 @@ async def test_client_side_state(
|
||||
set_sub_sub_state_button.click()
|
||||
|
||||
exp_cookies = {
|
||||
"client_side_state.client_side_sub_state.c1": {
|
||||
"state.client_side_state.client_side_sub_state.c1": {
|
||||
"domain": "localhost",
|
||||
"httpOnly": False,
|
||||
"name": "client_side_state.client_side_sub_state.c1",
|
||||
"name": "state.client_side_state.client_side_sub_state.c1",
|
||||
"path": "/",
|
||||
"sameSite": "Lax",
|
||||
"secure": False,
|
||||
"value": "c1%20value",
|
||||
},
|
||||
"client_side_state.client_side_sub_state.c2": {
|
||||
"state.client_side_state.client_side_sub_state.c2": {
|
||||
"domain": "localhost",
|
||||
"httpOnly": False,
|
||||
"name": "client_side_state.client_side_sub_state.c2",
|
||||
"name": "state.client_side_state.client_side_sub_state.c2",
|
||||
"path": "/",
|
||||
"sameSite": "Lax",
|
||||
"secure": False,
|
||||
"value": "c2%20value",
|
||||
},
|
||||
"client_side_state.client_side_sub_state.c4": {
|
||||
"state.client_side_state.client_side_sub_state.c4": {
|
||||
"domain": "localhost",
|
||||
"httpOnly": False,
|
||||
"name": "client_side_state.client_side_sub_state.c4",
|
||||
"name": "state.client_side_state.client_side_sub_state.c4",
|
||||
"path": "/",
|
||||
"sameSite": "Strict",
|
||||
"secure": False,
|
||||
@ -321,19 +319,19 @@ async def test_client_side_state(
|
||||
"secure": False,
|
||||
"value": "c6%20value",
|
||||
},
|
||||
"client_side_state.client_side_sub_state.c7": {
|
||||
"state.client_side_state.client_side_sub_state.c7": {
|
||||
"domain": "localhost",
|
||||
"httpOnly": False,
|
||||
"name": "client_side_state.client_side_sub_state.c7",
|
||||
"name": "state.client_side_state.client_side_sub_state.c7",
|
||||
"path": "/",
|
||||
"sameSite": "Lax",
|
||||
"secure": False,
|
||||
"value": "c7%20value",
|
||||
},
|
||||
"client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
|
||||
"state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
|
||||
"domain": "localhost",
|
||||
"httpOnly": False,
|
||||
"name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
|
||||
"name": "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
|
||||
"path": "/",
|
||||
"sameSite": "Lax",
|
||||
"secure": False,
|
||||
@ -354,40 +352,45 @@ async def test_client_side_state(
|
||||
input_value_input.send_keys("c3 value")
|
||||
set_sub_state_button.click()
|
||||
AppHarness._poll_for(
|
||||
lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
|
||||
lambda: "state.client_side_state.client_side_sub_state.c3"
|
||||
in cookie_info_map(driver)
|
||||
)
|
||||
c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
|
||||
c3_cookie = cookie_info_map(driver)[
|
||||
"state.client_side_state.client_side_sub_state.c3"
|
||||
]
|
||||
assert c3_cookie.pop("expiry") is not None
|
||||
assert c3_cookie == {
|
||||
"domain": "localhost",
|
||||
"httpOnly": False,
|
||||
"name": "client_side_state.client_side_sub_state.c3",
|
||||
"name": "state.client_side_state.client_side_sub_state.c3",
|
||||
"path": "/",
|
||||
"sameSite": "Lax",
|
||||
"secure": False,
|
||||
"value": "c3%20value",
|
||||
}
|
||||
time.sleep(2) # wait for c3 to expire
|
||||
assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
|
||||
assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map(
|
||||
driver
|
||||
)
|
||||
|
||||
local_storage_items = local_storage.items()
|
||||
local_storage_items.pop("chakra-ui-color-mode", None)
|
||||
assert (
|
||||
local_storage_items.pop("client_side_state.client_side_sub_state.l1")
|
||||
local_storage_items.pop("state.client_side_state.client_side_sub_state.l1")
|
||||
== "l1 value"
|
||||
)
|
||||
assert (
|
||||
local_storage_items.pop("client_side_state.client_side_sub_state.l2")
|
||||
local_storage_items.pop("state.client_side_state.client_side_sub_state.l2")
|
||||
== "l2 value"
|
||||
)
|
||||
assert local_storage_items.pop("l3") == "l3 value"
|
||||
assert (
|
||||
local_storage_items.pop("client_side_state.client_side_sub_state.l4")
|
||||
local_storage_items.pop("state.client_side_state.client_side_sub_state.l4")
|
||||
== "l4 value"
|
||||
)
|
||||
assert (
|
||||
local_storage_items.pop(
|
||||
"client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
|
||||
"state.client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
|
||||
)
|
||||
== "l1s value"
|
||||
)
|
||||
@ -482,12 +485,15 @@ async def test_client_side_state(
|
||||
|
||||
# make sure c5 cookie shows up on the `/foo` route
|
||||
AppHarness._poll_for(
|
||||
lambda: "client_side_state.client_side_sub_state.c5" in cookie_info_map(driver)
|
||||
lambda: "state.client_side_state.client_side_sub_state.c5"
|
||||
in cookie_info_map(driver)
|
||||
)
|
||||
assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
|
||||
assert cookie_info_map(driver)[
|
||||
"state.client_side_state.client_side_sub_state.c5"
|
||||
] == {
|
||||
"domain": "localhost",
|
||||
"httpOnly": False,
|
||||
"name": "client_side_state.client_side_sub_state.c5",
|
||||
"name": "state.client_side_state.client_side_sub_state.c5",
|
||||
"path": "/foo/",
|
||||
"sameSite": "Lax",
|
||||
"secure": False,
|
||||
|
@ -19,7 +19,7 @@ def ConnectionBanner():
|
||||
def index():
|
||||
return rx.text("Hello World")
|
||||
|
||||
app = rx.App(state=State)
|
||||
app = rx.App(state=rx.State)
|
||||
app.add_page(index)
|
||||
app.compile()
|
||||
|
||||
|
@ -56,7 +56,7 @@ def DynamicRoute():
|
||||
def redirect_page():
|
||||
return rx.fragment(rx.text("redirecting..."))
|
||||
|
||||
app = rx.App(state=DynamicState)
|
||||
app = rx.App(state=rx.State)
|
||||
app.add_page(index)
|
||||
app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore
|
||||
app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore
|
||||
@ -143,10 +143,12 @@ def poll_for_order(
|
||||
return await dynamic_route.get_state(token)
|
||||
|
||||
async def _check():
|
||||
return (await _backend_state()).order == exp_order
|
||||
return (await _backend_state()).substates[
|
||||
"dynamic_state"
|
||||
].order == exp_order
|
||||
|
||||
await AppHarness._poll_for_async(_check)
|
||||
assert (await _backend_state()).order == exp_order
|
||||
assert (await _backend_state()).substates["dynamic_state"].order == exp_order
|
||||
|
||||
return _poll_for_order
|
||||
|
||||
|
@ -130,7 +130,7 @@ def TestEventAction():
|
||||
on_click=EventActionState.on_click("outer"), # type: ignore
|
||||
)
|
||||
|
||||
app = rx.App(state=EventActionState)
|
||||
app = rx.App(state=rx.State)
|
||||
app.add_page(index)
|
||||
app.compile()
|
||||
|
||||
@ -211,10 +211,14 @@ def poll_for_order(
|
||||
return await event_action.get_state(token)
|
||||
|
||||
async def _check():
|
||||
return (await _backend_state()).order == exp_order
|
||||
return (await _backend_state()).substates[
|
||||
"event_action_state"
|
||||
].order == exp_order
|
||||
|
||||
await AppHarness._poll_for_async(_check)
|
||||
assert (await _backend_state()).order == exp_order
|
||||
assert (await _backend_state()).substates[
|
||||
"event_action_state"
|
||||
].order == exp_order
|
||||
|
||||
return _poll_for_order
|
||||
|
||||
|
@ -122,7 +122,7 @@ def EventChain():
|
||||
time.sleep(0.5)
|
||||
self.interim_value = "final"
|
||||
|
||||
app = rx.App(state=State)
|
||||
app = rx.App(state=rx.State)
|
||||
|
||||
token_input = rx.input(
|
||||
value=State.router.session.client_token, is_read_only=True, id="token"
|
||||
@ -401,12 +401,12 @@ async def test_event_chain_click(
|
||||
btn.click()
|
||||
|
||||
async def _has_all_events():
|
||||
return len((await event_chain.get_state(token)).event_order) == len(
|
||||
exp_event_order
|
||||
)
|
||||
return len(
|
||||
(await event_chain.get_state(token)).substates["state"].event_order
|
||||
) == len(exp_event_order)
|
||||
|
||||
await AppHarness._poll_for_async(_has_all_events)
|
||||
event_order = (await event_chain.get_state(token)).event_order
|
||||
event_order = (await event_chain.get_state(token)).substates["state"].event_order
|
||||
assert event_order == exp_event_order
|
||||
|
||||
|
||||
@ -453,12 +453,12 @@ async def test_event_chain_on_load(
|
||||
token = assert_token(event_chain, driver)
|
||||
|
||||
async def _has_all_events():
|
||||
return len((await event_chain.get_state(token)).event_order) == len(
|
||||
exp_event_order
|
||||
)
|
||||
return len(
|
||||
(await event_chain.get_state(token)).substates["state"].event_order
|
||||
) == len(exp_event_order)
|
||||
|
||||
await AppHarness._poll_for_async(_has_all_events)
|
||||
backend_state = await event_chain.get_state(token)
|
||||
backend_state = (await event_chain.get_state(token)).substates["state"]
|
||||
assert backend_state.event_order == exp_event_order
|
||||
assert backend_state.is_hydrated is True
|
||||
|
||||
@ -529,12 +529,12 @@ async def test_event_chain_on_mount(
|
||||
unmount_button.click()
|
||||
|
||||
async def _has_all_events():
|
||||
return len((await event_chain.get_state(token)).event_order) == len(
|
||||
exp_event_order
|
||||
)
|
||||
return len(
|
||||
(await event_chain.get_state(token)).substates["state"].event_order
|
||||
) == len(exp_event_order)
|
||||
|
||||
await AppHarness._poll_for_async(_has_all_events)
|
||||
event_order = (await event_chain.get_state(token)).event_order
|
||||
event_order = (await event_chain.get_state(token)).substates["state"].event_order
|
||||
assert event_order == exp_event_order
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ def FormSubmit():
|
||||
def form_submit(self, form_data: dict):
|
||||
self.form_data = form_data
|
||||
|
||||
app = rx.App(state=FormState)
|
||||
app = rx.App(state=rx.State)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
@ -75,7 +75,7 @@ def FormSubmitName():
|
||||
def form_submit(self, form_data: dict):
|
||||
self.form_data = form_data
|
||||
|
||||
app = rx.App(state=FormState)
|
||||
app = rx.App(state=rx.State)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
@ -210,7 +210,7 @@ async def test_submit(driver, form_submit: AppHarness):
|
||||
submit_input.click()
|
||||
|
||||
async def get_form_data():
|
||||
return (await form_submit.get_state(token)).form_data
|
||||
return (await form_submit.get_state(token)).substates["form_state"].form_data
|
||||
|
||||
# wait for the form data to arrive at the backend
|
||||
form_data = await AppHarness._poll_for_async(get_form_data)
|
||||
|
@ -16,7 +16,7 @@ def FullyControlledInput():
|
||||
class State(rx.State):
|
||||
text: str = "initial"
|
||||
|
||||
app = rx.App(state=State)
|
||||
app = rx.App(state=rx.State)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
@ -85,13 +85,15 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
debounce_input.send_keys("foo")
|
||||
time.sleep(0.5)
|
||||
assert debounce_input.get_attribute("value") == "ifoonitial"
|
||||
assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
|
||||
assert (await fully_controlled_input.get_state(token)).substates[
|
||||
"state"
|
||||
].text == "ifoonitial"
|
||||
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
|
||||
|
||||
# clear the input on the backend
|
||||
async with fully_controlled_input.modify_state(token) as state:
|
||||
state.text = ""
|
||||
assert (await fully_controlled_input.get_state(token)).text == ""
|
||||
state.substates["state"].text = ""
|
||||
assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
|
||||
assert (
|
||||
fully_controlled_input.poll_for_value(
|
||||
debounce_input, exp_not_equal="ifoonitial"
|
||||
@ -103,9 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
debounce_input.send_keys("getting testing done")
|
||||
time.sleep(0.5)
|
||||
assert debounce_input.get_attribute("value") == "getting testing done"
|
||||
assert (
|
||||
await fully_controlled_input.get_state(token)
|
||||
).text == "getting testing done"
|
||||
assert (await fully_controlled_input.get_state(token)).substates[
|
||||
"state"
|
||||
].text == "getting testing done"
|
||||
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
|
||||
|
||||
# type into the on_change input
|
||||
@ -113,7 +115,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
time.sleep(0.5)
|
||||
assert debounce_input.get_attribute("value") == "overwrite the state"
|
||||
assert on_change_input.get_attribute("value") == "overwrite the state"
|
||||
assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
|
||||
assert (await fully_controlled_input.get_state(token)).substates[
|
||||
"state"
|
||||
].text == "overwrite the state"
|
||||
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
|
||||
|
||||
clear_button.click()
|
||||
|
@ -42,7 +42,7 @@ def LoginSample():
|
||||
rx.button("Do it", on_click=State.login, id="doit"),
|
||||
)
|
||||
|
||||
app = rx.App(state=State)
|
||||
app = rx.App(state=rx.State)
|
||||
app.add_page(index)
|
||||
app.add_page(login)
|
||||
app.compile()
|
||||
@ -137,6 +137,6 @@ def test_login_flow(
|
||||
logout_button = driver.find_element(By.ID, "logout")
|
||||
logout_button.click()
|
||||
|
||||
assert login_sample._poll_for(lambda: local_storage["state.auth_token"] == "")
|
||||
assert login_sample._poll_for(lambda: local_storage["state.state.auth_token"] == "")
|
||||
with pytest.raises(NoSuchElementException):
|
||||
driver.find_element(By.ID, "auth-token")
|
||||
|
@ -81,7 +81,7 @@ def RadixThemesApp():
|
||||
)
|
||||
|
||||
app = rx.App(
|
||||
state=State,
|
||||
state=rx.State,
|
||||
theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"),
|
||||
)
|
||||
app.add_page(index)
|
||||
|
@ -33,7 +33,7 @@ def ServerSideEvent():
|
||||
def set_value_return_c(self):
|
||||
return rx.set_value("c", "")
|
||||
|
||||
app = rx.App(state=SSState)
|
||||
app = rx.App(state=rx.State)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
|
@ -26,7 +26,7 @@ def Table():
|
||||
|
||||
caption: str = "random caption"
|
||||
|
||||
app = rx.App(state=TableState)
|
||||
app = rx.App(state=rx.State)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
|
@ -113,7 +113,7 @@ def UploadFile():
|
||||
),
|
||||
)
|
||||
|
||||
app = rx.App(state=UploadState)
|
||||
app = rx.App(state=rx.State)
|
||||
app.add_page(index)
|
||||
app.compile()
|
||||
|
||||
@ -192,7 +192,7 @@ async def test_upload_file(
|
||||
|
||||
# look up the backend state and assert on uploaded contents
|
||||
async def get_file_data():
|
||||
return (await upload_file.get_state(token))._file_data
|
||||
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
|
||||
|
||||
file_data = await AppHarness._poll_for_async(get_file_data)
|
||||
assert isinstance(file_data, dict)
|
||||
@ -205,8 +205,8 @@ async def test_upload_file(
|
||||
state = await upload_file.get_state(token)
|
||||
if secondary:
|
||||
# only the secondary form tracks progress and chain events
|
||||
assert state.event_order.count("upload_progress") == 1
|
||||
assert state.event_order.count("chain_event") == 1
|
||||
assert state.substates["upload_state"].event_order.count("upload_progress") == 1
|
||||
assert state.substates["upload_state"].event_order.count("chain_event") == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -251,7 +251,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
|
||||
|
||||
# look up the backend state and assert on uploaded contents
|
||||
async def get_file_data():
|
||||
return (await upload_file.get_state(token))._file_data
|
||||
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
|
||||
|
||||
file_data = await AppHarness._poll_for_async(get_file_data)
|
||||
assert isinstance(file_data, dict)
|
||||
@ -349,7 +349,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
|
||||
|
||||
# look up the backend state and assert on progress
|
||||
state = await upload_file.get_state(token)
|
||||
assert state.progress_dicts
|
||||
assert exp_name not in state._file_data
|
||||
assert state.substates["upload_state"].progress_dicts
|
||||
assert exp_name not in state.substates["upload_state"]._file_data
|
||||
|
||||
target_file.unlink()
|
||||
|
@ -30,7 +30,7 @@ def VarOperations():
|
||||
dict2: dict = {3: 4}
|
||||
html_str: str = "<div>hello</div>"
|
||||
|
||||
app = rx.App(state=VarOperationState)
|
||||
app = rx.App(state=rx.State)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
|
@ -57,6 +57,7 @@ from reflex.route import (
|
||||
verify_route_validity,
|
||||
)
|
||||
from reflex.state import (
|
||||
BaseState,
|
||||
RouterData,
|
||||
State,
|
||||
StateManager,
|
||||
@ -98,7 +99,7 @@ class App(Base):
|
||||
socket_app: Optional[ASGIApp] = None
|
||||
|
||||
# The state class to use for the app.
|
||||
state: Optional[Type[State]] = None
|
||||
state: Optional[Type[BaseState]] = None
|
||||
|
||||
# Class to manage many client states.
|
||||
_state_manager: Optional[StateManager] = None
|
||||
@ -149,25 +150,24 @@ class App(Base):
|
||||
"`connect_error_component` is deprecated, use `overlay_component` instead"
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
state_subclasses = State.__subclasses__()
|
||||
inferred_state = state_subclasses[-1] if state_subclasses else None
|
||||
state_subclasses = BaseState.__subclasses__()
|
||||
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
|
||||
|
||||
# Special case to allow test cases have multiple subclasses of rx.State.
|
||||
# Special case to allow test cases have multiple subclasses of rx.BaseState.
|
||||
if not is_testing_env:
|
||||
# Only one State class is allowed.
|
||||
# Only one Base State class is allowed.
|
||||
if len(state_subclasses) > 1:
|
||||
raise ValueError(
|
||||
"rx.State has been subclassed multiple times. Only one subclass is allowed"
|
||||
"rx.BaseState cannot be subclassed multiple times. use rx.State instead"
|
||||
)
|
||||
|
||||
# verify that provided state is valid
|
||||
if self.state and inferred_state and self.state is not inferred_state:
|
||||
if self.state and self.state is not State:
|
||||
console.warn(
|
||||
f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported."
|
||||
f" Defaulting to root state: ({inferred_state.__name__})"
|
||||
f" Defaulting to root state: ({State.__name__})"
|
||||
)
|
||||
self.state = inferred_state
|
||||
self.state = State
|
||||
# Get the config
|
||||
config = get_config()
|
||||
|
||||
@ -265,7 +265,7 @@ class App(Base):
|
||||
raise ValueError("The state manager has not been initialized.")
|
||||
return self._state_manager
|
||||
|
||||
async def preprocess(self, state: State, event: Event) -> StateUpdate | None:
|
||||
async def preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
|
||||
"""Preprocess the event.
|
||||
|
||||
This is where middleware can modify the event before it is processed.
|
||||
@ -290,7 +290,7 @@ class App(Base):
|
||||
return out # type: ignore
|
||||
|
||||
async def postprocess(
|
||||
self, state: State, event: Event, update: StateUpdate
|
||||
self, state: BaseState, event: Event, update: StateUpdate
|
||||
) -> StateUpdate:
|
||||
"""Postprocess the event.
|
||||
|
||||
@ -764,7 +764,7 @@ class App(Base):
|
||||
future.result()
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||
"""Modify the state out of band.
|
||||
|
||||
Args:
|
||||
@ -792,7 +792,9 @@ class App(Base):
|
||||
sid=state.router.session.session_id,
|
||||
)
|
||||
|
||||
def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
|
||||
def _process_background(
|
||||
self, state: BaseState, event: Event
|
||||
) -> asyncio.Task | None:
|
||||
"""Process an event in the background and emit updates as they arrive.
|
||||
|
||||
Args:
|
||||
|
@ -33,6 +33,7 @@ from reflex.route import (
|
||||
)
|
||||
from reflex.state import (
|
||||
State as State,
|
||||
BaseState as BaseState,
|
||||
StateManager as StateManager,
|
||||
StateUpdate as StateUpdate,
|
||||
)
|
||||
@ -69,7 +70,7 @@ class App(Base):
|
||||
api: FastAPI
|
||||
sio: Optional[AsyncServer]
|
||||
socket_app: Optional[ASGIApp]
|
||||
state: Type[State]
|
||||
state: Type[BaseState]
|
||||
state_manager: StateManager
|
||||
style: ComponentStyle
|
||||
middleware: List[Middleware]
|
||||
|
@ -1,11 +1,42 @@
|
||||
"""Define the base Reflex class."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import os
|
||||
from typing import Any, List, Type
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import ModelField
|
||||
|
||||
from reflex import constants
|
||||
|
||||
|
||||
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
|
||||
"""Ensure that the field's name does not shadow an existing attribute of the model.
|
||||
|
||||
Args:
|
||||
bases: List of base models to check for shadowed attrs.
|
||||
field_name: name of attribute
|
||||
|
||||
Raises:
|
||||
NameError: If state var field shadows another in its parent state
|
||||
"""
|
||||
reload = os.getenv(constants.RELOAD_CONFIG) == "True"
|
||||
for base in bases:
|
||||
try:
|
||||
if not reload and getattr(base, field_name, None):
|
||||
pass
|
||||
except TypeError as te:
|
||||
raise NameError(
|
||||
f'State var "{field_name}" in {base} has been shadowed by a substate var; '
|
||||
f'use a different field name instead".'
|
||||
) from te
|
||||
|
||||
|
||||
# monkeypatch pydantic validate_field_name method to skip validating
|
||||
# shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
|
||||
pydantic.main.validate_field_name = validate_field_name # type: ignore
|
||||
|
||||
|
||||
class Base(pydantic.BaseModel):
|
||||
"""The base class subclassed by all Reflex classes.
|
||||
|
@ -15,7 +15,7 @@ from reflex.components.component import (
|
||||
StatefulComponent,
|
||||
)
|
||||
from reflex.config import get_config
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
from reflex.utils.imports import ImportVar
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ def _compile_theme(theme: dict) -> str:
|
||||
return templates.THEME.render(theme=theme)
|
||||
|
||||
|
||||
def _compile_contexts(state: Optional[Type[State]]) -> str:
|
||||
def _compile_contexts(state: Optional[Type[BaseState]]) -> str:
|
||||
"""Compile the initial state and contexts.
|
||||
|
||||
Args:
|
||||
@ -87,7 +87,7 @@ def _compile_contexts(state: Optional[Type[State]]) -> str:
|
||||
|
||||
def _compile_page(
|
||||
component: Component,
|
||||
state: Type[State],
|
||||
state: Type[BaseState],
|
||||
) -> str:
|
||||
"""Compile the component given the app state.
|
||||
|
||||
@ -337,7 +337,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]:
|
||||
return output_path, code
|
||||
|
||||
|
||||
def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
|
||||
def compile_contexts(state: Optional[Type[BaseState]]) -> tuple[str, str]:
|
||||
"""Compile the initial state / context.
|
||||
|
||||
Args:
|
||||
@ -353,7 +353,7 @@ def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
|
||||
|
||||
|
||||
def compile_page(
|
||||
path: str, component: Component, state: Type[State]
|
||||
path: str, component: Component, state: Type[BaseState]
|
||||
) -> tuple[str, str]:
|
||||
"""Compile a single page.
|
||||
|
||||
|
@ -21,7 +21,7 @@ from reflex.components.base import (
|
||||
Title,
|
||||
)
|
||||
from reflex.components.component import Component, ComponentStyle, CustomComponent
|
||||
from reflex.state import Cookie, LocalStorage, State
|
||||
from reflex.state import BaseState, Cookie, LocalStorage
|
||||
from reflex.style import Style
|
||||
from reflex.utils import console, format, imports, path_ops
|
||||
|
||||
@ -128,7 +128,7 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None)
|
||||
}
|
||||
|
||||
|
||||
def compile_state(state: Type[State]) -> dict:
|
||||
def compile_state(state: Type[BaseState]) -> dict:
|
||||
"""Compile the state of the app.
|
||||
|
||||
Args:
|
||||
@ -170,7 +170,7 @@ def _compile_client_storage_field(
|
||||
|
||||
|
||||
def _compile_client_storage_recursive(
|
||||
state: Type[State],
|
||||
state: Type[BaseState],
|
||||
) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
|
||||
"""Compile the client-side storage for the given state recursively.
|
||||
|
||||
@ -208,7 +208,7 @@ def _compile_client_storage_recursive(
|
||||
return cookies, local_storage
|
||||
|
||||
|
||||
def compile_client_storage(state: Type[State]) -> dict[str, dict]:
|
||||
def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
|
||||
"""Compile the client-side storage for the given state.
|
||||
|
||||
Args:
|
||||
|
@ -6,6 +6,7 @@ from .base import (
|
||||
LOCAL_STORAGE,
|
||||
POLLING_MAX_HTTP_BUFFER_SIZE,
|
||||
PYTEST_CURRENT_TEST,
|
||||
RELOAD_CONFIG,
|
||||
SKIP_COMPILE_ENV_VAR,
|
||||
ColorMode,
|
||||
Dirs,
|
||||
@ -85,6 +86,7 @@ __ALL__ = [
|
||||
PYTEST_CURRENT_TEST,
|
||||
PRODUCTION_BACKEND_URL,
|
||||
Reflex,
|
||||
RELOAD_CONFIG,
|
||||
RequirementsTxt,
|
||||
RouteArgType,
|
||||
RouteRegex,
|
||||
|
@ -173,3 +173,4 @@ SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
|
||||
# Testing variables.
|
||||
# Testing os env set by pytest when running a test case.
|
||||
PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
|
||||
RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"
|
||||
|
@ -22,7 +22,7 @@ from reflex.utils.types import ArgsSpec
|
||||
from reflex.vars import BaseVar, Var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
|
||||
|
||||
class Event(Base):
|
||||
@ -64,7 +64,7 @@ def background(fn):
|
||||
|
||||
|
||||
def _no_chain_background_task(
|
||||
state_cls: Type["State"], name: str, fn: Callable
|
||||
state_cls: Type["BaseState"], name: str, fn: Callable
|
||||
) -> Callable:
|
||||
"""Protect against directly chaining a background task from another event handler.
|
||||
|
||||
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from reflex import constants
|
||||
from reflex.event import Event, fix_events, get_hydrate_event
|
||||
from reflex.middleware.middleware import Middleware
|
||||
from reflex.state import State, StateUpdate
|
||||
from reflex.state import BaseState, StateUpdate
|
||||
from reflex.utils import format
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -17,7 +17,7 @@ class HydrateMiddleware(Middleware):
|
||||
"""Middleware to handle initial app hydration."""
|
||||
|
||||
async def preprocess(
|
||||
self, app: App, state: State, event: Event
|
||||
self, app: App, state: BaseState, event: Event
|
||||
) -> Optional[StateUpdate]:
|
||||
"""Preprocess the event.
|
||||
|
||||
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from reflex.base import Base
|
||||
from reflex.event import Event
|
||||
from reflex.state import State, StateUpdate
|
||||
from reflex.state import BaseState, StateUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.app import App
|
||||
@ -16,7 +16,7 @@ class Middleware(Base, ABC):
|
||||
"""Middleware to preprocess and postprocess requests."""
|
||||
|
||||
async def preprocess(
|
||||
self, app: App, state: State, event: Event
|
||||
self, app: App, state: BaseState, event: Event
|
||||
) -> Optional[StateUpdate]:
|
||||
"""Preprocess the event.
|
||||
|
||||
@ -31,7 +31,7 @@ class Middleware(Base, ABC):
|
||||
return None
|
||||
|
||||
async def postprocess(
|
||||
self, app: App, state: State, event: Event, update: StateUpdate
|
||||
self, app: App, state: BaseState, event: Event, update: StateUpdate
|
||||
) -> StateUpdate:
|
||||
"""Postprocess the event.
|
||||
|
||||
|
112
reflex/state.py
112
reflex/state.py
@ -7,6 +7,7 @@ import copy
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import urllib.parse
|
||||
import uuid
|
||||
@ -81,7 +82,7 @@ class HeaderData(Base):
|
||||
class PageData(Base):
|
||||
"""An object containing page data."""
|
||||
|
||||
host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
|
||||
host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
|
||||
path: str = ""
|
||||
raw_path: str = ""
|
||||
full_path: str = ""
|
||||
@ -152,7 +153,7 @@ RESERVED_BACKEND_VAR_NAMES = {
|
||||
}
|
||||
|
||||
|
||||
class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""The state of the app."""
|
||||
|
||||
# A map from the var name to the var.
|
||||
@ -176,6 +177,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# The event handlers.
|
||||
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
||||
|
||||
# A set of subclassses of this class.
|
||||
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
|
||||
|
||||
# Mapping of var name to set of computed variables that depend on it
|
||||
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
||||
|
||||
@ -189,10 +193,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
_always_dirty_substates: ClassVar[Set[str]] = set()
|
||||
|
||||
# The parent state.
|
||||
parent_state: Optional[State] = None
|
||||
parent_state: Optional[BaseState] = None
|
||||
|
||||
# The substates of the state.
|
||||
substates: Dict[str, State] = {}
|
||||
substates: Dict[str, BaseState] = {}
|
||||
|
||||
# The set of dirty vars.
|
||||
dirty_vars: Set[str] = set()
|
||||
@ -209,10 +213,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# The router data for the current page
|
||||
router: RouterData = RouterData()
|
||||
|
||||
# The hydrated bool.
|
||||
is_hydrated: bool = False
|
||||
|
||||
def __init__(self, *args, parent_state: State | None = None, **kwargs):
|
||||
def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
|
||||
"""Initialize the state.
|
||||
|
||||
Args:
|
||||
@ -220,28 +221,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
parent_state: The parent state.
|
||||
**kwargs: The kwargs to pass to the Pydantic init method.
|
||||
|
||||
Raises:
|
||||
ValueError: If a substate class shadows another.
|
||||
"""
|
||||
kwargs["parent_state"] = parent_state
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Setup the substates.
|
||||
for substate in self.get_substates():
|
||||
substate_name = substate.get_name()
|
||||
if substate_name in self.substates:
|
||||
raise ValueError(
|
||||
f"The substate class '{substate_name}' has been defined multiple times. Shadowing "
|
||||
f"substate classes is not allowed."
|
||||
)
|
||||
self.substates[substate_name] = substate(parent_state=self)
|
||||
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||
# Convert the event handlers to functions.
|
||||
self._init_event_handlers()
|
||||
|
||||
# Create a fresh copy of the backend variables for this instance
|
||||
self._backend_vars = copy.deepcopy(self.backend_vars)
|
||||
|
||||
def _init_event_handlers(self, state: State | None = None):
|
||||
def _init_event_handlers(self, state: BaseState | None = None):
|
||||
"""Initialize event handlers.
|
||||
|
||||
Allow event handlers to be called directly on the instance. This is
|
||||
@ -281,17 +274,44 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
Args:
|
||||
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
||||
|
||||
Raises:
|
||||
ValueError: If a substate class shadows another.
|
||||
"""
|
||||
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Event handlers should not shadow builtin state methods.
|
||||
cls._check_overridden_methods()
|
||||
|
||||
# Reset subclass tracking for this class.
|
||||
cls.class_subclasses = set()
|
||||
|
||||
# Get the parent vars.
|
||||
parent_state = cls.get_parent_state()
|
||||
if parent_state is not None:
|
||||
cls.inherited_vars = parent_state.vars
|
||||
cls.inherited_backend_vars = parent_state.backend_vars
|
||||
|
||||
# Check if another substate class with the same name has already been defined.
|
||||
if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
|
||||
if is_testing_env:
|
||||
# Clear existing subclass with same name when app is reloaded via
|
||||
# utils.prerequisites.get_app(reload=True)
|
||||
parent_state.class_subclasses = set(
|
||||
c
|
||||
for c in parent_state.class_subclasses
|
||||
if c.__name__ != cls.__name__
|
||||
)
|
||||
else:
|
||||
# During normal operation, subclasses cannot have the same name, even if they are
|
||||
# defined in different modules.
|
||||
raise ValueError(
|
||||
f"The substate class '{cls.__name__}' has been defined multiple times. "
|
||||
"Shadowing substate classes is not allowed."
|
||||
)
|
||||
# Track this new subclass in the parent state's subclasses set.
|
||||
parent_state.class_subclasses.add(cls)
|
||||
|
||||
cls.new_backend_vars = {
|
||||
name: value
|
||||
for name, value in cls.__dict__.items()
|
||||
@ -437,7 +457,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def get_parent_state(cls) -> Type[State] | None:
|
||||
def get_parent_state(cls) -> Type[BaseState] | None:
|
||||
"""Get the parent state.
|
||||
|
||||
Returns:
|
||||
@ -446,20 +466,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
parent_states = [
|
||||
base
|
||||
for base in cls.__bases__
|
||||
if types._issubclass(base, State) and base is not State
|
||||
if types._issubclass(base, BaseState) and base is not BaseState
|
||||
]
|
||||
assert len(parent_states) < 2, "Only one parent state is allowed."
|
||||
return parent_states[0] if len(parent_states) == 1 else None # type: ignore
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def get_substates(cls) -> set[Type[State]]:
|
||||
def get_substates(cls) -> set[Type[BaseState]]:
|
||||
"""Get the substates of the state.
|
||||
|
||||
Returns:
|
||||
The substates of the state.
|
||||
"""
|
||||
return set(cls.__subclasses__())
|
||||
return cls.class_subclasses
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
@ -487,7 +506,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
|
||||
def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
|
||||
"""Get the class substate.
|
||||
|
||||
Args:
|
||||
@ -643,7 +662,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""
|
||||
return {
|
||||
func[0]: func[1]
|
||||
for func in inspect.getmembers(State, predicate=inspect.isfunction)
|
||||
for func in inspect.getmembers(BaseState, predicate=inspect.isfunction)
|
||||
if not func[0].startswith("__")
|
||||
}
|
||||
|
||||
@ -909,7 +928,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
for substate in self.substates.values():
|
||||
substate._reset_client_storage()
|
||||
|
||||
def get_substate(self, path: Sequence[str]) -> State | None:
|
||||
def get_substate(self, path: Sequence[str]) -> BaseState | None:
|
||||
"""Get the substate.
|
||||
|
||||
Args:
|
||||
@ -933,7 +952,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
def _get_event_handler(
|
||||
self, event: Event
|
||||
) -> tuple[State | StateProxy, EventHandler]:
|
||||
) -> tuple[BaseState | StateProxy, EventHandler]:
|
||||
"""Get the event handler for the given event.
|
||||
|
||||
Args:
|
||||
@ -1050,7 +1069,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
)
|
||||
|
||||
async def _process_event(
|
||||
self, handler: EventHandler, state: State | StateProxy, payload: Dict
|
||||
self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
|
||||
) -> AsyncIterator[StateUpdate]:
|
||||
"""Process event.
|
||||
|
||||
@ -1263,7 +1282,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
d.update(substate_d)
|
||||
return d
|
||||
|
||||
async def __aenter__(self) -> State:
|
||||
async def __aenter__(self) -> BaseState:
|
||||
"""Enter the async context manager protocol.
|
||||
|
||||
This should not be used for the State class, but exists for
|
||||
@ -1288,6 +1307,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
pass
|
||||
|
||||
|
||||
class State(BaseState):
|
||||
"""The app Base State."""
|
||||
|
||||
# The hydrated bool.
|
||||
is_hydrated: bool = False
|
||||
|
||||
|
||||
class StateProxy(wrapt.ObjectProxy):
|
||||
"""Proxy of a state instance to control mutability of vars for a background task.
|
||||
|
||||
@ -1455,10 +1481,10 @@ class StateManager(Base, ABC):
|
||||
"""A class to manage many client states."""
|
||||
|
||||
# The state class to use.
|
||||
state: Type[State]
|
||||
state: Type[BaseState]
|
||||
|
||||
@classmethod
|
||||
def create(cls, state: Type[State]):
|
||||
def create(cls, state: Type[BaseState]):
|
||||
"""Create a new state manager.
|
||||
|
||||
Args:
|
||||
@ -1473,7 +1499,7 @@ class StateManager(Base, ABC):
|
||||
return StateManagerMemory(state=state)
|
||||
|
||||
@abstractmethod
|
||||
async def get_state(self, token: str) -> State:
|
||||
async def get_state(self, token: str) -> BaseState:
|
||||
"""Get the state for a token.
|
||||
|
||||
Args:
|
||||
@ -1485,7 +1511,7 @@ class StateManager(Base, ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_state(self, token: str, state: State):
|
||||
async def set_state(self, token: str, state: BaseState):
|
||||
"""Set the state for a token.
|
||||
|
||||
Args:
|
||||
@ -1496,7 +1522,7 @@ class StateManager(Base, ABC):
|
||||
|
||||
@abstractmethod
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||
"""Modify the state for a token while holding exclusive lock.
|
||||
|
||||
Args:
|
||||
@ -1512,7 +1538,7 @@ class StateManagerMemory(StateManager):
|
||||
"""A state manager that stores states in memory."""
|
||||
|
||||
# The mapping of client ids to states.
|
||||
states: Dict[str, State] = {}
|
||||
states: Dict[str, BaseState] = {}
|
||||
|
||||
# The mutex ensures the dict of mutexes is updated exclusively
|
||||
_state_manager_lock = asyncio.Lock()
|
||||
@ -1527,7 +1553,7 @@ class StateManagerMemory(StateManager):
|
||||
"_states_locks": {"exclude": True},
|
||||
}
|
||||
|
||||
async def get_state(self, token: str) -> State:
|
||||
async def get_state(self, token: str) -> BaseState:
|
||||
"""Get the state for a token.
|
||||
|
||||
Args:
|
||||
@ -1540,7 +1566,7 @@ class StateManagerMemory(StateManager):
|
||||
self.states[token] = self.state()
|
||||
return self.states[token]
|
||||
|
||||
async def set_state(self, token: str, state: State):
|
||||
async def set_state(self, token: str, state: BaseState):
|
||||
"""Set the state for a token.
|
||||
|
||||
Args:
|
||||
@ -1550,7 +1576,7 @@ class StateManagerMemory(StateManager):
|
||||
pass
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||
"""Modify the state for a token while holding exclusive lock.
|
||||
|
||||
Args:
|
||||
@ -1598,7 +1624,7 @@ class StateManagerRedis(StateManager):
|
||||
b"evicted",
|
||||
}
|
||||
|
||||
async def get_state(self, token: str) -> State:
|
||||
async def get_state(self, token: str) -> BaseState:
|
||||
"""Get the state for a token.
|
||||
|
||||
Args:
|
||||
@ -1613,7 +1639,9 @@ class StateManagerRedis(StateManager):
|
||||
return await self.get_state(token)
|
||||
return cloudpickle.loads(redis_state)
|
||||
|
||||
async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
|
||||
async def set_state(
|
||||
self, token: str, state: BaseState, lock_id: bytes | None = None
|
||||
):
|
||||
"""Set the state for a token.
|
||||
|
||||
Args:
|
||||
@ -1637,7 +1665,7 @@ class StateManagerRedis(StateManager):
|
||||
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||
"""Modify the state for a token while holding exclusive lock.
|
||||
|
||||
Args:
|
||||
@ -1879,7 +1907,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
|
||||
__mutable_types__ = (list, dict, set, Base)
|
||||
|
||||
def __init__(self, wrapped: Any, state: State, field_name: str):
|
||||
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
||||
"""Create a proxy for a mutable object that tracks changes.
|
||||
|
||||
Args:
|
||||
|
@ -38,7 +38,7 @@ import reflex.utils.build
|
||||
import reflex.utils.exec
|
||||
import reflex.utils.prerequisites
|
||||
import reflex.utils.processes
|
||||
from reflex.state import State, StateManagerMemory, StateManagerRedis
|
||||
from reflex.state import BaseState, State, StateManagerMemory, StateManagerRedis
|
||||
|
||||
try:
|
||||
from selenium import webdriver # pyright: ignore [reportMissingImports]
|
||||
@ -162,6 +162,9 @@ class AppHarness:
|
||||
with chdir(self.app_path):
|
||||
# ensure config and app are reloaded when testing different app
|
||||
reflex.config.get_config(reload=True)
|
||||
# reset rx.State subclasses
|
||||
State.class_subclasses.clear()
|
||||
# self.app_module.app.
|
||||
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
|
||||
self.app_instance = self.app_module.app
|
||||
if isinstance(self.app_instance.state_manager, StateManagerRedis):
|
||||
@ -434,7 +437,7 @@ class AppHarness:
|
||||
self._frontends.append(driver)
|
||||
return driver
|
||||
|
||||
async def get_state(self, token: str) -> State:
|
||||
async def get_state(self, token: str) -> BaseState:
|
||||
"""Get the state associated with the given token.
|
||||
|
||||
Args:
|
||||
@ -561,7 +564,7 @@ class AppHarness:
|
||||
)
|
||||
return element.get_attribute("value")
|
||||
|
||||
def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, reflex.State]:
|
||||
def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
|
||||
"""Poll app state_manager for any connected clients.
|
||||
|
||||
Args:
|
||||
|
@ -124,11 +124,13 @@ def get_app(reload: bool = False) -> ModuleType:
|
||||
Returns:
|
||||
The app based on the default config.
|
||||
"""
|
||||
os.environ[constants.RELOAD_CONFIG] = str(reload)
|
||||
config = get_config()
|
||||
module = ".".join([config.app_name, config.app_name])
|
||||
sys.path.insert(0, os.getcwd())
|
||||
app = __import__(module, fromlist=(constants.CompileVars.APP,))
|
||||
if reload:
|
||||
|
||||
importlib.reload(app)
|
||||
return app
|
||||
|
||||
|
@ -41,7 +41,7 @@ from reflex.utils import console, format, imports, serializers, types
|
||||
from reflex.utils.imports import ImportDict, ImportVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
|
||||
# Set of unique variable names.
|
||||
USED_VARIABLES = set()
|
||||
@ -1472,7 +1472,7 @@ class Var:
|
||||
)
|
||||
)
|
||||
|
||||
def _var_set_state(self, state: Type[State] | str) -> Any:
|
||||
def _var_set_state(self, state: Type[BaseState] | str) -> Any:
|
||||
"""Set the state of the var.
|
||||
|
||||
Args:
|
||||
@ -1604,14 +1604,14 @@ class BaseVar(Var):
|
||||
return setter
|
||||
return ".".join((self._var_data.state, setter))
|
||||
|
||||
def get_setter(self) -> Callable[[State, Any], None]:
|
||||
def get_setter(self) -> Callable[[BaseState, Any], None]:
|
||||
"""Get the var's setter function.
|
||||
|
||||
Returns:
|
||||
A function that that creates a setter for the var.
|
||||
"""
|
||||
|
||||
def setter(state: State, value: Any):
|
||||
def setter(state: BaseState, value: Any):
|
||||
"""Get the setter for the var.
|
||||
|
||||
Args:
|
||||
@ -1643,9 +1643,9 @@ class ComputedVar(Var, property):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fget: Callable[[State], Any],
|
||||
fset: Callable[[State, Any], None] | None = None,
|
||||
fdel: Callable[[State], Any] | None = None,
|
||||
fget: Callable[[BaseState], Any],
|
||||
fset: Callable[[BaseState, Any], None] | None = None,
|
||||
fdel: Callable[[BaseState], Any] | None = None,
|
||||
doc: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -5,6 +5,7 @@ from _typeshed import Incomplete
|
||||
from reflex import constants as constants
|
||||
from reflex.base import Base as Base
|
||||
from reflex.state import State as State
|
||||
from reflex.state import BaseState as BaseState
|
||||
from reflex.utils import console as console, format as format, types as types
|
||||
from reflex.utils.imports import ImportVar
|
||||
from types import FunctionType
|
||||
@ -110,7 +111,7 @@ class Var:
|
||||
def as_ref(self) -> Var: ...
|
||||
@property
|
||||
def _var_full_name(self) -> str: ...
|
||||
def _var_set_state(self, state: Type[State] | str) -> Any: ...
|
||||
def _var_set_state(self, state: Type[BaseState] | str) -> Any: ...
|
||||
|
||||
@dataclass(eq=False)
|
||||
class BaseVar(Var):
|
||||
@ -123,7 +124,7 @@ class BaseVar(Var):
|
||||
def __hash__(self) -> int: ...
|
||||
def get_default_value(self) -> Any: ...
|
||||
def get_setter_name(self, include_state: bool = ...) -> str: ...
|
||||
def get_setter(self) -> Callable[[State, Any], None]: ...
|
||||
def get_setter(self) -> Callable[[BaseState, Any], None]: ...
|
||||
|
||||
@dataclass(init=False)
|
||||
class ComputedVar(Var):
|
||||
|
@ -1 +1,4 @@
|
||||
"""Root directory for tests."""
|
||||
import os
|
||||
|
||||
from reflex import constants
|
||||
|
@ -2,7 +2,7 @@
|
||||
import pytest
|
||||
|
||||
from reflex.components.base.script import Script
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
|
||||
|
||||
def test_script_inline():
|
||||
@ -31,7 +31,7 @@ def test_script_neither():
|
||||
Script.create()
|
||||
|
||||
|
||||
class EvState(State):
|
||||
class EvState(BaseState):
|
||||
"""State for testing event handlers."""
|
||||
|
||||
def on_ready(self):
|
||||
|
@ -5,6 +5,7 @@ import pandas as pd
|
||||
import pytest
|
||||
|
||||
import reflex as rx
|
||||
from reflex.state import BaseState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -18,7 +19,7 @@ def data_table_state(request):
|
||||
The data table state class.
|
||||
"""
|
||||
|
||||
class DataTableState(rx.State):
|
||||
class DataTableState(BaseState):
|
||||
data = request.param["data"]
|
||||
columns = ["column1", "column2"]
|
||||
|
||||
@ -33,7 +34,7 @@ def data_table_state2():
|
||||
The data table state class.
|
||||
"""
|
||||
|
||||
class DataTableState(rx.State):
|
||||
class DataTableState(BaseState):
|
||||
_data = pd.DataFrame()
|
||||
|
||||
@rx.var
|
||||
@ -51,7 +52,7 @@ def data_table_state3():
|
||||
The data table state class.
|
||||
"""
|
||||
|
||||
class DataTableState(rx.State):
|
||||
class DataTableState(BaseState):
|
||||
_data: List = []
|
||||
_columns: List = ["col1", "col2"]
|
||||
|
||||
@ -74,7 +75,7 @@ def data_table_state4():
|
||||
The data table state class.
|
||||
"""
|
||||
|
||||
class DataTableState(rx.State):
|
||||
class DataTableState(BaseState):
|
||||
_data: List = []
|
||||
_columns: List = ["col1", "col2"]
|
||||
|
||||
|
@ -4,12 +4,12 @@ from typing import List, Tuple
|
||||
import pytest
|
||||
|
||||
from reflex.components.datadisplay.table import Tbody, Tfoot, Thead
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
|
||||
PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8
|
||||
|
||||
|
||||
class TableState(State):
|
||||
class TableState(BaseState):
|
||||
"""Test State class."""
|
||||
|
||||
rows_List_List_str: List[List[str]] = [["random", "row"]]
|
||||
|
@ -3,6 +3,7 @@
|
||||
import pytest
|
||||
|
||||
import reflex as rx
|
||||
from reflex.state import BaseState
|
||||
from reflex.vars import BaseVar
|
||||
|
||||
|
||||
@ -24,7 +25,7 @@ def test_render_many_child():
|
||||
_ = rx.debounce_input("foo", "bar").render()
|
||||
|
||||
|
||||
class S(rx.State):
|
||||
class S(BaseState):
|
||||
"""Example state for debounce tests."""
|
||||
|
||||
value: str = ""
|
||||
|
@ -15,12 +15,13 @@ from reflex.components.layout.responsive import (
|
||||
tablet_only,
|
||||
)
|
||||
from reflex.components.typography.text import Text
|
||||
from reflex.state import BaseState
|
||||
from reflex.vars import Var
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cond_state(request):
|
||||
class CondState(rx.State):
|
||||
class CondState(BaseState):
|
||||
value: request.param["value_type"] = request.param["value"] # noqa
|
||||
|
||||
return CondState
|
||||
|
@ -4,10 +4,10 @@ import pytest
|
||||
|
||||
from reflex.components import box, foreach, text
|
||||
from reflex.components.layout import Foreach
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
|
||||
|
||||
class ForEachState(State):
|
||||
class ForEachState(BaseState):
|
||||
"""A state for testing the ForEach component."""
|
||||
|
||||
colors_list: List[str] = ["red", "yellow"]
|
||||
|
@ -14,7 +14,7 @@ from reflex.components.component import (
|
||||
from reflex.components.layout.box import Box
|
||||
from reflex.constants import EventTriggers
|
||||
from reflex.event import EventChain, EventHandler
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
from reflex.style import Style
|
||||
from reflex.utils import imports
|
||||
from reflex.utils.imports import ImportVar
|
||||
@ -23,7 +23,7 @@ from reflex.vars import Var, VarData
|
||||
|
||||
@pytest.fixture
|
||||
def test_state():
|
||||
class TestState(State):
|
||||
class TestState(BaseState):
|
||||
num: int
|
||||
|
||||
def do_something(self):
|
||||
@ -400,7 +400,7 @@ def test_get_event_triggers(component1, component2):
|
||||
)
|
||||
|
||||
|
||||
class C1State(State):
|
||||
class C1State(BaseState):
|
||||
"""State for testing C1 component."""
|
||||
|
||||
def mock_handler(self, _e, _bravo, _charlie):
|
||||
|
@ -8,7 +8,6 @@ from typing import Dict, Generator
|
||||
|
||||
import pytest
|
||||
|
||||
import reflex as rx
|
||||
from reflex.app import App
|
||||
from reflex.event import EventSpec
|
||||
|
||||
@ -225,23 +224,3 @@ def token() -> str:
|
||||
A fresh/unique token string.
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def duplicate_substate():
|
||||
"""Create a Test state that has duplicate child substates.
|
||||
|
||||
Returns:
|
||||
The test state.
|
||||
"""
|
||||
|
||||
class TestState(rx.State):
|
||||
pass
|
||||
|
||||
class ChildTestState(TestState): # type: ignore # noqa
|
||||
pass
|
||||
|
||||
class ChildTestState(TestState): # type: ignore # noqa
|
||||
pass
|
||||
|
||||
return TestState
|
||||
|
@ -2,13 +2,14 @@ from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from reflex import constants
|
||||
from reflex.app import App
|
||||
from reflex.constants import CompileVars
|
||||
from reflex.middleware.hydrate_middleware import HydrateMiddleware
|
||||
from reflex.state import State, StateUpdate
|
||||
from reflex.state import BaseState, StateUpdate
|
||||
|
||||
|
||||
def exp_is_hydrated(state: State) -> Dict[str, Any]:
|
||||
def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
|
||||
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
|
||||
|
||||
Args:
|
||||
@ -20,7 +21,7 @@ def exp_is_hydrated(state: State) -> Dict[str, Any]:
|
||||
return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
|
||||
|
||||
|
||||
class TestState(State):
|
||||
class TestState(BaseState):
|
||||
"""A test state with no return in handler."""
|
||||
|
||||
__test__ = False
|
||||
@ -32,7 +33,7 @@ class TestState(State):
|
||||
self.num += 1
|
||||
|
||||
|
||||
class TestState2(State):
|
||||
class TestState2(BaseState):
|
||||
"""A test state with return in handler."""
|
||||
|
||||
__test__ = False
|
||||
@ -54,7 +55,7 @@ class TestState2(State):
|
||||
self.name = "random"
|
||||
|
||||
|
||||
class TestState3(State):
|
||||
class TestState3(BaseState):
|
||||
"""A test state with async handler."""
|
||||
|
||||
__test__ = False
|
||||
@ -97,6 +98,9 @@ async def test_preprocess(
|
||||
event_fixture: The event fixture(an Event).
|
||||
expected: Expected delta.
|
||||
"""
|
||||
test_state.add_var(
|
||||
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
|
||||
)
|
||||
app = App(state=test_state, load_events={"index": [test_state.test_handler]})
|
||||
state = test_state()
|
||||
|
||||
|
@ -1,10 +1,12 @@
|
||||
"""Common rx.State subclasses for use in tests."""
|
||||
"""Common rx.BaseState subclasses for use in tests."""
|
||||
import reflex as rx
|
||||
from reflex.state import BaseState
|
||||
|
||||
from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
|
||||
from .upload import (
|
||||
ChildFileUploadState,
|
||||
FileStateBase1,
|
||||
FileStateBase2,
|
||||
FileUploadState,
|
||||
GrandChildFileUploadState,
|
||||
SubUploadState,
|
||||
@ -12,7 +14,7 @@ from .upload import (
|
||||
)
|
||||
|
||||
|
||||
class GenState(rx.State):
|
||||
class GenState(BaseState):
|
||||
"""A state with event handlers that generate multiple updates."""
|
||||
|
||||
value: int
|
||||
|
@ -3,9 +3,10 @@
|
||||
from typing import Dict, List, Set, Union
|
||||
|
||||
import reflex as rx
|
||||
from reflex.state import BaseState
|
||||
|
||||
|
||||
class DictMutationTestState(rx.State):
|
||||
class DictMutationTestState(BaseState):
|
||||
"""A state for testing ReflexDict mutation."""
|
||||
|
||||
# plain dict
|
||||
@ -62,7 +63,7 @@ class DictMutationTestState(rx.State):
|
||||
self.friend_in_nested_dict["friend"]["age"] = 30
|
||||
|
||||
|
||||
class ListMutationTestState(rx.State):
|
||||
class ListMutationTestState(BaseState):
|
||||
"""A state for testing ReflexList mutation."""
|
||||
|
||||
# plain list
|
||||
@ -144,7 +145,7 @@ class CustomVar(rx.Base):
|
||||
custom: OtherBase = OtherBase()
|
||||
|
||||
|
||||
class MutableTestState(rx.State):
|
||||
class MutableTestState(BaseState):
|
||||
"""A test state."""
|
||||
|
||||
array: List[Union[str, List, Dict[str, str]]] = [
|
||||
|
@ -3,9 +3,10 @@ from pathlib import Path
|
||||
from typing import ClassVar, List
|
||||
|
||||
import reflex as rx
|
||||
from reflex.state import BaseState, State
|
||||
|
||||
|
||||
class UploadState(rx.State):
|
||||
class UploadState(BaseState):
|
||||
"""The base state for uploading a file."""
|
||||
|
||||
async def handle_upload1(self, files: List[rx.UploadFile]):
|
||||
@ -17,7 +18,7 @@ class UploadState(rx.State):
|
||||
pass
|
||||
|
||||
|
||||
class BaseState(rx.State):
|
||||
class BaseState(BaseState):
|
||||
"""The test base state."""
|
||||
|
||||
pass
|
||||
@ -37,7 +38,7 @@ class SubUploadState(BaseState):
|
||||
pass
|
||||
|
||||
|
||||
class FileUploadState(rx.State):
|
||||
class FileUploadState(State):
|
||||
"""The base state for uploading a file."""
|
||||
|
||||
img_list: List[str]
|
||||
@ -79,7 +80,7 @@ class FileUploadState(rx.State):
|
||||
pass
|
||||
|
||||
|
||||
class FileStateBase1(rx.State):
|
||||
class FileStateBase1(State):
|
||||
"""The base state for a child FileUploadState."""
|
||||
|
||||
pass
|
||||
|
@ -28,7 +28,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text
|
||||
from reflex.event import Event, get_hydrate_event
|
||||
from reflex.middleware import HydrateMiddleware
|
||||
from reflex.model import Model
|
||||
from reflex.state import RouterData, State, StateManagerRedis, StateUpdate
|
||||
from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
|
||||
from reflex.style import Style
|
||||
from reflex.utils import format
|
||||
from reflex.vars import ComputedVar
|
||||
@ -43,7 +43,7 @@ from .states import (
|
||||
)
|
||||
|
||||
|
||||
class EmptyState(State):
|
||||
class EmptyState(BaseState):
|
||||
"""An empty state."""
|
||||
|
||||
pass
|
||||
@ -77,14 +77,14 @@ def about_page():
|
||||
return about
|
||||
|
||||
|
||||
class ATestState(State):
|
||||
class ATestState(BaseState):
|
||||
"""A simple state for testing."""
|
||||
|
||||
var: int
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_state() -> Type[State]:
|
||||
def test_state() -> Type[BaseState]:
|
||||
"""A default state.
|
||||
|
||||
Returns:
|
||||
@ -94,14 +94,14 @@ def test_state() -> Type[State]:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def redundant_test_state() -> Type[State]:
|
||||
def redundant_test_state() -> Type[BaseState]:
|
||||
"""A default state.
|
||||
|
||||
Returns:
|
||||
A default state.
|
||||
"""
|
||||
|
||||
class RedundantTestState(State):
|
||||
class RedundantTestState(BaseState):
|
||||
var: int
|
||||
|
||||
return RedundantTestState
|
||||
@ -198,12 +198,12 @@ def test_default_app(app: App):
|
||||
|
||||
|
||||
def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
|
||||
"""Test that an error is thrown when multiple classes subclass rx.State.
|
||||
"""Test that an error is thrown when multiple classes subclass rx.BaseState.
|
||||
|
||||
Args:
|
||||
monkeypatch: Pytest monkeypatch object.
|
||||
test_state: A test state subclassing rx.State.
|
||||
redundant_test_state: Another test state subclassing rx.State.
|
||||
test_state: A test state subclassing rx.BaseState.
|
||||
redundant_test_state: Another test state subclassing rx.BaseState.
|
||||
"""
|
||||
monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
|
||||
with pytest.raises(ValueError):
|
||||
@ -705,12 +705,12 @@ async def test_dict_mutation_detection__plain_list(
|
||||
[
|
||||
(
|
||||
FileUploadState,
|
||||
{"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
|
||||
{"state.file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
|
||||
),
|
||||
(
|
||||
ChildFileUploadState,
|
||||
{
|
||||
"file_state_base1.child_file_upload_state": {
|
||||
"state.file_state_base1.child_file_upload_state": {
|
||||
"img_list": ["image1.jpg", "image2.jpg"]
|
||||
}
|
||||
},
|
||||
@ -718,14 +718,14 @@ async def test_dict_mutation_detection__plain_list(
|
||||
(
|
||||
GrandChildFileUploadState,
|
||||
{
|
||||
"file_state_base1.file_state_base2.grand_child_file_upload_state": {
|
||||
"state.file_state_base1.file_state_base2.grand_child_file_upload_state": {
|
||||
"img_list": ["image1.jpg", "image2.jpg"]
|
||||
}
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_upload_file(tmp_path, state, delta, token: str):
|
||||
async def test_upload_file(tmp_path, state, delta, token: str, mocker):
|
||||
"""Test that file upload works correctly.
|
||||
|
||||
Args:
|
||||
@ -733,10 +733,15 @@ async def test_upload_file(tmp_path, state, delta, token: str):
|
||||
state: The state class.
|
||||
delta: Expected delta
|
||||
token: a Token.
|
||||
mocker: pytest mocker object.
|
||||
"""
|
||||
mocker.patch(
|
||||
"reflex.state.State.class_subclasses",
|
||||
{state if state is FileUploadState else FileStateBase1},
|
||||
)
|
||||
state._tmp_path = tmp_path
|
||||
# The App state must be the "root" of the state tree
|
||||
app = App(state=state if state is FileUploadState else FileStateBase1)
|
||||
app = App(state=State)
|
||||
app.event_namespace.emit = AsyncMock() # type: ignore
|
||||
current_state = await app.state_manager.get_state(token)
|
||||
data = b"This is binary data"
|
||||
@ -749,7 +754,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
|
||||
request_mock = unittest.mock.Mock()
|
||||
request_mock.headers = {
|
||||
"reflex-client-token": token,
|
||||
"reflex-event-handler": f"{state_name}.multi_handle_upload",
|
||||
"reflex-event-handler": f"state.{state_name}.multi_handle_upload",
|
||||
}
|
||||
|
||||
file1 = UploadFile(
|
||||
@ -851,7 +856,7 @@ async def test_upload_file_background(state, tmp_path, token):
|
||||
await app.state_manager.redis.close()
|
||||
|
||||
|
||||
class DynamicState(State):
|
||||
class DynamicState(BaseState):
|
||||
"""State class for testing dynamic route var.
|
||||
|
||||
This is defined at module level because event handlers cannot be addressed
|
||||
@ -891,9 +896,7 @@ class DynamicState(State):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
index_page,
|
||||
windows_platform: bool,
|
||||
token: str,
|
||||
index_page, windows_platform: bool, token: str, mocker
|
||||
):
|
||||
"""Create app with dynamic route var, and simulate navigation.
|
||||
|
||||
@ -904,7 +907,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
index_page: The index page.
|
||||
windows_platform: Whether the system is windows.
|
||||
token: a Token.
|
||||
mocker: pytest mocker object.
|
||||
"""
|
||||
mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
|
||||
DynamicState.add_var(
|
||||
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
|
||||
)
|
||||
arg_name = "dynamic"
|
||||
route = f"/test/[{arg_name}]"
|
||||
if windows_platform:
|
||||
|
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from reflex import event
|
||||
from reflex.event import Event, EventHandler, EventSpec, fix_events
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
from reflex.utils import format
|
||||
from reflex.vars import Var
|
||||
|
||||
@ -303,7 +303,7 @@ def test_event_actions():
|
||||
|
||||
|
||||
def test_event_actions_on_state():
|
||||
class EventActionState(State):
|
||||
class EventActionState(BaseState):
|
||||
def handler(self):
|
||||
pass
|
||||
|
||||
|
@ -19,11 +19,11 @@ from reflex.base import Base
|
||||
from reflex.constants import CompileVars, RouteVar, SocketEvent
|
||||
from reflex.event import Event, EventHandler
|
||||
from reflex.state import (
|
||||
BaseState,
|
||||
ImmutableStateError,
|
||||
LockExpiredError,
|
||||
MutableProxy,
|
||||
RouterData,
|
||||
State,
|
||||
StateManager,
|
||||
StateManagerMemory,
|
||||
StateManagerRedis,
|
||||
@ -75,7 +75,7 @@ class Object(Base):
|
||||
prop2: str = "hello"
|
||||
|
||||
|
||||
class TestState(State):
|
||||
class TestState(BaseState):
|
||||
"""A test state."""
|
||||
|
||||
# Set this class as not test one
|
||||
@ -148,7 +148,7 @@ class GrandchildState(ChildState):
|
||||
pass
|
||||
|
||||
|
||||
class DateTimeState(State):
|
||||
class DateTimeState(BaseState):
|
||||
"""A State with some datetime fields."""
|
||||
|
||||
d: datetime.date = datetime.date.fromisoformat("1989-11-09")
|
||||
@ -253,7 +253,6 @@ def test_class_vars(test_state):
|
||||
"""
|
||||
cls = type(test_state)
|
||||
assert set(cls.vars.keys()) == {
|
||||
CompileVars.IS_HYDRATED, # added by hydrate_middleware to all State
|
||||
"router",
|
||||
"num1",
|
||||
"num2",
|
||||
@ -641,7 +640,6 @@ def test_reset(test_state, child_state):
|
||||
"obj",
|
||||
"upper",
|
||||
"complex",
|
||||
"is_hydrated",
|
||||
"fig",
|
||||
"key",
|
||||
"sum",
|
||||
@ -837,7 +835,7 @@ def test_get_query_params(test_state):
|
||||
|
||||
|
||||
def test_add_var():
|
||||
class DynamicState(State):
|
||||
class DynamicState(BaseState):
|
||||
pass
|
||||
|
||||
ds1 = DynamicState()
|
||||
@ -870,7 +868,7 @@ def test_add_var_default_handlers(test_state):
|
||||
assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler)
|
||||
|
||||
|
||||
class InterdependentState(State):
|
||||
class InterdependentState(BaseState):
|
||||
"""A state with 3 vars and 3 computed vars.
|
||||
|
||||
x: a variable that no computed var depends on
|
||||
@ -915,7 +913,7 @@ class InterdependentState(State):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def interdependent_state() -> State:
|
||||
def interdependent_state() -> BaseState:
|
||||
"""A state with varying dependency between vars.
|
||||
|
||||
Returns:
|
||||
@ -988,7 +986,7 @@ def test_per_state_backend_var(interdependent_state):
|
||||
def test_child_state():
|
||||
"""Test that the child state computed vars can reference parent state vars."""
|
||||
|
||||
class MainState(State):
|
||||
class MainState(BaseState):
|
||||
v: int = 2
|
||||
|
||||
class ChildState(MainState):
|
||||
@ -1006,7 +1004,7 @@ def test_child_state():
|
||||
def test_conditional_computed_vars():
|
||||
"""Test that computed vars can have conditionals."""
|
||||
|
||||
class MainState(State):
|
||||
class MainState(BaseState):
|
||||
flag: bool = False
|
||||
t1: str = "a"
|
||||
t2: str = "b"
|
||||
@ -1051,7 +1049,7 @@ def test_event_handlers_convert_to_fns(test_state, child_state):
|
||||
def test_event_handlers_call_other_handlers():
|
||||
"""Test that event handlers can call other event handlers."""
|
||||
|
||||
class MainState(State):
|
||||
class MainState(BaseState):
|
||||
v: int = 0
|
||||
|
||||
def set_v(self, v: int):
|
||||
@ -1077,7 +1075,7 @@ def test_computed_var_cached():
|
||||
"""Test that a ComputedVar doesn't recalculate when accessed."""
|
||||
comp_v_calls = 0
|
||||
|
||||
class ComputedState(State):
|
||||
class ComputedState(BaseState):
|
||||
v: int = 0
|
||||
|
||||
@rx.cached_var
|
||||
@ -1102,7 +1100,7 @@ def test_computed_var_cached():
|
||||
def test_computed_var_cached_depends_on_non_cached():
|
||||
"""Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
|
||||
|
||||
class ComputedState(State):
|
||||
class ComputedState(BaseState):
|
||||
v: int = 0
|
||||
|
||||
@rx.var
|
||||
@ -1144,7 +1142,7 @@ def test_computed_var_depends_on_parent_non_cached():
|
||||
"""Child state cached_var that depends on parent state un cached var is always recalculated."""
|
||||
counter = 0
|
||||
|
||||
class ParentState(State):
|
||||
class ParentState(BaseState):
|
||||
@rx.var
|
||||
def no_cache_v(self) -> int:
|
||||
nonlocal counter
|
||||
@ -1165,21 +1163,18 @@ def test_computed_var_depends_on_parent_non_cached():
|
||||
dict1 = ps.dict()
|
||||
assert dict1[ps.get_full_name()] == {
|
||||
"no_cache_v": 1,
|
||||
CompileVars.IS_HYDRATED: False,
|
||||
"router": formatted_router,
|
||||
}
|
||||
assert dict1[cs.get_full_name()] == {"dep_v": 2}
|
||||
dict2 = ps.dict()
|
||||
assert dict2[ps.get_full_name()] == {
|
||||
"no_cache_v": 3,
|
||||
CompileVars.IS_HYDRATED: False,
|
||||
"router": formatted_router,
|
||||
}
|
||||
assert dict2[cs.get_full_name()] == {"dep_v": 4}
|
||||
dict3 = ps.dict()
|
||||
assert dict3[ps.get_full_name()] == {
|
||||
"no_cache_v": 5,
|
||||
CompileVars.IS_HYDRATED: False,
|
||||
"router": formatted_router,
|
||||
}
|
||||
assert dict3[cs.get_full_name()] == {"dep_v": 6}
|
||||
@ -1195,7 +1190,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
||||
"""
|
||||
counter = 0
|
||||
|
||||
class HandlerState(State):
|
||||
class HandlerState(BaseState):
|
||||
x: int = 42
|
||||
|
||||
def handler(self):
|
||||
@ -1226,7 +1221,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
||||
def test_computed_var_dependencies():
|
||||
"""Test that a ComputedVar correctly tracks its dependencies."""
|
||||
|
||||
class ComputedState(State):
|
||||
class ComputedState(BaseState):
|
||||
v: int = 0
|
||||
w: int = 0
|
||||
x: int = 0
|
||||
@ -1293,7 +1288,7 @@ def test_computed_var_dependencies():
|
||||
def test_backend_method():
|
||||
"""A method with leading underscore should be callable from event handler."""
|
||||
|
||||
class BackendMethodState(State):
|
||||
class BackendMethodState(BaseState):
|
||||
def _be_method(self):
|
||||
return True
|
||||
|
||||
@ -1369,7 +1364,7 @@ def test_error_on_state_method_shadow():
|
||||
"""Test that an error is thrown when an event handler shadows a state method."""
|
||||
with pytest.raises(NameError) as err:
|
||||
|
||||
class InvalidTest(rx.State):
|
||||
class InvalidTest(BaseState):
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@ -1382,7 +1377,7 @@ def test_error_on_state_method_shadow():
|
||||
def test_state_with_invalid_yield():
|
||||
"""Test that an error is thrown when a state yields an invalid value."""
|
||||
|
||||
class StateWithInvalidYield(rx.State):
|
||||
class StateWithInvalidYield(BaseState):
|
||||
"""A state that yields an invalid value."""
|
||||
|
||||
def invalid_handler(self):
|
||||
@ -1666,7 +1661,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
assert mcall.kwargs["to"] == grandchild_state.get_sid()
|
||||
|
||||
|
||||
class BackgroundTaskState(State):
|
||||
class BackgroundTaskState(BaseState):
|
||||
"""A state with a background task."""
|
||||
|
||||
order: List[str] = []
|
||||
@ -2192,9 +2187,20 @@ def test_mutable_copy_vars(mutable_state, copy_func):
|
||||
assert not isinstance(var_copy, MutableProxy)
|
||||
|
||||
|
||||
def test_duplicate_substate_class(duplicate_substate):
|
||||
def test_duplicate_substate_class(mocker):
|
||||
mocker.patch("reflex.state.os.environ", {})
|
||||
with pytest.raises(ValueError):
|
||||
duplicate_substate()
|
||||
|
||||
class TestState(BaseState):
|
||||
pass
|
||||
|
||||
class ChildTestState(TestState): # type: ignore # noqa
|
||||
pass
|
||||
|
||||
class ChildTestState(TestState): # type: ignore # noqa
|
||||
pass
|
||||
|
||||
return TestState
|
||||
|
||||
|
||||
class Foo(Base):
|
||||
@ -2206,7 +2212,7 @@ class Foo(Base):
|
||||
def test_json_dumps_with_mutables():
|
||||
"""Test that json.dumps works with Base vars inside mutable types."""
|
||||
|
||||
class MutableContainsBase(State):
|
||||
class MutableContainsBase(BaseState):
|
||||
items: List[Foo] = [Foo()]
|
||||
|
||||
dict_val = MutableContainsBase().dict()
|
||||
@ -2216,7 +2222,7 @@ def test_json_dumps_with_mutables():
|
||||
f_formatted_router = str(formatted_router).replace("'", '"')
|
||||
assert (
|
||||
val
|
||||
== f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
|
||||
== f'{{"{MutableContainsBase.get_full_name()}": {{"items": {f_items}, "router": {f_formatted_router}}}}}'
|
||||
)
|
||||
|
||||
|
||||
@ -2225,7 +2231,7 @@ def test_reset_with_mutables():
|
||||
default = [[0, 0], [0, 1], [1, 1]]
|
||||
copied_default = copy.deepcopy(default)
|
||||
|
||||
class MutableResetState(State):
|
||||
class MutableResetState(BaseState):
|
||||
items: List[List[int]] = default
|
||||
|
||||
instance = MutableResetState()
|
||||
@ -2273,7 +2279,7 @@ class Custom3(Base):
|
||||
def test_state_union_optional():
|
||||
"""Test that state can be defined with Union and Optional vars."""
|
||||
|
||||
class UnionState(State):
|
||||
class UnionState(BaseState):
|
||||
int_float: Union[int, float] = 0
|
||||
opt_int: Optional[int]
|
||||
c3: Optional[Custom3]
|
||||
|
@ -6,7 +6,7 @@ import pytest
|
||||
from pandas import DataFrame
|
||||
|
||||
from reflex.base import Base
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
from reflex.vars import (
|
||||
BaseVar,
|
||||
ComputedVar,
|
||||
@ -24,12 +24,6 @@ test_vars = [
|
||||
]
|
||||
|
||||
|
||||
class BaseState(State):
|
||||
"""A Test State."""
|
||||
|
||||
val: str = "key"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def TestObj():
|
||||
class TestObj(Base):
|
||||
@ -41,7 +35,7 @@ def TestObj():
|
||||
|
||||
@pytest.fixture
|
||||
def ParentState(TestObj):
|
||||
class ParentState(State):
|
||||
class ParentState(BaseState):
|
||||
foo: int
|
||||
bar: int
|
||||
|
||||
@ -74,7 +68,7 @@ def GrandChildState(ChildState, TestObj):
|
||||
|
||||
@pytest.fixture
|
||||
def StateWithAnyVar(TestObj):
|
||||
class StateWithAnyVar(State):
|
||||
class StateWithAnyVar(BaseState):
|
||||
@ComputedVar
|
||||
def var_without_annotation(self) -> typing.Any:
|
||||
return TestObj
|
||||
@ -84,7 +78,7 @@ def StateWithAnyVar(TestObj):
|
||||
|
||||
@pytest.fixture
|
||||
def StateWithCorrectVarAnnotation():
|
||||
class StateWithCorrectVarAnnotation(State):
|
||||
class StateWithCorrectVarAnnotation(BaseState):
|
||||
@ComputedVar
|
||||
def var_with_annotation(self) -> str:
|
||||
return "Correct annotation"
|
||||
@ -94,7 +88,7 @@ def StateWithCorrectVarAnnotation():
|
||||
|
||||
@pytest.fixture
|
||||
def StateWithWrongVarAnnotation(TestObj):
|
||||
class StateWithWrongVarAnnotation(State):
|
||||
class StateWithWrongVarAnnotation(BaseState):
|
||||
@ComputedVar
|
||||
def var_with_annotation(self) -> str:
|
||||
return TestObj
|
||||
|
@ -528,7 +528,6 @@ formatted_router = {
|
||||
},
|
||||
"dt": "1989-11-09 18:53:00+01:00",
|
||||
"fig": [],
|
||||
"is_hydrated": False,
|
||||
"key": "",
|
||||
"map_key": "a",
|
||||
"mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
|
||||
@ -553,7 +552,6 @@ formatted_router = {
|
||||
DateTimeState.get_full_name(): {
|
||||
"d": "1989-11-09",
|
||||
"dt": "1989-11-09 18:53:00+01:00",
|
||||
"is_hydrated": False,
|
||||
"t": "18:53:00+01:00",
|
||||
"td": "11 days, 0:11:00",
|
||||
"router": formatted_router,
|
||||
|
@ -10,7 +10,7 @@ from packaging import version
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.event import EventHandler
|
||||
from reflex.state import State
|
||||
from reflex.state import BaseState
|
||||
from reflex.utils import (
|
||||
build,
|
||||
prerequisites,
|
||||
@ -43,7 +43,7 @@ V056 = version.parse("0.5.6")
|
||||
VMAXPLUS1 = version.parse(get_above_max_version())
|
||||
|
||||
|
||||
class ExampleTestState(State):
|
||||
class ExampleTestState(BaseState):
|
||||
"""Test state class."""
|
||||
|
||||
def test_event_handler(self):
|
||||
|
Loading…
Reference in New Issue
Block a user