Fix AppHarness tests (#1987)

* test_client_storage: remove race conditions for cookie assignment

Poll for default timeout for cookies to appear in the controlled browser.

* Remove use of deprecated get_token and get_sid in core

Both reflex.app and reflex.state were still using deprecated methods, which
were throwing unsolvable warnings for end users.

* Remove deprecated router functions from integration tests

Mostly removing custom "token" var and replacing with
router.session.client_token.

Also replacing `get_query_params` and `get_current_page` usage as well.

* fix upload tests

Cannot pass substate as main app state, since it blocks us from accessing
"inherited vars"

* state: do NOT reset `router` to default

When calling `.reset` to reset state vars, do NOT reset the router data, as
that could mess up internal event processing.
This commit is contained in:
Masen Furer 2023-10-17 16:46:13 -07:00 committed by GitHub
parent 317b883ec8
commit f6a7eed359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 178 additions and 154 deletions

View File

@ -57,13 +57,11 @@ def BackgroundTask():
async def non_blocking_pause(self): async def non_blocking_pause(self):
await asyncio.sleep(0.02) await asyncio.sleep(0.02)
@rx.cached_var
def token(self) -> str:
return self.get_token()
def index() -> rx.Component: def index() -> rx.Component:
return rx.vstack( return rx.vstack(
rx.input(id="token", value=State.token, is_read_only=True), rx.input(
id="token", value=State.router.session.client_token, is_read_only=True
),
rx.heading(State.counter, id="counter"), rx.heading(State.counter, id="counter"),
rx.input( rx.input(
id="iterations", id="iterations",

View File

@ -21,10 +21,6 @@ def ClientSide():
state_var: str = "" state_var: str = ""
input_value: str = "" input_value: str = ""
@rx.var
def token(self) -> str:
return self.get_token()
class ClientSideSubState(ClientSideState): class ClientSideSubState(ClientSideState):
# cookies with default settings # cookies with default settings
c1: str = rx.Cookie() c1: str = rx.Cookie()
@ -59,7 +55,11 @@ def ClientSide():
def index(): def index():
return rx.fragment( return rx.fragment(
rx.input(value=ClientSideState.token, is_read_only=True, id="token"), rx.input(
value=ClientSideState.router.session.client_token,
is_read_only=True,
id="token",
),
rx.input( rx.input(
placeholder="state var", placeholder="state var",
value=ClientSideState.state_var, value=ClientSideState.state_var,
@ -284,63 +284,68 @@ async def test_client_side_state(
input_value_input.send_keys("l1s value") input_value_input.send_keys("l1s value")
set_sub_sub_state_button.click() set_sub_sub_state_button.click()
exp_cookies = {
"client_side_state.client_side_sub_state.c1": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c1",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c1%20value",
},
"client_side_state.client_side_sub_state.c2": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c2",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c2%20value",
},
"client_side_state.client_side_sub_state.c4": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c4",
"path": "/",
"sameSite": "Strict",
"secure": False,
"value": "c4%20value",
},
"c6": {
"domain": "localhost",
"httpOnly": False,
"name": "c6",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c6%20value",
},
"client_side_state.client_side_sub_state.c7": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c7",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c7%20value",
},
"client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c1s%20value",
},
}
AppHarness._poll_for(
lambda: all(cookie_key in cookie_info_map(driver) for cookie_key in exp_cookies)
)
cookies = cookie_info_map(driver) cookies = cookie_info_map(driver)
assert cookies.pop("client_side_state.client_side_sub_state.c1") == { for exp_cookie_key, exp_cookie_data in exp_cookies.items():
"domain": "localhost", assert cookies.pop(exp_cookie_key) == exp_cookie_data
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c1",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c1%20value",
}
assert cookies.pop("client_side_state.client_side_sub_state.c2") == {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c2",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c2%20value",
}
assert cookies.pop("client_side_state.client_side_sub_state.c4") == {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c4",
"path": "/",
"sameSite": "Strict",
"secure": False,
"value": "c4%20value",
}
assert cookies.pop("c6") == {
"domain": "localhost",
"httpOnly": False,
"name": "c6",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c6%20value",
}
assert cookies.pop("client_side_state.client_side_sub_state.c7") == {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c7",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c7%20value",
}
assert cookies.pop(
"client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s"
) == {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c1s%20value",
}
# assert all cookies have been popped for this page # assert all cookies have been popped for this page
assert not cookies assert not cookies
@ -476,6 +481,9 @@ async def test_client_side_state(
assert l1s.text == "l1s value" assert l1s.text == "l1s value"
# make sure c5 cookie shows up on the `/foo` route # make sure c5 cookie shows up on the `/foo` route
AppHarness._poll_for(
lambda: "client_side_state.client_side_sub_state.c5" in cookie_info_map(driver)
)
assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == { assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
"domain": "localhost", "domain": "localhost",
"httpOnly": False, "httpOnly": False,

View File

@ -19,12 +19,10 @@ def DynamicRoute():
page_id: str = "" page_id: str = ""
def on_load(self): def on_load(self):
self.order.append( self.order.append(f"{self.router.page.path}-{self.page_id or 'no page id'}")
f"{self.get_current_page()}-{self.page_id or 'no page id'}"
)
def on_load_redir(self): def on_load_redir(self):
query_params = self.get_query_params() query_params = self.router.page.params
self.order.append(f"on_load_redir-{query_params}") self.order.append(f"on_load_redir-{query_params}")
return rx.redirect(f"/page/{query_params['page_id']}") return rx.redirect(f"/page/{query_params['page_id']}")
@ -35,13 +33,13 @@ def DynamicRoute():
except ValueError: except ValueError:
return "0" return "0"
@rx.var
def token(self) -> str:
return self.get_token()
def index(): def index():
return rx.fragment( return rx.fragment(
rx.input(value=DynamicState.token, is_read_only=True, id="token"), # type: ignore rx.input(
value=DynamicState.router.session.client_token,
is_read_only=True,
id="token",
),
rx.input(value=DynamicState.page_id, is_read_only=True, id="page_id"), rx.input(value=DynamicState.page_id, is_read_only=True, id="page_id"),
rx.link("index", href="/", id="link_index"), rx.link("index", href="/", id="link_index"),
rx.link("page_X", href="/static/x", id="link_page_x"), rx.link("page_X", href="/static/x", id="link_page_x"),
@ -212,7 +210,7 @@ async def test_on_load_navigate(
with poll_for_navigation(driver): with poll_for_navigation(driver):
driver.get(f"{driver.current_url}?foo=bar") driver.get(f"{driver.current_url}?foo=bar")
await poll_for_order(exp_order) await poll_for_order(exp_order)
assert (await dynamic_route.get_state(token)).get_query_params()["foo"] == "bar" assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar"
# hit a 404 and ensure we still hydrate # hit a 404 and ensure we still hydrate
exp_order += ["/404-no page id"] exp_order += ["/404-no page id"]

View File

@ -24,10 +24,6 @@ def EventChain():
event_order: list[str] = [] event_order: list[str] = []
interim_value: str = "" interim_value: str = ""
@rx.var
def token(self) -> str:
return self.get_token()
def event_no_args(self): def event_no_args(self):
self.event_order.append("event_no_args") self.event_order.append("event_no_args")
@ -128,10 +124,14 @@ def EventChain():
app = rx.App(state=State) app = rx.App(state=State)
token_input = rx.input(
value=State.router.session.client_token, is_read_only=True, id="token"
)
@app.add_page @app.add_page
def index(): def index():
return rx.fragment( return rx.fragment(
rx.input(value=State.token, is_read_only=True, id="token"), token_input,
rx.input(value=State.interim_value, is_read_only=True, id="interim_value"), rx.input(value=State.interim_value, is_read_only=True, id="interim_value"),
rx.button( rx.button(
"Return Event", "Return Event",
@ -203,13 +203,13 @@ def EventChain():
def on_load_return_chain(): def on_load_return_chain():
return rx.fragment( return rx.fragment(
rx.text("return"), rx.text("return"),
rx.input(value=State.token, readonly=True, id="token"), token_input,
) )
def on_load_yield_chain(): def on_load_yield_chain():
return rx.fragment( return rx.fragment(
rx.text("yield"), rx.text("yield"),
rx.input(value=State.token, readonly=True, id="token"), token_input,
) )
def on_mount_return_chain(): def on_mount_return_chain():
@ -219,7 +219,7 @@ def EventChain():
on_mount=State.on_load_return_chain, on_mount=State.on_load_return_chain,
on_unmount=lambda: State.event_arg("unmount"), # type: ignore on_unmount=lambda: State.event_arg("unmount"), # type: ignore
), ),
rx.input(value=State.token, readonly=True, id="token"), token_input,
rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"), rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
) )
@ -233,7 +233,7 @@ def EventChain():
], ],
on_unmount=State.event_no_args, on_unmount=State.event_no_args,
), ),
rx.input(value=State.token, readonly=True, id="token"), token_input,
rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"), rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
) )
@ -280,6 +280,27 @@ def driver(event_chain: AppHarness) -> Generator[WebDriver, None, None]:
driver.quit() driver.quit()
def assert_token(event_chain: AppHarness, driver: WebDriver) -> str:
"""Get the token associated with backend state.
Args:
event_chain: harness for EventChain app.
driver: WebDriver instance.
Returns:
The token visible in the driver browser.
"""
assert event_chain.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = event_chain.poll_for_value(token_input)
assert token is not None
return token
@pytest.mark.parametrize( @pytest.mark.parametrize(
("button_id", "exp_event_order"), ("button_id", "exp_event_order"),
[ [
@ -375,14 +396,8 @@ async def test_event_chain_click(
button_id: the ID of the button to click button_id: the ID of the button to click
exp_event_order: the expected events recorded in the State exp_event_order: the expected events recorded in the State
""" """
token_input = driver.find_element(By.ID, "token") token = assert_token(event_chain, driver)
btn = driver.find_element(By.ID, button_id) btn = driver.find_element(By.ID, button_id)
assert token_input
assert btn
token = event_chain.poll_for_value(token_input)
assert token is not None
btn.click() btn.click()
async def _has_all_events(): async def _has_all_events():
@ -435,11 +450,7 @@ async def test_event_chain_on_load(
""" """
assert event_chain.frontend_url is not None assert event_chain.frontend_url is not None
driver.get(event_chain.frontend_url + uri) driver.get(event_chain.frontend_url + uri)
token_input = driver.find_element(By.ID, "token") token = assert_token(event_chain, driver)
assert token_input
token = event_chain.poll_for_value(token_input)
assert token is not None
async def _has_all_events(): async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len( return len((await event_chain.get_state(token)).event_order) == len(
@ -511,11 +522,7 @@ async def test_event_chain_on_mount(
""" """
assert event_chain.frontend_url is not None assert event_chain.frontend_url is not None
driver.get(event_chain.frontend_url + uri) driver.get(event_chain.frontend_url + uri)
token_input = driver.find_element(By.ID, "token") token = assert_token(event_chain, driver)
assert token_input
token = event_chain.poll_for_value(token_input)
assert token is not None
unmount_button = driver.find_element(By.ID, "unmount") unmount_button = driver.find_element(By.ID, "unmount")
assert unmount_button assert unmount_button
@ -546,9 +553,8 @@ def test_yield_state_update(event_chain: AppHarness, driver: WebDriver, button_i
driver: selenium WebDriver open to the app driver: selenium WebDriver open to the app
button_id: the ID of the button to click button_id: the ID of the button to click
""" """
token_input = driver.find_element(By.ID, "token")
interim_value_input = driver.find_element(By.ID, "interim_value") interim_value_input = driver.find_element(By.ID, "interim_value")
assert event_chain.poll_for_value(token_input) assert_token(event_chain, driver)
btn = driver.find_element(By.ID, button_id) btn = driver.find_element(By.ID, button_id)
btn.click() btn.click()

View File

@ -19,16 +19,16 @@ def FormSubmit():
def form_submit(self, form_data: dict): def form_submit(self, form_data: dict):
self.form_data = form_data self.form_data = form_data
@rx.var
def token(self) -> str:
return self.get_token()
app = rx.App(state=FormState) app = rx.App(state=FormState)
@app.add_page @app.add_page
def index(): def index():
return rx.vstack( return rx.vstack(
rx.input(value=FormState.token, is_read_only=True, id="token"), rx.input(
value=FormState.router.session.client_token,
is_read_only=True,
id="token",
),
rx.form( rx.form(
rx.vstack( rx.vstack(
rx.input(id="name_input"), rx.input(id="name_input"),

View File

@ -16,22 +16,20 @@ def FullyControlledInput():
class State(rx.State): class State(rx.State):
text: str = "initial" text: str = "initial"
@rx.var
def token(self) -> str:
return self.get_token()
app = rx.App(state=State) app = rx.App(state=State)
@app.add_page @app.add_page
def index(): def index():
return rx.fragment( return rx.fragment(
rx.input(value=State.token, is_read_only=True, id="token"), rx.input(
value=State.router.session.client_token, is_read_only=True, id="token"
),
rx.input( rx.input(
id="debounce_input_input", id="debounce_input_input",
on_change=State.set_text, # type: ignore on_change=State.set_text, # type: ignore
value=State.text, value=State.text,
), ),
rx.input(value=State.text, id="value_input"), rx.input(value=State.text, id="value_input", is_read_only=True),
rx.input(on_change=State.set_text, id="on_change_input"), # type: ignore rx.input(on_change=State.set_text, id="on_change_input"), # type: ignore
rx.button("CLEAR", on_click=rx.set_value("on_change_input", "")), rx.button("CLEAR", on_click=rx.set_value("on_change_input", "")),
) )

View File

@ -21,13 +21,11 @@ def RadixThemesApp():
v: str = "" v: str = ""
checked: bool = False checked: bool = False
@rx.var
def token(self) -> str:
return self.get_token()
def index() -> rx.Component: def index() -> rx.Component:
return rdxt.box( return rdxt.box(
rdxt.text_field(id="token", value=State.token, read_only=True), rdxt.text_field(
id="token", value=State.router.session.client_token, read_only=True
),
rdxt.text_field(id="tf-bare", value=State.v, on_change=State.set_v), # type: ignore rdxt.text_field(id="tf-bare", value=State.v, on_change=State.set_v), # type: ignore
rdxt.text_field_root( rdxt.text_field_root(
rdxt.text_field_slot("🧸"), rdxt.text_field_slot("🧸"),

View File

@ -33,16 +33,14 @@ def ServerSideEvent():
def set_value_return_c(self): def set_value_return_c(self):
return rx.set_value("c", "") return rx.set_value("c", "")
@rx.var
def token(self) -> str:
return self.get_token()
app = rx.App(state=SSState) app = rx.App(state=SSState)
@app.add_page @app.add_page
def index(): def index():
return rx.fragment( return rx.fragment(
rx.input(id="token", value=SSState.token, is_read_only=True), rx.input(
id="token", value=SSState.router.session.client_token, is_read_only=True
),
rx.input(default_value="a", id="a"), rx.input(default_value="a", id="a"),
rx.input(default_value="b", id="b"), rx.input(default_value="b", id="b"),
rx.input(default_value="c", id="c"), rx.input(default_value="c", id="c"),

View File

@ -26,16 +26,16 @@ def Table():
caption: str = "random caption" caption: str = "random caption"
@rx.var
def token(self) -> str:
return self.get_token()
app = rx.App(state=TableState) app = rx.App(state=TableState)
@app.add_page @app.add_page
def index(): def index():
return rx.center( return rx.center(
rx.input(id="token", value=TableState.token, is_read_only=True), rx.input(
id="token",
value=TableState.router.session.client_token,
is_read_only=True,
),
rx.table_container( rx.table_container(
rx.table( rx.table(
headers=TableState.headers, headers=TableState.headers,

View File

@ -22,13 +22,13 @@ def UploadFile():
upload_data = await file.read() upload_data = await file.read()
self._file_data[file.filename or ""] = upload_data.decode("utf-8") self._file_data[file.filename or ""] = upload_data.decode("utf-8")
@rx.var
def token(self) -> str:
return self.get_token()
def index(): def index():
return rx.vstack( return rx.vstack(
rx.input(value=UploadState.token, is_read_only=True, id="token"), rx.input(
value=UploadState.router.session.client_token,
is_read_only=True,
id="token",
),
rx.upload( rx.upload(
rx.vstack( rx.vstack(
rx.button("Select File"), rx.button("Select File"),

View File

@ -29,16 +29,16 @@ def VarOperations():
dict1: dict = {1: 2} dict1: dict = {1: 2}
dict2: dict = {3: 4} dict2: dict = {3: 4}
@rx.var
def token(self) -> str:
return self.get_token()
app = rx.App(state=VarOperationState) app = rx.App(state=VarOperationState)
@app.add_page @app.add_page
def index(): def index():
return rx.vstack( return rx.vstack(
rx.input(id="token", value=VarOperationState.token, is_read_only=True), rx.input(
id="token",
value=VarOperationState.router.session.client_token,
is_read_only=True,
),
# INT INT # INT INT
rx.text( rx.text(
VarOperationState.int_var1 + VarOperationState.int_var2, VarOperationState.int_var1 + VarOperationState.int_var2,

View File

@ -725,7 +725,7 @@ class App(Base):
state._clean() state._clean()
await self.event_namespace.emit_update( await self.event_namespace.emit_update(
update=StateUpdate(delta=delta), update=StateUpdate(delta=delta),
sid=state.get_sid(), sid=state.router.session.session_id,
) )
def _process_background(self, state: State, event: Event) -> asyncio.Task | None: def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
@ -761,7 +761,7 @@ class App(Base):
# Send the update to the client. # Send the update to the client.
await self.event_namespace.emit_update( await self.event_namespace.emit_update(
update=update, update=update,
sid=state.get_sid(), sid=state.router.session.session_id,
) )
task = asyncio.create_task(_coro()) task = asyncio.create_task(_coro())
@ -865,7 +865,7 @@ def upload(app: App):
# Get the state for the session. # Get the state for the session.
async with app.state_manager.modify_state(token) as state: async with app.state_manager.modify_state(token) as state:
# get the current session ID # get the current session ID
sid = state.get_sid() sid = state.router.session.session_id
# get the current state(parent state/substate) # get the current state(parent state/substate)
path = handler.split(".")[:-1] path = handler.split(".")[:-1]
current_state = state.get_substate(path) current_state = state.get_substate(path)

View File

@ -818,6 +818,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Reset the base vars. # Reset the base vars.
fields = self.get_fields() fields = self.get_fields()
for prop_name in self.base_vars: for prop_name in self.base_vars:
if prop_name == constants.ROUTER:
continue # never reset the router data
setattr(self, prop_name, copy.deepcopy(fields[prop_name].default)) setattr(self, prop_name, copy.deepcopy(fields[prop_name].default))
# Recursively reset the substates. # Recursively reset the substates.
@ -961,7 +963,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns: Returns:
The valid StateUpdate containing the events and final flag. The valid StateUpdate containing the events and final flag.
""" """
token = self.get_token() token = self.router.session.client_token
# Convert valid EventHandler and EventSpec into Event # Convert valid EventHandler and EventSpec into Event
fixed_events = fix_events(self._check_valid(handler, events), token) fixed_events = fix_events(self._check_valid(handler, events), token)
@ -1273,7 +1275,9 @@ class StateProxy(wrapt.ObjectProxy):
Returns: Returns:
This StateProxy instance in mutable mode. This StateProxy instance in mutable mode.
""" """
self._self_actx = self._self_app.modify_state(self.__wrapped__.get_token()) self._self_actx = self._self_app.modify_state(
self.__wrapped__.router.session.client_token
)
mutable_state = await self._self_actx.__aenter__() mutable_state = await self._self_actx.__aenter__()
super().__setattr__( super().__setattr__(
"__wrapped__", mutable_state.get_substate(self._self_substate_path) "__wrapped__", mutable_state.get_substate(self._self_substate_path)

View File

@ -4,6 +4,7 @@ import reflex as rx
from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
from .upload import ( from .upload import (
ChildFileUploadState, ChildFileUploadState,
FileStateBase1,
FileUploadState, FileUploadState,
GrandChildFileUploadState, GrandChildFileUploadState,
SubUploadState, SubUploadState,

View File

@ -42,6 +42,7 @@ from reflex.vars import ComputedVar
from .conftest import chdir from .conftest import chdir
from .states import ( from .states import (
ChildFileUploadState, ChildFileUploadState,
FileStateBase1,
FileUploadState, FileUploadState,
GenState, GenState,
GrandChildFileUploadState, GrandChildFileUploadState,
@ -735,7 +736,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
token: a Token. token: a Token.
""" """
state._tmp_path = tmp_path state._tmp_path = tmp_path
app = App(state=state) app = App(state=state if state is FileUploadState else FileStateBase1)
app.event_namespace.emit = AsyncMock() # type: ignore app.event_namespace.emit = AsyncMock() # type: ignore
current_state = await app.state_manager.get_state(token) current_state = await app.state_manager.get_state(token)
data = b"This is binary data" data = b"This is binary data"
@ -744,12 +745,17 @@ async def test_upload_file(tmp_path, state, delta, token: str):
bio = io.BytesIO() bio = io.BytesIO()
bio.write(data) bio.write(data)
if state is FileUploadState:
handler_prefix = f"{token}:{state.get_name()}"
else:
handler_prefix = f"{token}:{state.get_full_name().partition('.')[2]}"
file1 = UploadFile( file1 = UploadFile(
filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image1.jpg", filename=f"{handler_prefix}.multi_handle_upload:True:image1.jpg",
file=bio, file=bio,
) )
file2 = UploadFile( file2 = UploadFile(
filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image2.jpg", filename=f"{handler_prefix}.multi_handle_upload:True:image2.jpg",
file=bio, file=bio,
) )
upload_fn = upload(app) upload_fn = upload(app)
@ -759,7 +765,11 @@ async def test_upload_file(tmp_path, state, delta, token: str):
app.event_namespace.emit.assert_called_with( # type: ignore app.event_namespace.emit.assert_called_with( # type: ignore
"event", state_update.json(), to=current_state.get_sid() "event", state_update.json(), to=current_state.get_sid()
) )
assert (await app.state_manager.get_state(token)).dict()["img_list"] == [ current_state = await app.state_manager.get_state(token)
state_dict = current_state.dict()
for substate in state.get_full_name().split(".")[1:]:
state_dict = state_dict[substate]
assert state_dict["img_list"] == [
"image1.jpg", "image1.jpg",
"image2.jpg", "image2.jpg",
] ]
@ -788,14 +798,20 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
bio.write(data) bio.write(data)
state._tmp_path = tmp_path state._tmp_path = tmp_path
app = App(state=state) app = App(state=state if state is FileUploadState else FileStateBase1)
if state is FileUploadState:
state_name = state.get_name()
else:
state_name = state.get_full_name().partition(".")[2]
handler_prefix = f"{token}:{state_name}"
file1 = UploadFile( file1 = UploadFile(
filename=f"{token}:{state.get_name()}.handle_upload2:True:image1.jpg", filename=f"{handler_prefix}.handle_upload2:True:image1.jpg",
file=bio, file=bio,
) )
file2 = UploadFile( file2 = UploadFile(
filename=f"{token}:{state.get_name()}.handle_upload2:True:image2.jpg", filename=f"{handler_prefix}.handle_upload2:True:image2.jpg",
file=bio, file=bio,
) )
fn = upload(app) fn = upload(app)
@ -803,7 +819,7 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
await fn([file1, file2]) await fn([file1, file2])
assert ( assert (
err.value.args[0] err.value.args[0]
== f"`{state.get_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]" == f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
) )
if isinstance(app.state_manager, StateManagerRedis): if isinstance(app.state_manager, StateManagerRedis):

View File

@ -643,7 +643,6 @@ def test_reset(test_state, child_state):
"map_key", "map_key",
"mapping", "mapping",
"dt", "dt",
"router",
} }
# The dirty vars should be reset. # The dirty vars should be reset.