From 3039b54a7534fc46e030c0b30534d3ab5d685365 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Thu, 11 Jul 2024 20:13:57 +0200 Subject: [PATCH] add module prefix to generated state names (#3214) * add module prefix to state names * fix state names in test_app * update state names in test_state * fix state names in test_var * fix state name in test_component * fix state names in test_format * fix state names in test_foreach * fix state names in test_cond * fix state names in test_datatable * fix state names in test_colors * fix state names in test_script * fix state names in test_match * fix state name in event1 fixture * fix pyright and darglint * fix state names in state_tree * fix state names in redis only test * fix state names in test_client_storage * fix state name in js template * add `get_state_name` and `get_full_state_name` helpers for `AppHarness` * fix state names in test_dynamic_routes * use new state name helpers in test_client_storage * fix state names in test_event_actions * fix state names in test_event_chain * fix state names in test_upload * fix state name in test_login_flow * fix state names in test_input * fix state names in test_form_submit * ruff * validate state module names * wtf is going on here? * remove comments leftover from refactoring * adjust new test_add_style_embedded_vars * fix state name in state.js * fix integration/test_client_state.py new SessionStorage feature was added with more full state names that need to be formatted in * fix pre-commit issues in test_client_storage.py * adjust test_computed_vars * adjust safe-guards * fix redis tests with new exception state --------- Co-authored-by: Masen Furer --- integration/test_client_storage.py | 99 +++----- integration/test_computed_vars.py | 8 +- integration/test_dynamic_routes.py | 11 +- integration/test_event_actions.py | 12 +- integration/test_event_chain.py | 18 +- integration/test_exception_handlers.py | 2 + integration/test_form_submit.py | 7 +- integration/test_input.py | 13 +- integration/test_login_flow.py | 5 +- integration/test_upload.py | 24 +- reflex/.templates/web/utils/state.js | 4 +- reflex/compiler/templates.py | 2 +- reflex/constants/compiler.py | 14 +- reflex/state.py | 30 ++- reflex/testing.py | 28 +++ tests/components/base/test_script.py | 6 +- tests/components/core/test_colors.py | 23 +- tests/components/core/test_cond.py | 4 +- tests/components/core/test_foreach.py | 22 +- tests/components/core/test_match.py | 18 +- .../components/datadisplay/test_datatable.py | 12 +- tests/components/test_component.py | 6 +- tests/middleware/conftest.py | 3 +- tests/test_app.py | 206 +++++++--------- tests/test_state.py | 220 +++++++++++------- tests/test_state_tree.py | 24 +- tests/test_var.py | 58 +++-- tests/utils/test_format.py | 13 +- 28 files changed, 479 insertions(+), 413 deletions(-) diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 03dbd1681..b69009a09 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -229,7 +229,8 @@ async def test_client_side_state( local_storage: Local storage helper. session_storage: Session storage helper. """ - assert client_side.app_instance is not None + app = client_side.app_instance + assert app is not None assert client_side.frontend_url is not None def poll_for_token(): @@ -333,29 +334,37 @@ async def test_client_side_state( set_sub_sub("l1s", "l1s value") set_sub_sub("s1s", "s1s value") + state_name = client_side.get_full_state_name(["_client_side_state"]) + sub_state_name = client_side.get_full_state_name( + ["_client_side_state", "_client_side_sub_state"] + ) + sub_sub_state_name = client_side.get_full_state_name( + ["_client_side_state", "_client_side_sub_state", "_client_side_sub_sub_state"] + ) + exp_cookies = { - "state.client_side_state.client_side_sub_state.c1": { + f"{sub_state_name}.c1": { "domain": "localhost", "httpOnly": False, - "name": "state.client_side_state.client_side_sub_state.c1", + "name": f"{sub_state_name}.c1", "path": "/", "sameSite": "Lax", "secure": False, "value": "c1%20value", }, - "state.client_side_state.client_side_sub_state.c2": { + f"{sub_state_name}.c2": { "domain": "localhost", "httpOnly": False, - "name": "state.client_side_state.client_side_sub_state.c2", + "name": f"{sub_state_name}.c2", "path": "/", "sameSite": "Lax", "secure": False, "value": "c2%20value", }, - "state.client_side_state.client_side_sub_state.c4": { + f"{sub_state_name}.c4": { "domain": "localhost", "httpOnly": False, - "name": "state.client_side_state.client_side_sub_state.c4", + "name": f"{sub_state_name}.c4", "path": "/", "sameSite": "Strict", "secure": False, @@ -370,19 +379,19 @@ async def test_client_side_state( "secure": False, "value": "c6%20value", }, - "state.client_side_state.client_side_sub_state.c7": { + f"{sub_state_name}.c7": { "domain": "localhost", "httpOnly": False, - "name": "state.client_side_state.client_side_sub_state.c7", + "name": f"{sub_state_name}.c7", "path": "/", "sameSite": "Lax", "secure": False, "value": "c7%20value", }, - "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": { + f"{sub_sub_state_name}.c1s": { "domain": "localhost", "httpOnly": False, - "name": "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s", + "name": f"{sub_sub_state_name}.c1s", "path": "/", "sameSite": "Lax", "secure": False, @@ -400,18 +409,13 @@ async def test_client_side_state( # Test cookie with expiry by itself to avoid timing flakiness set_sub("c3", "c3 value") - AppHarness._poll_for( - lambda: "state.client_side_state.client_side_sub_state.c3" - in cookie_info_map(driver) - ) - c3_cookie = cookie_info_map(driver)[ - "state.client_side_state.client_side_sub_state.c3" - ] + AppHarness._poll_for(lambda: f"{sub_state_name}.c3" in cookie_info_map(driver)) + c3_cookie = cookie_info_map(driver)[f"{sub_state_name}.c3"] assert c3_cookie.pop("expiry") is not None assert c3_cookie == { "domain": "localhost", "httpOnly": False, - "name": "state.client_side_state.client_side_sub_state.c3", + "name": f"{sub_state_name}.c3", "path": "/", "sameSite": "Lax", "secure": False, @@ -420,52 +424,24 @@ async def test_client_side_state( time.sleep(2) # wait for c3 to expire if not isinstance(driver, Firefox): # Note: Firefox does not remove expired cookies Bug 576347 - assert ( - "state.client_side_state.client_side_sub_state.c3" - not in cookie_info_map(driver) - ) + assert f"{sub_state_name}.c3" not in cookie_info_map(driver) local_storage_items = local_storage.items() local_storage_items.pop("chakra-ui-color-mode", None) local_storage_items.pop("last_compiled_time", None) - assert ( - local_storage_items.pop("state.client_side_state.client_side_sub_state.l1") - == "l1 value" - ) - assert ( - local_storage_items.pop("state.client_side_state.client_side_sub_state.l2") - == "l2 value" - ) + assert local_storage_items.pop(f"{sub_state_name}.l1") == "l1 value" + assert local_storage_items.pop(f"{sub_state_name}.l2") == "l2 value" assert local_storage_items.pop("l3") == "l3 value" - assert ( - local_storage_items.pop("state.client_side_state.client_side_sub_state.l4") - == "l4 value" - ) - assert ( - local_storage_items.pop( - "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s" - ) - == "l1s value" - ) + assert local_storage_items.pop(f"{sub_state_name}.l4") == "l4 value" + assert local_storage_items.pop(f"{sub_sub_state_name}.l1s") == "l1s value" assert not local_storage_items session_storage_items = session_storage.items() session_storage_items.pop("token", None) - assert ( - session_storage_items.pop("state.client_side_state.client_side_sub_state.s1") - == "s1 value" - ) - assert ( - session_storage_items.pop("state.client_side_state.client_side_sub_state.s2") - == "s2 value" - ) + assert session_storage_items.pop(f"{sub_state_name}.s1") == "s1 value" + assert session_storage_items.pop(f"{sub_state_name}.s2") == "s2 value" assert session_storage_items.pop("s3") == "s3 value" - assert ( - session_storage_items.pop( - "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.s1s" - ) - == "s1s value" - ) + assert session_storage_items.pop(f"{sub_sub_state_name}.s1s") == "s1s value" assert not session_storage_items assert c1.text == "c1 value" @@ -528,7 +504,7 @@ async def test_client_side_state( assert s1s.text == "s1s value" # reset the backend state to force refresh from client storage - async with client_side.modify_state(f"{token}_state.client_side_state") as state: + async with client_side.modify_state(f"{token}_{state_name}") as state: state.reset() driver.refresh() @@ -576,16 +552,11 @@ async def test_client_side_state( assert s1s.text == "s1s value" # make sure c5 cookie shows up on the `/foo` route - AppHarness._poll_for( - lambda: "state.client_side_state.client_side_sub_state.c5" - in cookie_info_map(driver) - ) - assert cookie_info_map(driver)[ - "state.client_side_state.client_side_sub_state.c5" - ] == { + AppHarness._poll_for(lambda: f"{sub_state_name}.c5" in cookie_info_map(driver)) + assert cookie_info_map(driver)[f"{sub_state_name}.c5"] == { "domain": "localhost", "httpOnly": False, - "name": "state.client_side_state.client_side_sub_state.c5", + "name": f"{sub_state_name}.c5", "path": "/foo/", "sameSite": "Lax", "secure": False, diff --git a/integration/test_computed_vars.py b/integration/test_computed_vars.py index 1b86f0ac7..28f774de5 100644 --- a/integration/test_computed_vars.py +++ b/integration/test_computed_vars.py @@ -183,8 +183,10 @@ async def test_computed_vars( """ assert computed_vars.app_instance is not None - token = f"{token}_state.state" - state = (await computed_vars.get_state(token)).substates["state"] + state_name = computed_vars.get_state_name("_state") + full_state_name = computed_vars.get_full_state_name(["_state"]) + token = f"{token}_{full_state_name}" + state = (await computed_vars.get_state(token)).substates[state_name] assert state is not None assert state.count1_backend == 0 assert state._count1_backend == 0 @@ -236,7 +238,7 @@ async def test_computed_vars( computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0") == "1" ) - state = (await computed_vars.get_state(token)).substates["state"] + state = (await computed_vars.get_state(token)).substates[state_name] assert state is not None assert state.count1_backend == 1 assert count1_backend.text == "" diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index 570a21af6..e3686ee1a 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -154,18 +154,20 @@ def poll_for_order( Returns: An async function that polls for the order list to match the expected order. """ + dynamic_state_name = dynamic_route.get_state_name("_dynamic_state") + dynamic_state_full_name = dynamic_route.get_full_state_name(["_dynamic_state"]) async def _poll_for_order(exp_order: list[str]): async def _backend_state(): - return await dynamic_route.get_state(f"{token}_state.dynamic_state") + return await dynamic_route.get_state(f"{token}_{dynamic_state_full_name}") async def _check(): return (await _backend_state()).substates[ - "dynamic_state" + dynamic_state_name ].order == exp_order await AppHarness._poll_for_async(_check) - assert (await _backend_state()).substates["dynamic_state"].order == exp_order + assert (await _backend_state()).substates[dynamic_state_name].order == exp_order return _poll_for_order @@ -185,6 +187,7 @@ async def test_on_load_navigate( token: The token visible in the driver browser. poll_for_order: function that polls for the order list to match the expected order. """ + dynamic_state_full_name = dynamic_route.get_full_state_name(["_dynamic_state"]) assert dynamic_route.app_instance is not None is_prod = isinstance(dynamic_route, AppHarnessProd) link = driver.find_element(By.ID, "link_page_next") @@ -234,7 +237,7 @@ async def test_on_load_navigate( driver.get(f"{driver.current_url}?foo=bar") await poll_for_order(exp_order) assert ( - await dynamic_route.get_state(f"{token}_state.dynamic_state") + await dynamic_route.get_state(f"{token}_{dynamic_state_full_name}") ).router.page.params["foo"] == "bar" # hit a 404 and ensure we still hydrate diff --git a/integration/test_event_actions.py b/integration/test_event_actions.py index 1e9be064b..e2704fa70 100644 --- a/integration/test_event_actions.py +++ b/integration/test_event_actions.py @@ -229,20 +229,18 @@ def poll_for_order( Returns: An async function that polls for the order list to match the expected order. """ + state_name = event_action.get_state_name("_event_action_state") + state_full_name = event_action.get_full_state_name(["_event_action_state"]) async def _poll_for_order(exp_order: list[str]): async def _backend_state(): - return await event_action.get_state(f"{token}_state.event_action_state") + return await event_action.get_state(f"{token}_{state_full_name}") async def _check(): - return (await _backend_state()).substates[ - "event_action_state" - ].order == exp_order + return (await _backend_state()).substates[state_name].order == exp_order await AppHarness._poll_for_async(_check) - assert (await _backend_state()).substates[ - "event_action_state" - ].order == exp_order + assert (await _backend_state()).substates[state_name].order == exp_order return _poll_for_order diff --git a/integration/test_event_chain.py b/integration/test_event_chain.py index 2a686d5c1..b0feb3e46 100644 --- a/integration/test_event_chain.py +++ b/integration/test_event_chain.py @@ -301,7 +301,8 @@ def assert_token(event_chain: AppHarness, driver: WebDriver) -> str: token = event_chain.poll_for_value(token_input) assert token is not None - return f"{token}_state.state" + state_name = event_chain.get_full_state_name(["_state"]) + return f"{token}_{state_name}" @pytest.mark.parametrize( @@ -400,16 +401,17 @@ async def test_event_chain_click( exp_event_order: the expected events recorded in the State """ token = assert_token(event_chain, driver) + state_name = event_chain.get_state_name("_state") btn = driver.find_element(By.ID, button_id) btn.click() async def _has_all_events(): return len( - (await event_chain.get_state(token)).substates["state"].event_order + (await event_chain.get_state(token)).substates[state_name].event_order ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).substates["state"].event_order + event_order = (await event_chain.get_state(token)).substates[state_name].event_order assert event_order == exp_event_order @@ -454,14 +456,15 @@ async def test_event_chain_on_load( assert event_chain.frontend_url is not None driver.get(event_chain.frontend_url + uri) token = assert_token(event_chain, driver) + state_name = event_chain.get_state_name("_state") async def _has_all_events(): return len( - (await event_chain.get_state(token)).substates["state"].event_order + (await event_chain.get_state(token)).substates[state_name].event_order ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) - backend_state = (await event_chain.get_state(token)).substates["state"] + backend_state = (await event_chain.get_state(token)).substates[state_name] assert backend_state.event_order == exp_event_order assert backend_state.is_hydrated is True @@ -526,6 +529,7 @@ async def test_event_chain_on_mount( assert event_chain.frontend_url is not None driver.get(event_chain.frontend_url + uri) token = assert_token(event_chain, driver) + state_name = event_chain.get_state_name("_state") unmount_button = driver.find_element(By.ID, "unmount") assert unmount_button @@ -533,11 +537,11 @@ async def test_event_chain_on_mount( async def _has_all_events(): return len( - (await event_chain.get_state(token)).substates["state"].event_order + (await event_chain.get_state(token)).substates[state_name].event_order ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).substates["state"].event_order + event_order = (await event_chain.get_state(token)).substates[state_name].event_order assert event_order == exp_event_order diff --git a/integration/test_exception_handlers.py b/integration/test_exception_handlers.py index 8ba1faa89..00683c48b 100644 --- a/integration/test_exception_handlers.py +++ b/integration/test_exception_handlers.py @@ -21,6 +21,8 @@ def TestApp(): class TestAppConfig(rx.Config): """Config for the TestApp app.""" + pass + class TestAppState(rx.State): """State for the TestApp app.""" diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py index 572476dc3..c9eb45146 100644 --- a/integration/test_form_submit.py +++ b/integration/test_form_submit.py @@ -252,10 +252,13 @@ async def test_submit(driver, form_submit: AppHarness): submit_input = driver.find_element(By.CLASS_NAME, "rt-Button") submit_input.click() + state_name = form_submit.get_state_name("_form_state") + full_state_name = form_submit.get_full_state_name(["_form_state"]) + async def get_form_data(): return ( - (await form_submit.get_state(f"{token}_state.form_state")) - .substates["form_state"] + (await form_submit.get_state(f"{token}_{full_state_name}")) + .substates[state_name] .form_data ) diff --git a/integration/test_input.py b/integration/test_input.py index 0dd8df6c6..4679104a4 100644 --- a/integration/test_input.py +++ b/integration/test_input.py @@ -86,9 +86,12 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): token = fully_controlled_input.poll_for_value(token_input) assert token + state_name = fully_controlled_input.get_state_name("_state") + full_state_name = fully_controlled_input.get_full_state_name(["_state"]) + async def get_state_text(): - state = await fully_controlled_input.get_state(f"{token}_state.state") - return state.substates["state"].text + state = await fully_controlled_input.get_state(f"{token}_{full_state_name}") + return state.substates[state_name].text # ensure defaults are set correctly assert ( @@ -138,8 +141,10 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): assert fully_controlled_input.poll_for_value(plain_value_input) == "ifoonitial" # clear the input on the backend - async with fully_controlled_input.modify_state(f"{token}_state.state") as state: - state.substates["state"].text = "" + async with fully_controlled_input.modify_state( + f"{token}_{full_state_name}" + ) as state: + state.substates[state_name].text = "" assert await get_state_text() == "" assert ( fully_controlled_input.poll_for_value( diff --git a/integration/test_login_flow.py b/integration/test_login_flow.py index bb1db536f..7d583e433 100644 --- a/integration/test_login_flow.py +++ b/integration/test_login_flow.py @@ -137,6 +137,9 @@ def test_login_flow( logout_button = driver.find_element(By.ID, "logout") logout_button.click() - assert login_sample._poll_for(lambda: local_storage["state.state.auth_token"] == "") + state_name = login_sample.get_full_state_name(["_state"]) + assert login_sample._poll_for( + lambda: local_storage[f"{state_name}.auth_token"] == "" + ) with pytest.raises(NoSuchElementException): driver.find_element(By.ID, "auth-token") diff --git a/integration/test_upload.py b/integration/test_upload.py index 1ddad530e..77199ca4e 100644 --- a/integration/test_upload.py +++ b/integration/test_upload.py @@ -174,7 +174,9 @@ async def test_upload_file( # wait for the backend connection to send the token token = upload_file.poll_for_value(token_input) assert token is not None - substate_token = f"{token}_state.upload_state" + full_state_name = upload_file.get_full_state_name(["_upload_state"]) + state_name = upload_file.get_state_name("_upload_state") + substate_token = f"{token}_{full_state_name}" suffix = "_secondary" if secondary else "" @@ -197,7 +199,7 @@ async def test_upload_file( async def get_file_data(): return ( (await upload_file.get_state(substate_token)) - .substates["upload_state"] + .substates[state_name] ._file_data ) @@ -212,8 +214,8 @@ async def test_upload_file( state = await upload_file.get_state(substate_token) if secondary: # only the secondary form tracks progress and chain events - assert state.substates["upload_state"].event_order.count("upload_progress") == 1 - assert state.substates["upload_state"].event_order.count("chain_event") == 1 + assert state.substates[state_name].event_order.count("upload_progress") == 1 + assert state.substates[state_name].event_order.count("chain_event") == 1 @pytest.mark.asyncio @@ -231,7 +233,9 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): # wait for the backend connection to send the token token = upload_file.poll_for_value(token_input) assert token is not None - substate_token = f"{token}_state.upload_state" + full_state_name = upload_file.get_full_state_name(["_upload_state"]) + state_name = upload_file.get_state_name("_upload_state") + substate_token = f"{token}_{full_state_name}" upload_box = driver.find_element(By.XPATH, "//input[@type='file']") assert upload_box @@ -261,7 +265,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): async def get_file_data(): return ( (await upload_file.get_state(substate_token)) - .substates["upload_state"] + .substates[state_name] ._file_data ) @@ -343,7 +347,9 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive # wait for the backend connection to send the token token = upload_file.poll_for_value(token_input) assert token is not None - substate_token = f"{token}_state.upload_state" + state_name = upload_file.get_state_name("_upload_state") + state_full_name = upload_file.get_full_state_name(["_upload_state"]) + substate_token = f"{token}_{state_full_name}" upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1] upload_button = driver.find_element(By.ID, f"upload_button_secondary") @@ -362,7 +368,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive # look up the backend state and assert on progress state = await upload_file.get_state(substate_token) - assert state.substates["upload_state"].progress_dicts - assert exp_name not in state.substates["upload_state"]._file_data + assert state.substates[state_name].progress_dicts + assert exp_name not in state.substates[state_name]._file_data target_file.unlink() diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index f7fbf17f1..be536c1d7 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -117,7 +117,7 @@ export const isStateful = () => { if (event_queue.length === 0) { return false; } - return event_queue.some(event => event.name.startsWith("state")); + return event_queue.some(event => event.name.startsWith("reflex___state")); } /** @@ -763,7 +763,7 @@ export const useEventLoop = ( const vars = {}; vars[storage_to_state_map[e.key]] = e.newValue; const event = Event( - `${state_name}.update_vars_internal_state.update_vars_internal`, + `${state_name}.reflex___state____update_vars_internal_state.update_vars_internal`, { vars: vars } ); addEvents([event], e); diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index 7e900a68d..c868a0cbb 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -44,7 +44,7 @@ class ReflexJinjaEnvironment(Environment): "hydrate": constants.CompileVars.HYDRATE, "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL, "update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL, - "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE, + "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL, } diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 83bef4429..1de3fc263 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -62,11 +62,17 @@ class CompileVars(SimpleNamespace): # The name of the function for converting a dict to an event. TO_EVENT = "Event" # The name of the internal on_load event. - ON_LOAD_INTERNAL = "on_load_internal_state.on_load_internal" + ON_LOAD_INTERNAL = "reflex___state____on_load_internal_state.on_load_internal" # The name of the internal event to update generic state vars. - UPDATE_VARS_INTERNAL = "update_vars_internal_state.update_vars_internal" + UPDATE_VARS_INTERNAL = ( + "reflex___state____update_vars_internal_state.update_vars_internal" + ) # The name of the frontend event exception state - FRONTEND_EXCEPTION_STATE = "state.frontend_event_exception_state" + FRONTEND_EXCEPTION_STATE = "reflex___state____frontend_event_exception_state" + # The full name of the frontend exception state + FRONTEND_EXCEPTION_STATE_FULL = ( + f"reflex___state____state.{FRONTEND_EXCEPTION_STATE}" + ) class PageNames(SimpleNamespace): @@ -129,7 +135,7 @@ class Hooks(SimpleNamespace): FRONTEND_ERRORS = f""" const logFrontendError = (error, info) => {{ if (process.env.NODE_ENV === "production") {{ - addEvents([Event("{CompileVars.FRONTEND_EXCEPTION_STATE}.handle_frontend_exception", {{ + addEvents([Event("{CompileVars.FRONTEND_EXCEPTION_STATE_FULL}.handle_frontend_exception", {{ stack: error.stack, }})]) }} diff --git a/reflex/state.py b/reflex/state.py index c8d970a4f..49b5bd4a4 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -426,6 +426,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if isinstance(v, ComputedVar) ] + @classmethod + def _validate_module_name(cls) -> None: + """Check if the module name is valid. + + Reflex uses ___ as state name module separator. + + Raises: + NameError: If the module name is invalid. + """ + if "___" in cls.__module__: + raise NameError( + "The module name of a State class cannot contain '___'. " + "Please rename the module." + ) + @classmethod def __init_subclass__(cls, mixin: bool = False, **kwargs): """Do some magic for the subclass initialization. @@ -445,8 +460,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if mixin: return + # Validate the module name. + cls._validate_module_name() + # Event handlers should not shadow builtin state methods. cls._check_overridden_methods() + # Computed vars should not shadow builtin state props. cls._check_overriden_basevars() @@ -463,20 +482,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls.inherited_backend_vars = parent_state.backend_vars # Check if another substate class with the same name has already been defined. - if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses): + if cls.get_name() in set( + c.get_name() for c in parent_state.class_subclasses + ): if is_testing_env(): # Clear existing subclass with same name when app is reloaded via # utils.prerequisites.get_app(reload=True) parent_state.class_subclasses = set( c for c in parent_state.class_subclasses - if c.__name__ != cls.__name__ + if c.get_name() != cls.get_name() ) else: # During normal operation, subclasses cannot have the same name, even if they are # defined in different modules. raise StateValueError( - f"The substate class '{cls.__name__}' has been defined multiple times. " + f"The substate class '{cls.get_name()}' has been defined multiple times. " "Shadowing substate classes is not allowed." ) # Track this new subclass in the parent state's subclasses set. @@ -759,7 +780,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Returns: The name of the state. """ - return format.to_snake_case(cls.__name__) + module = cls.__module__.replace(".", "___") + return format.to_snake_case(f"{module}___{cls.__name__}") @classmethod @functools.lru_cache() diff --git a/reflex/testing.py b/reflex/testing.py index 135286b4a..bde218c5e 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -40,6 +40,7 @@ import reflex import reflex.reflex import reflex.utils.build import reflex.utils.exec +import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes from reflex.state import ( @@ -177,6 +178,33 @@ class AppHarness: app_module_path=root / app_name / f"{app_name}.py", ) + def get_state_name(self, state_cls_name: str) -> str: + """Get the state name for the given state class name. + + Args: + state_cls_name: The state class name + + Returns: + The state name + """ + return reflex.utils.format.to_snake_case( + f"{self.app_name}___{self.app_name}___" + state_cls_name + ) + + def get_full_state_name(self, path: List[str]) -> str: + """Get the full state name for the given state class name. + + Args: + path: A list of state class names + + Returns: + The full state name + """ + # NOTE: using State.get_name() somehow causes trouble here + # path = [State.get_name()] + [self.get_state_name(p) for p in path] + path = ["reflex___state____state"] + [self.get_state_name(p) for p in path] + return ".".join(path) + def _get_globals_from_signature(self, func: Any) -> dict[str, Any]: """Get the globals from a function or module object. diff --git a/tests/components/base/test_script.py b/tests/components/base/test_script.py index 16239791b..e06914258 100644 --- a/tests/components/base/test_script.py +++ b/tests/components/base/test_script.py @@ -58,14 +58,14 @@ def test_script_event_handler(): ) render_dict = component.render() assert ( - 'onReady={(_e) => addEvents([Event("ev_state.on_ready", {})], (_e), {})}' + f'onReady={{(_e) => addEvents([Event("{EvState.get_full_name()}.on_ready", {{}})], (_e), {{}})}}' in render_dict["props"] ) assert ( - 'onLoad={(_e) => addEvents([Event("ev_state.on_load", {})], (_e), {})}' + f'onLoad={{(_e) => addEvents([Event("{EvState.get_full_name()}.on_load", {{}})], (_e), {{}})}}' in render_dict["props"] ) assert ( - 'onError={(_e) => addEvents([Event("ev_state.on_error", {})], (_e), {})}' + f'onError={{(_e) => addEvents([Event("{EvState.get_full_name()}.on_error", {{}})], (_e), {{}})}}' in render_dict["props"] ) diff --git a/tests/components/core/test_colors.py b/tests/components/core/test_colors.py index 28078059f..0c14c8a7c 100644 --- a/tests/components/core/test_colors.py +++ b/tests/components/core/test_colors.py @@ -14,6 +14,9 @@ class ColorState(rx.State): shade: int = 4 +color_state_name = ColorState.get_full_name().replace(".", "__") + + def create_color_var(color): return Var.create(color) @@ -26,27 +29,27 @@ def create_color_var(color): (create_color_var(rx.color("mint", 3, True)), "var(--mint-a3)"), ( create_color_var(rx.color(ColorState.color, ColorState.shade)), # type: ignore - "var(--${state__color_state.color}-${state__color_state.shade})", + f"var(--${{{color_state_name}.color}}-${{{color_state_name}.shade}})", ), ( create_color_var(rx.color(f"{ColorState.color}", f"{ColorState.shade}")), # type: ignore - "var(--${state__color_state.color}-${state__color_state.shade})", + f"var(--${{{color_state_name}.color}}-${{{color_state_name}.shade}})", ), ( create_color_var( rx.color(f"{ColorState.color_part}ato", f"{ColorState.shade}") # type: ignore ), - "var(--${state__color_state.color_part}ato-${state__color_state.shade})", + f"var(--${{{color_state_name}.color_part}}ato-${{{color_state_name}.shade}})", ), ( create_color_var(f'{rx.color(ColorState.color, f"{ColorState.shade}")}'), # type: ignore - "var(--${state__color_state.color}-${state__color_state.shade})", + f"var(--${{{color_state_name}.color}}-${{{color_state_name}.shade}})", ), ( create_color_var( f'{rx.color(f"{ColorState.color}", f"{ColorState.shade}")}' # type: ignore ), - "var(--${state__color_state.color}-${state__color_state.shade})", + f"var(--${{{color_state_name}.color}}-${{{color_state_name}.shade}})", ), ], ) @@ -68,7 +71,7 @@ def test_color(color, expected): ), ( rx.cond(True, rx.color(ColorState.color), rx.color(ColorState.color, 5)), # type: ignore - "{isTrue(true) ? `var(--${state__color_state.color}-7)` : `var(--${state__color_state.color}-5)`}", + f"{{isTrue(true) ? `var(--${{{color_state_name}.color}}-7)` : `var(--${{{color_state_name}.color}}-5)`}}", ), ( rx.match( @@ -79,7 +82,7 @@ def test_color(color, expected): ), "{(() => { switch (JSON.stringify(`condition`)) {case JSON.stringify(`first`): return (`var(--mint-7)`);" " break;case JSON.stringify(`second`): return (`var(--tomato-5)`); break;default: " - "return (`var(--${state__color_state.color}-2)`); break;};})()}", + f"return (`var(--${{{color_state_name}.color}}-2)`); break;}};}})()}}", ), ( rx.match( @@ -89,9 +92,9 @@ def test_color(color, expected): rx.color(ColorState.color, 2), # type: ignore ), "{(() => { switch (JSON.stringify(`condition`)) {case JSON.stringify(`first`): " - "return (`var(--${state__color_state.color}-7)`); break;case JSON.stringify(`second`): " - "return (`var(--${state__color_state.color}-5)`); break;default: " - "return (`var(--${state__color_state.color}-2)`); break;};})()}", + f"return (`var(--${{{color_state_name}.color}}-7)`); break;case JSON.stringify(`second`): " + f"return (`var(--${{{color_state_name}.color}}-5)`); break;default: " + f"return (`var(--${{{color_state_name}.color}}-2)`); break;}};}})()}}", ), ], ) diff --git a/tests/components/core/test_cond.py b/tests/components/core/test_cond.py index 4bfa902af..e18be2217 100644 --- a/tests/components/core/test_cond.py +++ b/tests/components/core/test_cond.py @@ -34,7 +34,7 @@ def test_f_string_cond_interpolation(): ], indirect=True, ) -def test_validate_cond(cond_state: Var): +def test_validate_cond(cond_state: BaseState): """Test if cond can be a rx.Var with any values. Args: @@ -49,7 +49,7 @@ def test_validate_cond(cond_state: Var): assert cond_dict["name"] == "Fragment" [condition] = cond_dict["children"] - assert condition["cond_state"] == "isTrue(cond_state.value)" + assert condition["cond_state"] == f"isTrue({cond_state.get_full_name()}.value)" # true value true_value = condition["true_value"] diff --git a/tests/components/core/test_foreach.py b/tests/components/core/test_foreach.py index 3fe38def9..6e4fa6b5e 100644 --- a/tests/components/core/test_foreach.py +++ b/tests/components/core/test_foreach.py @@ -134,7 +134,7 @@ seen_index_vars = set() ForEachState.colors_list, display_color, { - "iterable_state": "for_each_state.colors_list", + "iterable_state": f"{ForEachState.get_full_name()}.colors_list", "iterable_type": "list", }, ), @@ -142,7 +142,7 @@ seen_index_vars = set() ForEachState.colors_dict_list, display_color_name, { - "iterable_state": "for_each_state.colors_dict_list", + "iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list", "iterable_type": "list", }, ), @@ -150,7 +150,7 @@ seen_index_vars = set() ForEachState.colors_nested_dict_list, display_shade, { - "iterable_state": "for_each_state.colors_nested_dict_list", + "iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list", "iterable_type": "list", }, ), @@ -158,7 +158,7 @@ seen_index_vars = set() ForEachState.primary_color, display_primary_colors, { - "iterable_state": "for_each_state.primary_color", + "iterable_state": f"{ForEachState.get_full_name()}.primary_color", "iterable_type": "dict", }, ), @@ -166,7 +166,7 @@ seen_index_vars = set() ForEachState.color_with_shades, display_color_with_shades, { - "iterable_state": "for_each_state.color_with_shades", + "iterable_state": f"{ForEachState.get_full_name()}.color_with_shades", "iterable_type": "dict", }, ), @@ -174,7 +174,7 @@ seen_index_vars = set() ForEachState.nested_colors_with_shades, display_nested_color_with_shades, { - "iterable_state": "for_each_state.nested_colors_with_shades", + "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades", "iterable_type": "dict", }, ), @@ -182,7 +182,7 @@ seen_index_vars = set() ForEachState.nested_colors_with_shades, display_nested_color_with_shades_v2, { - "iterable_state": "for_each_state.nested_colors_with_shades", + "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades", "iterable_type": "dict", }, ), @@ -190,7 +190,7 @@ seen_index_vars = set() ForEachState.color_tuple, display_color_tuple, { - "iterable_state": "for_each_state.color_tuple", + "iterable_state": f"{ForEachState.get_full_name()}.color_tuple", "iterable_type": "tuple", }, ), @@ -198,7 +198,7 @@ seen_index_vars = set() ForEachState.colors_set, display_colors_set, { - "iterable_state": "for_each_state.colors_set", + "iterable_state": f"{ForEachState.get_full_name()}.colors_set", "iterable_type": "set", }, ), @@ -206,7 +206,7 @@ seen_index_vars = set() ForEachState.nested_colors_list, lambda el, i: display_nested_list_element(el, i), { - "iterable_state": "for_each_state.nested_colors_list", + "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list", "iterable_type": "list", }, ), @@ -214,7 +214,7 @@ seen_index_vars = set() ForEachState.color_index_tuple, display_color_index_tuple, { - "iterable_state": "for_each_state.color_index_tuple", + "iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple", "iterable_type": "tuple", }, ), diff --git a/tests/components/core/test_match.py b/tests/components/core/test_match.py index 5a0e46e9e..883386ebd 100644 --- a/tests/components/core/test_match.py +++ b/tests/components/core/test_match.py @@ -35,7 +35,7 @@ def test_match_components(): [match_child] = match_dict["children"] assert match_child["name"] == "match" - assert str(match_child["cond"]) == "{match_state.value}" + assert str(match_child["cond"]) == f"{{{MatchState.get_name()}.value}}" match_cases = match_child["match_cases"] assert len(match_cases) == 6 @@ -72,7 +72,7 @@ def test_match_components(): assert fifth_return_value_render["name"] == "RadixThemesText" assert fifth_return_value_render["children"][0]["contents"] == "{`fifth value`}" - assert match_cases[5][0]._var_name == "((match_state.num) + (1))" + assert match_cases[5][0]._var_name == f"(({MatchState.get_name()}.num) + (1))" assert match_cases[5][0]._var_type == int fifth_return_value_render = match_cases[5][1].render() assert fifth_return_value_render["name"] == "RadixThemesText" @@ -99,11 +99,11 @@ def test_match_components(): (MatchState.string, f"{MatchState.value} - string"), "default value", ), - "(() => { switch (JSON.stringify(match_state.value)) {case JSON.stringify(1): return (`first`); break;case JSON.stringify(2): case JSON.stringify(3): return " + f"(() => {{ switch (JSON.stringify({MatchState.get_name()}.value)) {{case JSON.stringify(1): return (`first`); break;case JSON.stringify(2): case JSON.stringify(3): return " "(`second value`); break;case JSON.stringify([1, 2]): return (`third-value`); break;case JSON.stringify(`random`): " 'return (`fourth_value`); break;case JSON.stringify({"foo": "bar"}): return (`fifth value`); ' - "break;case JSON.stringify(((match_state.num) + (1))): return (`sixth value`); break;case JSON.stringify(`${match_state.value} - string`): " - "return (match_state.string); break;case JSON.stringify(match_state.string): return (`${match_state.value} - string`); break;default: " + f"break;case JSON.stringify((({MatchState.get_name()}.num) + (1))): return (`sixth value`); break;case JSON.stringify(`${{{MatchState.get_name()}.value}} - string`): " + f"return ({MatchState.get_name()}.string); break;case JSON.stringify({MatchState.get_name()}.string): return (`${{{MatchState.get_name()}.value}} - string`); break;default: " "return (`default value`); break;};})()", ), ( @@ -118,12 +118,12 @@ def test_match_components(): (MatchState.string, f"{MatchState.value} - string"), MatchState.string, ), - "(() => { switch (JSON.stringify(match_state.value)) {case JSON.stringify(1): return (`first`); break;case JSON.stringify(2): case JSON.stringify(3): return " + f"(() => {{ switch (JSON.stringify({MatchState.get_name()}.value)) {{case JSON.stringify(1): return (`first`); break;case JSON.stringify(2): case JSON.stringify(3): return " "(`second value`); break;case JSON.stringify([1, 2]): return (`third-value`); break;case JSON.stringify(`random`): " 'return (`fourth_value`); break;case JSON.stringify({"foo": "bar"}): return (`fifth value`); ' - "break;case JSON.stringify(((match_state.num) + (1))): return (`sixth value`); break;case JSON.stringify(`${match_state.value} - string`): " - "return (match_state.string); break;case JSON.stringify(match_state.string): return (`${match_state.value} - string`); break;default: " - "return (match_state.string); break;};})()", + f"break;case JSON.stringify((({MatchState.get_name()}.num) + (1))): return (`sixth value`); break;case JSON.stringify(`${{{MatchState.get_name()}.value}} - string`): " + f"return ({MatchState.get_name()}.string); break;case JSON.stringify({MatchState.get_name()}.string): return (`${{{MatchState.get_name()}.value}} - string`); break;default: " + f"return ({MatchState.get_name()}.string); break;}};}})()", ), ], ) diff --git a/tests/components/datadisplay/test_datatable.py b/tests/components/datadisplay/test_datatable.py index 2557be62b..b3d31ea32 100644 --- a/tests/components/datadisplay/test_datatable.py +++ b/tests/components/datadisplay/test_datatable.py @@ -16,14 +16,14 @@ from reflex.utils.serializers import serialize, serialize_dataframe [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"] ) }, - "data_table_state.data", + "data", ), - pytest.param({"data": ["foo", "bar"]}, "data_table_state"), - pytest.param({"data": [["foo", "bar"], ["foo1", "bar1"]]}, "data_table_state"), + pytest.param({"data": ["foo", "bar"]}, ""), + pytest.param({"data": [["foo", "bar"], ["foo1", "bar1"]]}, ""), ], indirect=["data_table_state"], ) -def test_validate_data_table(data_table_state: rx.Var, expected): +def test_validate_data_table(data_table_state: rx.State, expected): """Test the str/render function. Args: @@ -40,6 +40,10 @@ def test_validate_data_table(data_table_state: rx.Var, expected): data_table_dict = data_table_component.render() + # prefix expected with state name + state_name = data_table_state.get_name() + expected = f"{state_name}.{expected}" if expected else state_name + assert data_table_dict["props"] == [ f"columns={{{expected}.columns}}", f"data={{{expected}.data}}", diff --git a/tests/components/test_component.py b/tests/components/test_component.py index b014088ba..64354ada9 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -825,7 +825,7 @@ def test_component_event_trigger_arbitrary_args(): assert comp.render()["props"][0] == ( "onFoo={(__e,_alpha,_bravo,_charlie) => addEvents(" - '[Event("c1_state.mock_handler", {_e:__e.target.value,_bravo:_bravo["nested"],_charlie:((_charlie.custom) + (42))})], ' + f'[Event("{C1State.get_full_name()}.mock_handler", {{_e:__e.target.value,_bravo:_bravo["nested"],_charlie:((_charlie.custom) + (42))}})], ' "(__e,_alpha,_bravo,_charlie), {})}" ) @@ -2037,14 +2037,14 @@ def test_add_style_embedded_vars(test_state: BaseState): page._add_style_recursive(Style()) assert ( - "const test_state = useContext(StateContexts.test_state)" + f"const {test_state.get_name()} = useContext(StateContexts.{test_state.get_name()})" in page._get_all_hooks_internal() ) assert "useText" in page._get_all_hooks_internal() assert "useParent" in page._get_all_hooks_internal() assert ( str(page).count( - 'css={{"fakeParent": "parent", "color": "var(--plum-10)", "fake": "text", "margin": `${test_state.num}%`}}' + f'css={{{{"fakeParent": "parent", "color": "var(--plum-10)", "fake": "text", "margin": `${{{test_state.get_name()}.num}}%`}}}}' ) == 1 ) diff --git a/tests/middleware/conftest.py b/tests/middleware/conftest.py index 5a1897110..d786db652 100644 --- a/tests/middleware/conftest.py +++ b/tests/middleware/conftest.py @@ -1,6 +1,7 @@ import pytest from reflex.event import Event +from reflex.state import State def create_event(name): @@ -21,4 +22,4 @@ def create_event(name): @pytest.fixture def event1(): - return create_event("state.hydrate") + return create_event(f"{State.get_name()}.hydrate") diff --git a/tests/test_app.py b/tests/test_app.py index f41d46c7c..489ace511 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -485,20 +485,12 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str): pytest.param( [ ( - "list_mutation_test_state.make_friend", - { - "list_mutation_test_state": { - "plain_friends": ["Tommy", "another-fd"] - } - }, + "make_friend", + {"plain_friends": ["Tommy", "another-fd"]}, ), ( - "list_mutation_test_state.change_first_friend", - { - "list_mutation_test_state": { - "plain_friends": ["Jenny", "another-fd"] - } - }, + "change_first_friend", + {"plain_friends": ["Jenny", "another-fd"]}, ), ], id="append then __setitem__", @@ -506,12 +498,12 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str): pytest.param( [ ( - "list_mutation_test_state.unfriend_first_friend", - {"list_mutation_test_state": {"plain_friends": []}}, + "unfriend_first_friend", + {"plain_friends": []}, ), ( - "list_mutation_test_state.make_friend", - {"list_mutation_test_state": {"plain_friends": ["another-fd"]}}, + "make_friend", + {"plain_friends": ["another-fd"]}, ), ], id="delitem then append", @@ -519,24 +511,20 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str): pytest.param( [ ( - "list_mutation_test_state.make_friends_with_colleagues", - { - "list_mutation_test_state": { - "plain_friends": ["Tommy", "Peter", "Jimmy"] - } - }, + "make_friends_with_colleagues", + {"plain_friends": ["Tommy", "Peter", "Jimmy"]}, ), ( - "list_mutation_test_state.remove_tommy", - {"list_mutation_test_state": {"plain_friends": ["Peter", "Jimmy"]}}, + "remove_tommy", + {"plain_friends": ["Peter", "Jimmy"]}, ), ( - "list_mutation_test_state.remove_last_friend", - {"list_mutation_test_state": {"plain_friends": ["Peter"]}}, + "remove_last_friend", + {"plain_friends": ["Peter"]}, ), ( - "list_mutation_test_state.unfriend_all_friends", - {"list_mutation_test_state": {"plain_friends": []}}, + "unfriend_all_friends", + {"plain_friends": []}, ), ], id="extend, remove, pop, clear", @@ -544,28 +532,16 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str): pytest.param( [ ( - "list_mutation_test_state.add_jimmy_to_second_group", - { - "list_mutation_test_state": { - "friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]] - } - }, + "add_jimmy_to_second_group", + {"friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]}, ), ( - "list_mutation_test_state.remove_first_person_from_first_group", - { - "list_mutation_test_state": { - "friends_in_nested_list": [[], ["Jenny", "Jimmy"]] - } - }, + "remove_first_person_from_first_group", + {"friends_in_nested_list": [[], ["Jenny", "Jimmy"]]}, ), ( - "list_mutation_test_state.remove_first_group", - { - "list_mutation_test_state": { - "friends_in_nested_list": [["Jenny", "Jimmy"]] - } - }, + "remove_first_group", + {"friends_in_nested_list": [["Jenny", "Jimmy"]]}, ), ], id="nested list", @@ -573,24 +549,16 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str): pytest.param( [ ( - "list_mutation_test_state.add_jimmy_to_tommy_friends", - { - "list_mutation_test_state": { - "friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]} - } - }, + "add_jimmy_to_tommy_friends", + {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}, ), ( - "list_mutation_test_state.remove_jenny_from_tommy", - { - "list_mutation_test_state": { - "friends_in_dict": {"Tommy": ["Jimmy"]} - } - }, + "remove_jenny_from_tommy", + {"friends_in_dict": {"Tommy": ["Jimmy"]}}, ), ( - "list_mutation_test_state.tommy_has_no_fds", - {"list_mutation_test_state": {"friends_in_dict": {"Tommy": []}}}, + "tommy_has_no_fds", + {"friends_in_dict": {"Tommy": []}}, ), ], id="list in dict", @@ -614,12 +582,14 @@ async def test_list_mutation_detection__plain_list( result = await list_mutation_state._process( Event( token=token, - name=event_name, + name=f"{list_mutation_state.get_name()}.{event_name}", router_data={"pathname": "/", "query": {}}, payload={}, ) ).__anext__() + # prefix keys in expected_delta with the state name + expected_delta = {list_mutation_state.get_name(): expected_delta} assert result.delta == expected_delta @@ -630,24 +600,16 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "dict_mutation_test_state.add_age", - { - "dict_mutation_test_state": { - "details": {"name": "Tommy", "age": 20} - } - }, + "add_age", + {"details": {"name": "Tommy", "age": 20}}, ), ( - "dict_mutation_test_state.change_name", - { - "dict_mutation_test_state": { - "details": {"name": "Jenny", "age": 20} - } - }, + "change_name", + {"details": {"name": "Jenny", "age": 20}}, ), ( - "dict_mutation_test_state.remove_last_detail", - {"dict_mutation_test_state": {"details": {"name": "Jenny"}}}, + "remove_last_detail", + {"details": {"name": "Jenny"}}, ), ], id="update then __setitem__", @@ -655,12 +617,12 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "dict_mutation_test_state.clear_details", - {"dict_mutation_test_state": {"details": {}}}, + "clear_details", + {"details": {}}, ), ( - "dict_mutation_test_state.add_age", - {"dict_mutation_test_state": {"details": {"age": 20}}}, + "add_age", + {"details": {"age": 20}}, ), ], id="delitem then update", @@ -668,20 +630,16 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "dict_mutation_test_state.add_age", - { - "dict_mutation_test_state": { - "details": {"name": "Tommy", "age": 20} - } - }, + "add_age", + {"details": {"name": "Tommy", "age": 20}}, ), ( - "dict_mutation_test_state.remove_name", - {"dict_mutation_test_state": {"details": {"age": 20}}}, + "remove_name", + {"details": {"age": 20}}, ), ( - "dict_mutation_test_state.pop_out_age", - {"dict_mutation_test_state": {"details": {}}}, + "pop_out_age", + {"details": {}}, ), ], id="add, remove, pop", @@ -689,22 +647,16 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "dict_mutation_test_state.remove_home_address", - { - "dict_mutation_test_state": { - "address": [{}, {"work": "work address"}] - } - }, + "remove_home_address", + {"address": [{}, {"work": "work address"}]}, ), ( - "dict_mutation_test_state.add_street_to_home_address", + "add_street_to_home_address", { - "dict_mutation_test_state": { - "address": [ - {"street": "street address"}, - {"work": "work address"}, - ] - } + "address": [ + {"street": "street address"}, + {"work": "work address"}, + ] }, ), ], @@ -713,34 +665,26 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "dict_mutation_test_state.change_friend_name", + "change_friend_name", { - "dict_mutation_test_state": { - "friend_in_nested_dict": { - "name": "Nikhil", - "friend": {"name": "Tommy"}, - } + "friend_in_nested_dict": { + "name": "Nikhil", + "friend": {"name": "Tommy"}, } }, ), ( - "dict_mutation_test_state.add_friend_age", + "add_friend_age", { - "dict_mutation_test_state": { - "friend_in_nested_dict": { - "name": "Nikhil", - "friend": {"name": "Tommy", "age": 30}, - } + "friend_in_nested_dict": { + "name": "Nikhil", + "friend": {"name": "Tommy", "age": 30}, } }, ), ( - "dict_mutation_test_state.remove_friend", - { - "dict_mutation_test_state": { - "friend_in_nested_dict": {"name": "Nikhil"} - } - }, + "remove_friend", + {"friend_in_nested_dict": {"name": "Nikhil"}}, ), ], id="nested dict", @@ -764,12 +708,15 @@ async def test_dict_mutation_detection__plain_list( result = await dict_mutation_state._process( Event( token=token, - name=event_name, + name=f"{dict_mutation_state.get_name()}.{event_name}", router_data={"pathname": "/", "query": {}}, payload={}, ) ).__anext__() + # prefix keys in expected_delta with the state name + expected_delta = {dict_mutation_state.get_name(): expected_delta} + assert result.delta == expected_delta @@ -779,12 +726,16 @@ async def test_dict_mutation_detection__plain_list( [ ( FileUploadState, - {"state.file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}}, + { + FileUploadState.get_full_name(): { + "img_list": ["image1.jpg", "image2.jpg"] + } + }, ), ( ChildFileUploadState, { - "state.file_state_base1.child_file_upload_state": { + ChildFileUploadState.get_full_name(): { "img_list": ["image1.jpg", "image2.jpg"] } }, @@ -792,7 +743,7 @@ async def test_dict_mutation_detection__plain_list( ( GrandChildFileUploadState, { - "state.file_state_base1.file_state_base2.grand_child_file_upload_state": { + GrandChildFileUploadState.get_full_name(): { "img_list": ["image1.jpg", "image2.jpg"] } }, @@ -1065,7 +1016,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( val=exp_val, ), _event( - name="state.set_is_hydrated", + name=f"{State.get_name()}.set_is_hydrated", payload={"value": True}, val=exp_val, router_data={}, @@ -1188,7 +1139,10 @@ async def test_process_events(mocker, token: str): app = App(state=GenState) mocker.patch.object(app, "_postprocess", AsyncMock()) event = Event( - token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data + token=token, + name=f"{GenState.get_name()}.go", + payload={"c": 5}, + router_data=router_data, ) async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): diff --git a/tests/test_state.py b/tests/test_state.py index d81d88d82..2fc149389 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -217,7 +217,7 @@ def child_state(test_state) -> ChildState: Returns: A test child state. """ - child_state = test_state.get_substate(["child_state"]) + child_state = test_state.get_substate([ChildState.get_name()]) assert child_state is not None return child_state @@ -232,7 +232,7 @@ def child_state2(test_state) -> ChildState2: Returns: A second test child state. """ - child_state2 = test_state.get_substate(["child_state2"]) + child_state2 = test_state.get_substate([ChildState2.get_name()]) assert child_state2 is not None return child_state2 @@ -247,7 +247,7 @@ def grandchild_state(child_state) -> GrandchildState: Returns: A test state. """ - grandchild_state = child_state.get_substate(["grandchild_state"]) + grandchild_state = child_state.get_substate([GrandchildState.get_name()]) assert grandchild_state is not None return grandchild_state @@ -357,20 +357,20 @@ def test_computed_vars(test_state): assert test_state.upper == "HELLO WORLD" -def test_dict(test_state): +def test_dict(test_state: TestState): """Test that the dict representation of a state is correct. Args: test_state: A state. """ substates = { - "test_state", - "test_state.child_state", - "test_state.child_state.grandchild_state", - "test_state.child_state2", - "test_state.child_state2.grandchild_state2", - "test_state.child_state3", - "test_state.child_state3.grandchild_state3", + test_state.get_full_name(), + ChildState.get_full_name(), + GrandchildState.get_full_name(), + ChildState2.get_full_name(), + GrandchildState2.get_full_name(), + ChildState3.get_full_name(), + GrandchildState3.get_full_name(), } test_state_dict = test_state.dict() assert set(test_state_dict) == substates @@ -394,22 +394,30 @@ def test_default_setters(test_state): def test_class_indexing_with_vars(): """Test that we can index into a state var with another var.""" prop = TestState.array[TestState.num1] - assert str(prop) == "{test_state.array.at(test_state.num1)}" + assert ( + str(prop) == f"{{{TestState.get_name()}.array.at({TestState.get_name()}.num1)}}" + ) prop = TestState.mapping["a"][TestState.num1] - assert str(prop) == '{test_state.mapping["a"].at(test_state.num1)}' + assert ( + str(prop) + == f'{{{TestState.get_name()}.mapping["a"].at({TestState.get_name()}.num1)}}' + ) prop = TestState.mapping[TestState.map_key] - assert str(prop) == "{test_state.mapping[test_state.map_key]}" + assert ( + str(prop) + == f"{{{TestState.get_name()}.mapping[{TestState.get_name()}.map_key]}}" + ) def test_class_attributes(): """Test that we can get class attributes.""" prop = TestState.obj.prop1 - assert str(prop) == "{test_state.obj.prop1}" + assert str(prop) == f"{{{TestState.get_name()}.obj.prop1}}" prop = TestState.complex[1].prop1 - assert str(prop) == "{test_state.complex[1].prop1}" + assert str(prop) == f"{{{TestState.get_name()}.complex[1].prop1}}" def test_get_parent_state(): @@ -431,27 +439,40 @@ def test_get_substates(): def test_get_name(): """Test getting the name of a state.""" - assert TestState.get_name() == "test_state" - assert ChildState.get_name() == "child_state" - assert ChildState2.get_name() == "child_state2" - assert GrandchildState.get_name() == "grandchild_state" + assert TestState.get_name() == "tests___test_state____test_state" + assert ChildState.get_name() == "tests___test_state____child_state" + assert ChildState2.get_name() == "tests___test_state____child_state2" + assert GrandchildState.get_name() == "tests___test_state____grandchild_state" def test_get_full_name(): """Test getting the full name.""" - assert TestState.get_full_name() == "test_state" - assert ChildState.get_full_name() == "test_state.child_state" - assert ChildState2.get_full_name() == "test_state.child_state2" - assert GrandchildState.get_full_name() == "test_state.child_state.grandchild_state" + assert TestState.get_full_name() == "tests___test_state____test_state" + assert ( + ChildState.get_full_name() + == "tests___test_state____test_state.tests___test_state____child_state" + ) + assert ( + ChildState2.get_full_name() + == "tests___test_state____test_state.tests___test_state____child_state2" + ) + assert ( + GrandchildState.get_full_name() + == "tests___test_state____test_state.tests___test_state____child_state.tests___test_state____grandchild_state" + ) def test_get_class_substate(): """Test getting the substate of a class.""" - assert TestState.get_class_substate(("child_state",)) == ChildState - assert TestState.get_class_substate(("child_state2",)) == ChildState2 - assert ChildState.get_class_substate(("grandchild_state",)) == GrandchildState + assert TestState.get_class_substate((ChildState.get_name(),)) == ChildState + assert TestState.get_class_substate((ChildState2.get_name(),)) == ChildState2 assert ( - TestState.get_class_substate(("child_state", "grandchild_state")) + ChildState.get_class_substate((GrandchildState.get_name(),)) == GrandchildState + ) + assert ( + TestState.get_class_substate( + (ChildState.get_name(), GrandchildState.get_name()) + ) == GrandchildState ) with pytest.raises(ValueError): @@ -459,7 +480,7 @@ def test_get_class_substate(): with pytest.raises(ValueError): TestState.get_class_substate( ( - "child_state", + ChildState.get_name(), "invalid_child", ) ) @@ -471,13 +492,15 @@ def test_get_class_var(): assert TestState.get_class_var(("num2",)).equals(TestState.num2) assert ChildState.get_class_var(("value",)).equals(ChildState.value) assert GrandchildState.get_class_var(("value2",)).equals(GrandchildState.value2) - assert TestState.get_class_var(("child_state", "value")).equals(ChildState.value) + assert TestState.get_class_var((ChildState.get_name(), "value")).equals( + ChildState.value + ) assert TestState.get_class_var( - ("child_state", "grandchild_state", "value2") + (ChildState.get_name(), GrandchildState.get_name(), "value2") ).equals( GrandchildState.value2, ) - assert ChildState.get_class_var(("grandchild_state", "value2")).equals( + assert ChildState.get_class_var((GrandchildState.get_name(), "value2")).equals( GrandchildState.value2, ) with pytest.raises(ValueError): @@ -485,7 +508,7 @@ def test_get_class_var(): with pytest.raises(ValueError): TestState.get_class_var( ( - "child_state", + ChildState.get_name(), "invalid_var", ) ) @@ -513,11 +536,15 @@ def test_set_parent_and_substates(test_state, child_state, grandchild_state): grandchild_state: A grandchild state. """ assert len(test_state.substates) == 3 - assert set(test_state.substates) == {"child_state", "child_state2", "child_state3"} + assert set(test_state.substates) == { + ChildState.get_name(), + ChildState2.get_name(), + ChildState3.get_name(), + } assert child_state.parent_state == test_state assert len(child_state.substates) == 1 - assert set(child_state.substates) == {"grandchild_state"} + assert set(child_state.substates) == {GrandchildState.get_name()} assert grandchild_state.parent_state == child_state assert len(grandchild_state.substates) == 0 @@ -584,18 +611,21 @@ def test_get_substate(test_state, child_state, child_state2, grandchild_state): child_state2: A child state. grandchild_state: A grandchild state. """ - assert test_state.get_substate(("child_state",)) == child_state - assert test_state.get_substate(("child_state2",)) == child_state2 + assert test_state.get_substate((ChildState.get_name(),)) == child_state + assert test_state.get_substate((ChildState2.get_name(),)) == child_state2 assert ( - test_state.get_substate(("child_state", "grandchild_state")) == grandchild_state + test_state.get_substate((ChildState.get_name(), GrandchildState.get_name())) + == grandchild_state ) - assert child_state.get_substate(("grandchild_state",)) == grandchild_state + assert child_state.get_substate((GrandchildState.get_name(),)) == grandchild_state with pytest.raises(ValueError): test_state.get_substate(("invalid",)) with pytest.raises(ValueError): - test_state.get_substate(("child_state", "invalid")) + test_state.get_substate((ChildState.get_name(), "invalid")) with pytest.raises(ValueError): - test_state.get_substate(("child_state", "grandchild_state", "invalid")) + test_state.get_substate( + (ChildState.get_name(), GrandchildState.get_name(), "invalid") + ) def test_set_dirty_var(test_state): @@ -638,7 +668,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st # Setting a var should mark it as dirty. child_state.value = "test" assert child_state.dirty_vars == {"value"} - assert test_state.dirty_substates == {"child_state"} + assert test_state.dirty_substates == {ChildState.get_name()} assert child_state.dirty_substates == set() # Cleaning the parent state should remove the dirty substate. @@ -648,12 +678,12 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st # Setting a var on the grandchild should bubble up. grandchild_state.value2 = "test2" - assert child_state.dirty_substates == {"grandchild_state"} - assert test_state.dirty_substates == {"child_state"} + assert child_state.dirty_substates == {GrandchildState.get_name()} + assert test_state.dirty_substates == {ChildState.get_name()} # Cleaning the middle state should keep the parent state dirty. child_state._clean() - assert test_state.dirty_substates == {"child_state"} + assert test_state.dirty_substates == {ChildState.get_name()} assert child_state.dirty_substates == set() assert grandchild_state.dirty_vars == set() @@ -698,7 +728,11 @@ def test_reset(test_state, child_state): assert child_state.dirty_vars == {"count", "value"} # The dirty substates should be reset. - assert test_state.dirty_substates == {"child_state", "child_state2", "child_state3"} + assert test_state.dirty_substates == { + ChildState.get_name(), + ChildState2.get_name(), + ChildState3.get_name(), + } @pytest.mark.asyncio @@ -719,8 +753,8 @@ async def test_process_event_simple(test_state): # The delta should contain the changes, including computed vars. # assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}} assert update.delta == { - "test_state": {"num1": 69, "sum": 72.14, "upper": ""}, - "test_state.child_state3.grandchild_state3": {"computed": ""}, + TestState.get_full_name(): {"num1": 69, "sum": 72.14, "upper": ""}, + GrandchildState3.get_full_name(): {"computed": ""}, } assert update.events == [] @@ -738,15 +772,17 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) assert child_state.value == "" assert child_state.count == 23 event = Event( - token="t", name="child_state.change_both", payload={"value": "hi", "count": 12} + token="t", + name=f"{ChildState.get_name()}.change_both", + payload={"value": "hi", "count": 12}, ) update = await test_state._process(event).__anext__() assert child_state.value == "HI" assert child_state.count == 24 assert update.delta == { - "test_state": {"sum": 3.14, "upper": ""}, - "test_state.child_state": {"value": "HI", "count": 24}, - "test_state.child_state3.grandchild_state3": {"computed": ""}, + TestState.get_full_name(): {"sum": 3.14, "upper": ""}, + ChildState.get_full_name(): {"value": "HI", "count": 24}, + GrandchildState3.get_full_name(): {"computed": ""}, } test_state._clean() @@ -754,15 +790,15 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) assert grandchild_state.value2 == "" event = Event( token="t", - name="child_state.grandchild_state.set_value2", + name=f"{GrandchildState.get_full_name()}.set_value2", payload={"value": "new"}, ) update = await test_state._process(event).__anext__() assert grandchild_state.value2 == "new" assert update.delta == { - "test_state": {"sum": 3.14, "upper": ""}, - "test_state.child_state.grandchild_state": {"value2": "new"}, - "test_state.child_state3.grandchild_state3": {"computed": ""}, + TestState.get_full_name(): {"sum": 3.14, "upper": ""}, + GrandchildState.get_full_name(): {"value2": "new"}, + GrandchildState3.get_full_name(): {"computed": ""}, } @@ -786,7 +822,7 @@ async def test_process_event_generator(): else: assert gen_state.value == count assert update.delta == { - "gen_state": {"value": count}, + GenState.get_full_name(): {"value": count}, } assert not update.final @@ -1955,7 +1991,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): mock_app, Event( token=token, - name=f"{BackgroundTaskState.get_name()}.background_task", + name=f"{BackgroundTaskState.get_full_name()}.background_task", router_data=router_data, payload={}, ), @@ -1975,7 +2011,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): mock_app, Event( token=token, - name=f"{BackgroundTaskState.get_name()}.other", + name=f"{BackgroundTaskState.get_full_name()}.other", router_data=router_data, payload={}, ), @@ -1986,7 +2022,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): # other task returns delta assert update == StateUpdate( delta={ - BackgroundTaskState.get_name(): { + BackgroundTaskState.get_full_name(): { "order": [ "background_task:start", "other", @@ -2022,10 +2058,13 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): emit_mock = mock_app.event_namespace.emit first_ws_message = json.loads(emit_mock.mock_calls[0].args[1]) - assert first_ws_message["delta"]["background_task_state"].pop("router") is not None + assert ( + first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router") + is not None + ) assert first_ws_message == { "delta": { - "background_task_state": { + BackgroundTaskState.get_full_name(): { "order": ["background_task:start"], "computed_order": ["background_task:start"], } @@ -2036,14 +2075,16 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): for call in emit_mock.mock_calls[1:5]: assert json.loads(call.args[1]) == { "delta": { - "background_task_state": {"computed_order": ["background_task:start"]} + BackgroundTaskState.get_full_name(): { + "computed_order": ["background_task:start"], + } }, "events": [], "final": True, } assert json.loads(emit_mock.mock_calls[-2].args[1]) == { "delta": { - "background_task_state": { + BackgroundTaskState.get_full_name(): { "order": exp_order, "computed_order": exp_order, "dict_list": {}, @@ -2054,7 +2095,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): } assert json.loads(emit_mock.mock_calls[-1].args[1]) == { "delta": { - "background_task_state": { + BackgroundTaskState.get_full_name(): { "computed_order": exp_order, }, }, @@ -2683,7 +2724,7 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker): assert isinstance(update, StateUpdate) updates.append(update) assert len(updates) == 1 - assert updates[0].delta["state"].pop("router") is not None + assert updates[0].delta[State.get_name()].pop("router") is not None assert updates[0].delta == exp_is_hydrated(state, False) events = updates[0].events @@ -2727,7 +2768,7 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker): assert isinstance(update, StateUpdate) updates.append(update) assert len(updates) == 1 - assert updates[0].delta["state"].pop("router") is not None + assert updates[0].delta[State.get_name()].pop("router") is not None assert updates[0].delta == exp_is_hydrated(state, False) events = updates[0].events @@ -2759,22 +2800,27 @@ async def test_get_state(mock_app: rx.App, token: str): if isinstance(mock_app.state_manager, StateManagerMemory): # All substates are available assert tuple(sorted(test_state.substates)) == ( - "child_state", - "child_state2", - "child_state3", + ChildState.get_name(), + ChildState2.get_name(), + ChildState3.get_name(), ) else: # Sibling states are only populated if they have computed vars - assert tuple(sorted(test_state.substates)) == ("child_state2", "child_state3") + assert tuple(sorted(test_state.substates)) == ( + ChildState2.get_name(), + ChildState3.get_name(), + ) # Because ChildState3 has a computed var, it is always dirty, and always populated. assert ( - test_state.substates["child_state3"].substates["grandchild_state3"].computed + test_state.substates[ChildState3.get_name()] + .substates[GrandchildState3.get_name()] + .computed == "" ) # Get the child_state2 directly. - child_state2_direct = test_state.get_substate(["child_state2"]) + child_state2_direct = test_state.get_substate([ChildState2.get_name()]) child_state2_get_state = await test_state.get_state(ChildState2) # These should be the same object. assert child_state2_direct is child_state2_get_state @@ -2785,19 +2831,21 @@ async def test_get_state(mock_app: rx.App, token: str): # Now the original root should have all substates populated. assert tuple(sorted(test_state.substates)) == ( - "child_state", - "child_state2", - "child_state3", + ChildState.get_name(), + ChildState2.get_name(), + ChildState3.get_name(), ) # ChildState should be retrievable - child_state_direct = test_state.get_substate(["child_state"]) + child_state_direct = test_state.get_substate([ChildState.get_name()]) child_state_get_state = await test_state.get_state(ChildState) # These should be the same object. assert child_state_direct is child_state_get_state # GrandchildState instance should be the same as the one retrieved from the child_state2. - assert grandchild_state is child_state_direct.get_substate(["grandchild_state"]) + assert grandchild_state is child_state_direct.get_substate( + [GrandchildState.get_name()] + ) grandchild_state.value2 = "set_value" assert test_state.get_delta() == { @@ -2824,21 +2872,21 @@ async def test_get_state(mock_app: rx.App, token: str): test_state._clean() # All substates are available assert tuple(sorted(new_test_state.substates)) == ( - "child_state", - "child_state2", - "child_state3", + ChildState.get_name(), + ChildState2.get_name(), + ChildState3.get_name(), ) else: # With redis, we get a whole new instance assert new_test_state is not test_state # Sibling states are only populated if they have computed vars assert tuple(sorted(new_test_state.substates)) == ( - "child_state2", - "child_state3", + ChildState2.get_name(), + ChildState3.get_name(), ) # Set a value on child_state2, should update cached var in grandchild_state2 - child_state2 = new_test_state.get_substate(("child_state2",)) + child_state2 = new_test_state.get_substate((ChildState2.get_name(),)) child_state2.value = "set_c2_value" assert new_test_state.get_delta() == { @@ -2929,8 +2977,8 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): if isinstance(mock_app.state_manager, StateManagerRedis): # When redis is used, only states with computed vars are pre-fetched. - assert "child2" not in root.substates - assert "child3" in root.substates # (due to @rx.var) + assert Child2.get_name() not in root.substates + assert Child3.get_name() in root.substates # (due to @rx.var) # Get the unconnected sibling state, which will be used to `get_state` other instances. child = root.get_substate(Child.get_full_name().split(".")) diff --git a/tests/test_state_tree.py b/tests/test_state_tree.py index f1e7c8ddb..7c1e13a91 100644 --- a/tests/test_state_tree.py +++ b/tests/test_state_tree.py @@ -237,7 +237,13 @@ def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]: [ ( Root, - ["tree_a", "tree_b", "tree_c", "tree_d", "tree_e"], + [ + TreeA.get_name(), + TreeB.get_name(), + TreeC.get_name(), + TreeD.get_name(), + TreeE.get_name(), + ], [ TreeA.get_full_name(), SubA_A.get_full_name(), @@ -261,7 +267,7 @@ def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]: ), ( TreeA, - ("tree_a", "tree_d", "tree_e"), + (TreeA.get_name(), TreeD.get_name(), TreeE.get_name()), [ TreeA.get_full_name(), SubA_A.get_full_name(), @@ -276,7 +282,7 @@ def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]: ), ( SubA_A_A_A, - ["tree_a", "tree_d", "tree_e"], + [TreeA.get_name(), TreeD.get_name(), TreeE.get_name()], [ TreeA.get_full_name(), SubA_A.get_full_name(), @@ -288,7 +294,7 @@ def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]: ), ( TreeB, - ["tree_b", "tree_d", "tree_e"], + [TreeB.get_name(), TreeD.get_name(), TreeE.get_name()], [ TreeB.get_full_name(), SubB_A.get_full_name(), @@ -300,7 +306,7 @@ def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]: ), ( SubB_B, - ["tree_b", "tree_d", "tree_e"], + [TreeB.get_name(), TreeD.get_name(), TreeE.get_name()], [ TreeB.get_full_name(), SubB_B.get_full_name(), @@ -309,7 +315,7 @@ def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]: ), ( SubB_C_A, - ["tree_b", "tree_d", "tree_e"], + [TreeB.get_name(), TreeD.get_name(), TreeE.get_name()], [ TreeB.get_full_name(), SubB_C.get_full_name(), @@ -319,7 +325,7 @@ def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]: ), ( TreeC, - ["tree_c", "tree_d", "tree_e"], + [TreeC.get_name(), TreeD.get_name(), TreeE.get_name()], [ TreeC.get_full_name(), SubC_A.get_full_name(), @@ -328,14 +334,14 @@ def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]: ), ( TreeD, - ["tree_d", "tree_e"], + [TreeD.get_name(), TreeE.get_name()], [ *ALWAYS_COMPUTED_DICT_KEYS, ], ), ( TreeE, - ["tree_d", "tree_e"], + [TreeE.get_name(), TreeD.get_name()], [ # Extra siblings of computed var included now. SubE_A_A_A_B.get_full_name(), diff --git a/tests/test_var.py b/tests/test_var.py index 918d3a177..a96b331bd 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -739,64 +739,55 @@ def test_var_unsupported_indexing_dicts(var, index): @pytest.mark.parametrize( - "fixture,full_name", + "fixture", [ - ("ParentState", "parent_state.var_without_annotation"), - ("ChildState", "parent_state__child_state.var_without_annotation"), - ( - "GrandChildState", - "parent_state__child_state__grand_child_state.var_without_annotation", - ), - ("StateWithAnyVar", "state_with_any_var.var_without_annotation"), + "ParentState", + "ChildState", + "GrandChildState", + "StateWithAnyVar", ], ) -def test_computed_var_without_annotation_error(request, fixture, full_name): +def test_computed_var_without_annotation_error(request, fixture): """Test that a type error is thrown when an attribute of a computed var is accessed without annotating the computed var. Args: request: Fixture Request. fixture: The state fixture. - full_name: The full name of the state var. """ with pytest.raises(TypeError) as err: state = request.getfixturevalue(fixture) state.var_without_annotation.foo - assert ( - err.value.args[0] - == f"You must provide an annotation for the state var `{full_name}`. Annotation cannot be `typing.Any`" - ) + full_name = state.var_without_annotation._var_full_name + assert ( + err.value.args[0] + == f"You must provide an annotation for the state var `{full_name}`. Annotation cannot be `typing.Any`" + ) @pytest.mark.parametrize( - "fixture,full_name", + "fixture", [ - ( - "StateWithCorrectVarAnnotation", - "state_with_correct_var_annotation.var_with_annotation", - ), - ( - "StateWithWrongVarAnnotation", - "state_with_wrong_var_annotation.var_with_annotation", - ), + "StateWithCorrectVarAnnotation", + "StateWithWrongVarAnnotation", ], ) -def test_computed_var_with_annotation_error(request, fixture, full_name): +def test_computed_var_with_annotation_error(request, fixture): """Test that an Attribute error is thrown when a non-existent attribute of an annotated computed var is accessed or when the wrong annotation is provided to a computed var. Args: request: Fixture Request. fixture: The state fixture. - full_name: The full name of the state var. """ with pytest.raises(AttributeError) as err: state = request.getfixturevalue(fixture) state.var_with_annotation.foo - assert ( - err.value.args[0] - == f"The State var `{full_name}` has no attribute 'foo' or may have been annotated wrongly." - ) + full_name = state.var_with_annotation._var_full_name + assert ( + err.value.args[0] + == f"The State var `{full_name}` has no attribute 'foo' or may have been annotated wrongly." + ) @pytest.mark.parametrize( @@ -1402,12 +1393,15 @@ def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: List (Var.create(1), "1"), (Var.create([1, 2, 3]), "[1, 2, 3]"), (Var.create({"foo": "bar"}), '{"foo": "bar"}'), - (Var.create(ATestState.value, _var_is_string=True), "a_test_state.value"), + ( + Var.create(ATestState.value, _var_is_string=True), + f"{ATestState.get_full_name()}.value", + ), ( Var.create(f"{ATestState.value} string", _var_is_string=True), - "`${a_test_state.value} string`", + f"`${{{ATestState.get_full_name()}.value}} string`", ), - (Var.create(ATestState.dict_val), "a_test_state.dict_val"), + (Var.create(ATestState.dict_val), f"{ATestState.get_full_name()}.dict_val"), ], ) def test_var_name_unwrapped(var, expected): diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 5e43cfc39..7037a3798 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -432,8 +432,8 @@ def test_format_cond( ], Var.create("yellow", _var_is_string=True), "(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return (`red`); break;case JSON.stringify(2): case JSON.stringify(3): " - "return (`blue`); break;case JSON.stringify(test_state.mapping): return " - "(test_state.num1); break;case JSON.stringify(`${test_state.map_key}-key`): return (`return-key`);" + f"return (`blue`); break;case JSON.stringify({TestState.get_full_name()}.mapping): return " + f"({TestState.get_full_name()}.num1); break;case JSON.stringify(`${{{TestState.get_full_name()}.map_key}}-key`): return (`return-key`);" " break;default: return (`yellow`); break;};})()", ) ], @@ -585,11 +585,14 @@ def test_get_handler_parts(input, output): @pytest.mark.parametrize( "input,output", [ - (TestState.do_something, "test_state.do_something"), - (ChildState.change_both, "test_state.child_state.change_both"), + (TestState.do_something, f"{TestState.get_full_name()}.do_something"), + ( + ChildState.change_both, + f"{ChildState.get_full_name()}.change_both", + ), ( GrandchildState.do_nothing, - "test_state.child_state.grandchild_state.do_nothing", + f"{GrandchildState.get_full_name()}.do_nothing", ), ], )