diff --git a/tests/integration/test_component_state.py b/tests/integration/test_component_state.py index dfd8a9d74..7b35f8116 100644 --- a/tests/integration/test_component_state.py +++ b/tests/integration/test_component_state.py @@ -5,6 +5,7 @@ from typing import Generator import pytest from selenium.webdriver.common.by import By +from reflex.state import State, _substate_key from reflex.testing import AppHarness from . import utils @@ -96,6 +97,7 @@ async def test_component_state_app(component_state_app: AppHarness): ss = utils.SessionStorage(driver) assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + root_state_token = _substate_key(ss.get("token"), State) count_a = driver.find_element(By.ID, "count-a") count_b = driver.find_element(By.ID, "count-b") @@ -106,7 +108,7 @@ async def test_component_state_app(component_state_app: AppHarness): # Check that backend vars in mixins are okay a_state_name = driver.find_element(By.ID, "a_state_name").text b_state_name = driver.find_element(By.ID, "b_state_name").text - root_state = await component_state_app.get_state(ss.get("token")) + root_state = await component_state_app.get_state(root_state_token) a_state = root_state.substates[a_state_name] b_state = root_state.substates[b_state_name] assert a_state._backend_vars == a_state.backend_vars @@ -126,7 +128,7 @@ async def test_component_state_app(component_state_app: AppHarness): button_inc_a.click() assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3" - root_state = await component_state_app.get_state(ss.get("token")) + root_state = await component_state_app.get_state(root_state_token) a_state = root_state.substates[a_state_name] b_state = root_state.substates[b_state_name] assert a_state._backend_vars != a_state.backend_vars @@ -142,7 +144,7 @@ async def test_component_state_app(component_state_app: AppHarness): button_b.click() assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2" - root_state = await component_state_app.get_state(ss.get("token")) + root_state = await component_state_app.get_state(root_state_token) a_state = root_state.substates[a_state_name] b_state = root_state.substates[b_state_name] assert b_state._backend_vars != b_state.backend_vars