diff --git a/integration/test_background_task.py b/integration/test_background_task.py index 493dc5180..bc70ff01f 100644 --- a/integration/test_background_task.py +++ b/integration/test_background_task.py @@ -93,7 +93,7 @@ def BackgroundTask(): rx.button("Reset", on_click=State.reset_counter, id="reset"), ) - app = rx.App(state=State) + app = rx.App(state=rx.State) app.add_page(index) app.compile() diff --git a/integration/test_call_script.py b/integration/test_call_script.py index 8b24ce19b..95bdf37c9 100644 --- a/integration/test_call_script.py +++ b/integration/test_call_script.py @@ -135,7 +135,7 @@ def CallScript(): yield rx.call_script("inline_counter = 0; external_counter = 0") self.reset() - app = rx.App(state=CallScriptState) + app = rx.App(state=rx.State) with open("assets/external.js", "w") as f: f.write(external_scripts) diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 90b98f16d..a9a311d8c 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -97,7 +97,7 @@ def ClientSide(): rx.box(ClientSideSubSubState.l1s, id="l1s"), ) - app = rx.App(state=ClientSideState) + app = rx.App(state=rx.State) app.add_page(index) app.add_page(index, route="/foo") app.compile() @@ -263,7 +263,6 @@ async def test_client_side_state( state_var_input.send_keys("c7") input_value_input.send_keys("c7 value") set_sub_state_button.click() - state_var_input.send_keys("l1") input_value_input.send_keys("l1 value") set_sub_state_button.click() @@ -276,7 +275,6 @@ async def test_client_side_state( state_var_input.send_keys("l4") input_value_input.send_keys("l4 value") set_sub_state_button.click() - state_var_input.send_keys("c1s") input_value_input.send_keys("c1s value") set_sub_sub_state_button.click() @@ -285,28 +283,28 @@ async def test_client_side_state( set_sub_sub_state_button.click() exp_cookies = { - "client_side_state.client_side_sub_state.c1": { + "state.client_side_state.client_side_sub_state.c1": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c1", + "name": "state.client_side_state.client_side_sub_state.c1", "path": "/", "sameSite": "Lax", "secure": False, "value": "c1%20value", }, - "client_side_state.client_side_sub_state.c2": { + "state.client_side_state.client_side_sub_state.c2": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c2", + "name": "state.client_side_state.client_side_sub_state.c2", "path": "/", "sameSite": "Lax", "secure": False, "value": "c2%20value", }, - "client_side_state.client_side_sub_state.c4": { + "state.client_side_state.client_side_sub_state.c4": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c4", + "name": "state.client_side_state.client_side_sub_state.c4", "path": "/", "sameSite": "Strict", "secure": False, @@ -321,19 +319,19 @@ async def test_client_side_state( "secure": False, "value": "c6%20value", }, - "client_side_state.client_side_sub_state.c7": { + "state.client_side_state.client_side_sub_state.c7": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c7", + "name": "state.client_side_state.client_side_sub_state.c7", "path": "/", "sameSite": "Lax", "secure": False, "value": "c7%20value", }, - "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": { + "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s", + "name": "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s", "path": "/", "sameSite": "Lax", "secure": False, @@ -354,40 +352,45 @@ async def test_client_side_state( input_value_input.send_keys("c3 value") set_sub_state_button.click() AppHarness._poll_for( - lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver) + lambda: "state.client_side_state.client_side_sub_state.c3" + in cookie_info_map(driver) ) - c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"] + c3_cookie = cookie_info_map(driver)[ + "state.client_side_state.client_side_sub_state.c3" + ] assert c3_cookie.pop("expiry") is not None assert c3_cookie == { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c3", + "name": "state.client_side_state.client_side_sub_state.c3", "path": "/", "sameSite": "Lax", "secure": False, "value": "c3%20value", } time.sleep(2) # wait for c3 to expire - assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver) + assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map( + driver + ) local_storage_items = local_storage.items() local_storage_items.pop("chakra-ui-color-mode", None) assert ( - local_storage_items.pop("client_side_state.client_side_sub_state.l1") + local_storage_items.pop("state.client_side_state.client_side_sub_state.l1") == "l1 value" ) assert ( - local_storage_items.pop("client_side_state.client_side_sub_state.l2") + local_storage_items.pop("state.client_side_state.client_side_sub_state.l2") == "l2 value" ) assert local_storage_items.pop("l3") == "l3 value" assert ( - local_storage_items.pop("client_side_state.client_side_sub_state.l4") + local_storage_items.pop("state.client_side_state.client_side_sub_state.l4") == "l4 value" ) assert ( local_storage_items.pop( - "client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s" + "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s" ) == "l1s value" ) @@ -482,12 +485,15 @@ async def test_client_side_state( # make sure c5 cookie shows up on the `/foo` route AppHarness._poll_for( - lambda: "client_side_state.client_side_sub_state.c5" in cookie_info_map(driver) + lambda: "state.client_side_state.client_side_sub_state.c5" + in cookie_info_map(driver) ) - assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == { + assert cookie_info_map(driver)[ + "state.client_side_state.client_side_sub_state.c5" + ] == { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c5", + "name": "state.client_side_state.client_side_sub_state.c5", "path": "/foo/", "sameSite": "Lax", "secure": False, diff --git a/integration/test_connection_banner.py b/integration/test_connection_banner.py index c078df1f7..0468e6b53 100644 --- a/integration/test_connection_banner.py +++ b/integration/test_connection_banner.py @@ -19,7 +19,7 @@ def ConnectionBanner(): def index(): return rx.text("Hello World") - app = rx.App(state=State) + app = rx.App(state=rx.State) app.add_page(index) app.compile() diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index 5ab11b693..782e625ee 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -56,7 +56,7 @@ def DynamicRoute(): def redirect_page(): return rx.fragment(rx.text("redirecting...")) - app = rx.App(state=DynamicState) + app = rx.App(state=rx.State) app.add_page(index) app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore @@ -143,10 +143,12 @@ def poll_for_order( return await dynamic_route.get_state(token) async def _check(): - return (await _backend_state()).order == exp_order + return (await _backend_state()).substates[ + "dynamic_state" + ].order == exp_order await AppHarness._poll_for_async(_check) - assert (await _backend_state()).order == exp_order + assert (await _backend_state()).substates["dynamic_state"].order == exp_order return _poll_for_order diff --git a/integration/test_event_actions.py b/integration/test_event_actions.py index 8f5e0788b..da444a5cf 100644 --- a/integration/test_event_actions.py +++ b/integration/test_event_actions.py @@ -130,7 +130,7 @@ def TestEventAction(): on_click=EventActionState.on_click("outer"), # type: ignore ) - app = rx.App(state=EventActionState) + app = rx.App(state=rx.State) app.add_page(index) app.compile() @@ -211,10 +211,14 @@ def poll_for_order( return await event_action.get_state(token) async def _check(): - return (await _backend_state()).order == exp_order + return (await _backend_state()).substates[ + "event_action_state" + ].order == exp_order await AppHarness._poll_for_async(_check) - assert (await _backend_state()).order == exp_order + assert (await _backend_state()).substates[ + "event_action_state" + ].order == exp_order return _poll_for_order diff --git a/integration/test_event_chain.py b/integration/test_event_chain.py index 5fbf7cc14..7d003635e 100644 --- a/integration/test_event_chain.py +++ b/integration/test_event_chain.py @@ -122,7 +122,7 @@ def EventChain(): time.sleep(0.5) self.interim_value = "final" - app = rx.App(state=State) + app = rx.App(state=rx.State) token_input = rx.input( value=State.router.session.client_token, is_read_only=True, id="token" @@ -401,12 +401,12 @@ async def test_event_chain_click( btn.click() async def _has_all_events(): - return len((await event_chain.get_state(token)).event_order) == len( - exp_event_order - ) + return len( + (await event_chain.get_state(token)).substates["state"].event_order + ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).event_order + event_order = (await event_chain.get_state(token)).substates["state"].event_order assert event_order == exp_event_order @@ -453,12 +453,12 @@ async def test_event_chain_on_load( token = assert_token(event_chain, driver) async def _has_all_events(): - return len((await event_chain.get_state(token)).event_order) == len( - exp_event_order - ) + return len( + (await event_chain.get_state(token)).substates["state"].event_order + ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) - backend_state = await event_chain.get_state(token) + backend_state = (await event_chain.get_state(token)).substates["state"] assert backend_state.event_order == exp_event_order assert backend_state.is_hydrated is True @@ -529,12 +529,12 @@ async def test_event_chain_on_mount( unmount_button.click() async def _has_all_events(): - return len((await event_chain.get_state(token)).event_order) == len( - exp_event_order - ) + return len( + (await event_chain.get_state(token)).substates["state"].event_order + ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).event_order + event_order = (await event_chain.get_state(token)).substates["state"].event_order assert event_order == exp_event_order diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py index 4a9f6c2d1..cc36b5f25 100644 --- a/integration/test_form_submit.py +++ b/integration/test_form_submit.py @@ -22,7 +22,7 @@ def FormSubmit(): def form_submit(self, form_data: dict): self.form_data = form_data - app = rx.App(state=FormState) + app = rx.App(state=rx.State) @app.add_page def index(): @@ -75,7 +75,7 @@ def FormSubmitName(): def form_submit(self, form_data: dict): self.form_data = form_data - app = rx.App(state=FormState) + app = rx.App(state=rx.State) @app.add_page def index(): @@ -210,7 +210,7 @@ async def test_submit(driver, form_submit: AppHarness): submit_input.click() async def get_form_data(): - return (await form_submit.get_state(token)).form_data + return (await form_submit.get_state(token)).substates["form_state"].form_data # wait for the form data to arrive at the backend form_data = await AppHarness._poll_for_async(get_form_data) diff --git a/integration/test_input.py b/integration/test_input.py index d24f76815..4a5179850 100644 --- a/integration/test_input.py +++ b/integration/test_input.py @@ -16,7 +16,7 @@ def FullyControlledInput(): class State(rx.State): text: str = "initial" - app = rx.App(state=State) + app = rx.App(state=rx.State) @app.add_page def index(): @@ -85,13 +85,15 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): debounce_input.send_keys("foo") time.sleep(0.5) assert debounce_input.get_attribute("value") == "ifoonitial" - assert (await fully_controlled_input.get_state(token)).text == "ifoonitial" + assert (await fully_controlled_input.get_state(token)).substates[ + "state" + ].text == "ifoonitial" assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial" # clear the input on the backend async with fully_controlled_input.modify_state(token) as state: - state.text = "" - assert (await fully_controlled_input.get_state(token)).text == "" + state.substates["state"].text = "" + assert (await fully_controlled_input.get_state(token)).substates["state"].text == "" assert ( fully_controlled_input.poll_for_value( debounce_input, exp_not_equal="ifoonitial" @@ -103,9 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): debounce_input.send_keys("getting testing done") time.sleep(0.5) assert debounce_input.get_attribute("value") == "getting testing done" - assert ( - await fully_controlled_input.get_state(token) - ).text == "getting testing done" + assert (await fully_controlled_input.get_state(token)).substates[ + "state" + ].text == "getting testing done" assert fully_controlled_input.poll_for_value(value_input) == "getting testing done" # type into the on_change input @@ -113,7 +115,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): time.sleep(0.5) assert debounce_input.get_attribute("value") == "overwrite the state" assert on_change_input.get_attribute("value") == "overwrite the state" - assert (await fully_controlled_input.get_state(token)).text == "overwrite the state" + assert (await fully_controlled_input.get_state(token)).substates[ + "state" + ].text == "overwrite the state" assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state" clear_button.click() diff --git a/integration/test_login_flow.py b/integration/test_login_flow.py index 68e0a864d..f53635743 100644 --- a/integration/test_login_flow.py +++ b/integration/test_login_flow.py @@ -42,7 +42,7 @@ def LoginSample(): rx.button("Do it", on_click=State.login, id="doit"), ) - app = rx.App(state=State) + app = rx.App(state=rx.State) app.add_page(index) app.add_page(login) app.compile() @@ -137,6 +137,6 @@ def test_login_flow( logout_button = driver.find_element(By.ID, "logout") logout_button.click() - assert login_sample._poll_for(lambda: local_storage["state.auth_token"] == "") + assert login_sample._poll_for(lambda: local_storage["state.state.auth_token"] == "") with pytest.raises(NoSuchElementException): driver.find_element(By.ID, "auth-token") diff --git a/integration/test_radix_themes.py b/integration/test_radix_themes.py index 8f07786c5..d4731d20e 100644 --- a/integration/test_radix_themes.py +++ b/integration/test_radix_themes.py @@ -81,7 +81,7 @@ def RadixThemesApp(): ) app = rx.App( - state=State, + state=rx.State, theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"), ) app.add_page(index) diff --git a/integration/test_server_side_event.py b/integration/test_server_side_event.py index b24f47368..31a38ba36 100644 --- a/integration/test_server_side_event.py +++ b/integration/test_server_side_event.py @@ -33,7 +33,7 @@ def ServerSideEvent(): def set_value_return_c(self): return rx.set_value("c", "") - app = rx.App(state=SSState) + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/integration/test_table.py b/integration/test_table.py index e947514a0..00e6a8b22 100644 --- a/integration/test_table.py +++ b/integration/test_table.py @@ -26,7 +26,7 @@ def Table(): caption: str = "random caption" - app = rx.App(state=TableState) + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/integration/test_upload.py b/integration/test_upload.py index 648c68be5..13fb28d36 100644 --- a/integration/test_upload.py +++ b/integration/test_upload.py @@ -113,7 +113,7 @@ def UploadFile(): ), ) - app = rx.App(state=UploadState) + app = rx.App(state=rx.State) app.add_page(index) app.compile() @@ -192,7 +192,7 @@ async def test_upload_file( # look up the backend state and assert on uploaded contents async def get_file_data(): - return (await upload_file.get_state(token))._file_data + return (await upload_file.get_state(token)).substates["upload_state"]._file_data file_data = await AppHarness._poll_for_async(get_file_data) assert isinstance(file_data, dict) @@ -205,8 +205,8 @@ async def test_upload_file( state = await upload_file.get_state(token) if secondary: # only the secondary form tracks progress and chain events - assert state.event_order.count("upload_progress") == 1 - assert state.event_order.count("chain_event") == 1 + assert state.substates["upload_state"].event_order.count("upload_progress") == 1 + assert state.substates["upload_state"].event_order.count("chain_event") == 1 @pytest.mark.asyncio @@ -251,7 +251,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): # look up the backend state and assert on uploaded contents async def get_file_data(): - return (await upload_file.get_state(token))._file_data + return (await upload_file.get_state(token)).substates["upload_state"]._file_data file_data = await AppHarness._poll_for_async(get_file_data) assert isinstance(file_data, dict) @@ -349,7 +349,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive # look up the backend state and assert on progress state = await upload_file.get_state(token) - assert state.progress_dicts - assert exp_name not in state._file_data + assert state.substates["upload_state"].progress_dicts + assert exp_name not in state.substates["upload_state"]._file_data target_file.unlink() diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py index f327b662c..374344a87 100644 --- a/integration/test_var_operations.py +++ b/integration/test_var_operations.py @@ -30,7 +30,7 @@ def VarOperations(): dict2: dict = {3: 4} html_str: str = "
hello
" - app = rx.App(state=VarOperationState) + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/reflex/app.py b/reflex/app.py index 68ad8b639..33b679378 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -57,6 +57,7 @@ from reflex.route import ( verify_route_validity, ) from reflex.state import ( + BaseState, RouterData, State, StateManager, @@ -98,7 +99,7 @@ class App(Base): socket_app: Optional[ASGIApp] = None # The state class to use for the app. - state: Optional[Type[State]] = None + state: Optional[Type[BaseState]] = None # Class to manage many client states. _state_manager: Optional[StateManager] = None @@ -149,25 +150,24 @@ class App(Base): "`connect_error_component` is deprecated, use `overlay_component` instead" ) super().__init__(*args, **kwargs) - state_subclasses = State.__subclasses__() - inferred_state = state_subclasses[-1] if state_subclasses else None + state_subclasses = BaseState.__subclasses__() is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ - # Special case to allow test cases have multiple subclasses of rx.State. + # Special case to allow test cases have multiple subclasses of rx.BaseState. if not is_testing_env: - # Only one State class is allowed. + # Only one Base State class is allowed. if len(state_subclasses) > 1: raise ValueError( - "rx.State has been subclassed multiple times. Only one subclass is allowed" + "rx.BaseState cannot be subclassed multiple times. use rx.State instead" ) # verify that provided state is valid - if self.state and inferred_state and self.state is not inferred_state: + if self.state and self.state is not State: console.warn( f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported." - f" Defaulting to root state: ({inferred_state.__name__})" + f" Defaulting to root state: ({State.__name__})" ) - self.state = inferred_state + self.state = State # Get the config config = get_config() @@ -265,7 +265,7 @@ class App(Base): raise ValueError("The state manager has not been initialized.") return self._state_manager - async def preprocess(self, state: State, event: Event) -> StateUpdate | None: + async def preprocess(self, state: BaseState, event: Event) -> StateUpdate | None: """Preprocess the event. This is where middleware can modify the event before it is processed. @@ -290,7 +290,7 @@ class App(Base): return out # type: ignore async def postprocess( - self, state: State, event: Event, update: StateUpdate + self, state: BaseState, event: Event, update: StateUpdate ) -> StateUpdate: """Postprocess the event. @@ -764,7 +764,7 @@ class App(Base): future.result() @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[State]: + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state out of band. Args: @@ -792,7 +792,9 @@ class App(Base): sid=state.router.session.session_id, ) - def _process_background(self, state: State, event: Event) -> asyncio.Task | None: + def _process_background( + self, state: BaseState, event: Event + ) -> asyncio.Task | None: """Process an event in the background and emit updates as they arrive. Args: diff --git a/reflex/app.pyi b/reflex/app.pyi index f7e63727a..667ebf52a 100644 --- a/reflex/app.pyi +++ b/reflex/app.pyi @@ -33,6 +33,7 @@ from reflex.route import ( ) from reflex.state import ( State as State, + BaseState as BaseState, StateManager as StateManager, StateUpdate as StateUpdate, ) @@ -69,7 +70,7 @@ class App(Base): api: FastAPI sio: Optional[AsyncServer] socket_app: Optional[ASGIApp] - state: Type[State] + state: Type[BaseState] state_manager: StateManager style: ComponentStyle middleware: List[Middleware] diff --git a/reflex/base.py b/reflex/base.py index 62bc70855..9efea60e0 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -1,11 +1,42 @@ """Define the base Reflex class.""" from __future__ import annotations -from typing import Any +import os +from typing import Any, List, Type import pydantic +from pydantic import BaseModel from pydantic.fields import ModelField +from reflex import constants + + +def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None: + """Ensure that the field's name does not shadow an existing attribute of the model. + + Args: + bases: List of base models to check for shadowed attrs. + field_name: name of attribute + + Raises: + NameError: If state var field shadows another in its parent state + """ + reload = os.getenv(constants.RELOAD_CONFIG) == "True" + for base in bases: + try: + if not reload and getattr(base, field_name, None): + pass + except TypeError as te: + raise NameError( + f'State var "{field_name}" in {base} has been shadowed by a substate var; ' + f'use a different field name instead".' + ) from te + + +# monkeypatch pydantic validate_field_name method to skip validating +# shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True) +pydantic.main.validate_field_name = validate_field_name # type: ignore + class Base(pydantic.BaseModel): """The base class subclassed by all Reflex classes. diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 1e5215872..1740af50a 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -15,7 +15,7 @@ from reflex.components.component import ( StatefulComponent, ) from reflex.config import get_config -from reflex.state import State +from reflex.state import BaseState from reflex.utils.imports import ImportVar @@ -63,7 +63,7 @@ def _compile_theme(theme: dict) -> str: return templates.THEME.render(theme=theme) -def _compile_contexts(state: Optional[Type[State]]) -> str: +def _compile_contexts(state: Optional[Type[BaseState]]) -> str: """Compile the initial state and contexts. Args: @@ -87,7 +87,7 @@ def _compile_contexts(state: Optional[Type[State]]) -> str: def _compile_page( component: Component, - state: Type[State], + state: Type[BaseState], ) -> str: """Compile the component given the app state. @@ -337,7 +337,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]: return output_path, code -def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]: +def compile_contexts(state: Optional[Type[BaseState]]) -> tuple[str, str]: """Compile the initial state / context. Args: @@ -353,7 +353,7 @@ def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]: def compile_page( - path: str, component: Component, state: Type[State] + path: str, component: Component, state: Type[BaseState] ) -> tuple[str, str]: """Compile a single page. diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 3a866051a..0a4df7a4e 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -21,7 +21,7 @@ from reflex.components.base import ( Title, ) from reflex.components.component import Component, ComponentStyle, CustomComponent -from reflex.state import Cookie, LocalStorage, State +from reflex.state import BaseState, Cookie, LocalStorage from reflex.style import Style from reflex.utils import console, format, imports, path_ops @@ -128,7 +128,7 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) } -def compile_state(state: Type[State]) -> dict: +def compile_state(state: Type[BaseState]) -> dict: """Compile the state of the app. Args: @@ -170,7 +170,7 @@ def _compile_client_storage_field( def _compile_client_storage_recursive( - state: Type[State], + state: Type[BaseState], ) -> tuple[dict[str, dict], dict[str, dict[str, str]]]: """Compile the client-side storage for the given state recursively. @@ -208,7 +208,7 @@ def _compile_client_storage_recursive( return cookies, local_storage -def compile_client_storage(state: Type[State]) -> dict[str, dict]: +def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]: """Compile the client-side storage for the given state. Args: diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index bfe112e63..45aaa7248 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -6,6 +6,7 @@ from .base import ( LOCAL_STORAGE, POLLING_MAX_HTTP_BUFFER_SIZE, PYTEST_CURRENT_TEST, + RELOAD_CONFIG, SKIP_COMPILE_ENV_VAR, ColorMode, Dirs, @@ -85,6 +86,7 @@ __ALL__ = [ PYTEST_CURRENT_TEST, PRODUCTION_BACKEND_URL, Reflex, + RELOAD_CONFIG, RequirementsTxt, RouteArgType, RouteRegex, diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 0b28e18cb..82cb39da8 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -173,3 +173,4 @@ SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE" # Testing variables. # Testing os env set by pytest when running a test case. PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" +RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG" diff --git a/reflex/event.py b/reflex/event.py index cbaf65b0c..f0ad0957b 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -22,7 +22,7 @@ from reflex.utils.types import ArgsSpec from reflex.vars import BaseVar, Var if TYPE_CHECKING: - from reflex.state import State + from reflex.state import BaseState class Event(Base): @@ -64,7 +64,7 @@ def background(fn): def _no_chain_background_task( - state_cls: Type["State"], name: str, fn: Callable + state_cls: Type["BaseState"], name: str, fn: Callable ) -> Callable: """Protect against directly chaining a background task from another event handler. diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 38d5fb14f..6108a90c4 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional from reflex import constants from reflex.event import Event, fix_events, get_hydrate_event from reflex.middleware.middleware import Middleware -from reflex.state import State, StateUpdate +from reflex.state import BaseState, StateUpdate from reflex.utils import format if TYPE_CHECKING: @@ -17,7 +17,7 @@ class HydrateMiddleware(Middleware): """Middleware to handle initial app hydration.""" async def preprocess( - self, app: App, state: State, event: Event + self, app: App, state: BaseState, event: Event ) -> Optional[StateUpdate]: """Preprocess the event. diff --git a/reflex/middleware/middleware.py b/reflex/middleware/middleware.py index 726d81621..f522ff861 100644 --- a/reflex/middleware/middleware.py +++ b/reflex/middleware/middleware.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional from reflex.base import Base from reflex.event import Event -from reflex.state import State, StateUpdate +from reflex.state import BaseState, StateUpdate if TYPE_CHECKING: from reflex.app import App @@ -16,7 +16,7 @@ class Middleware(Base, ABC): """Middleware to preprocess and postprocess requests.""" async def preprocess( - self, app: App, state: State, event: Event + self, app: App, state: BaseState, event: Event ) -> Optional[StateUpdate]: """Preprocess the event. @@ -31,7 +31,7 @@ class Middleware(Base, ABC): return None async def postprocess( - self, app: App, state: State, event: Event, update: StateUpdate + self, app: App, state: BaseState, event: Event, update: StateUpdate ) -> StateUpdate: """Postprocess the event. diff --git a/reflex/state.py b/reflex/state.py index 881a6b47c..f6b10e849 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -7,6 +7,7 @@ import copy import functools import inspect import json +import os import traceback import urllib.parse import uuid @@ -81,7 +82,7 @@ class HeaderData(Base): class PageData(Base): """An object containing page data.""" - host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) + host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) path: str = "" raw_path: str = "" full_path: str = "" @@ -152,7 +153,7 @@ RESERVED_BACKEND_VAR_NAMES = { } -class State(Base, ABC, extra=pydantic.Extra.allow): +class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" # A map from the var name to the var. @@ -176,6 +177,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # The event handlers. event_handlers: ClassVar[Dict[str, EventHandler]] = {} + # A set of subclassses of this class. + class_subclasses: ClassVar[Set[Type[BaseState]]] = set() + # Mapping of var name to set of computed variables that depend on it _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} @@ -189,10 +193,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow): _always_dirty_substates: ClassVar[Set[str]] = set() # The parent state. - parent_state: Optional[State] = None + parent_state: Optional[BaseState] = None # The substates of the state. - substates: Dict[str, State] = {} + substates: Dict[str, BaseState] = {} # The set of dirty vars. dirty_vars: Set[str] = set() @@ -209,10 +213,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # The router data for the current page router: RouterData = RouterData() - # The hydrated bool. - is_hydrated: bool = False - - def __init__(self, *args, parent_state: State | None = None, **kwargs): + def __init__(self, *args, parent_state: BaseState | None = None, **kwargs): """Initialize the state. Args: @@ -220,28 +221,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow): parent_state: The parent state. **kwargs: The kwargs to pass to the Pydantic init method. - Raises: - ValueError: If a substate class shadows another. """ kwargs["parent_state"] = parent_state super().__init__(*args, **kwargs) # Setup the substates. for substate in self.get_substates(): - substate_name = substate.get_name() - if substate_name in self.substates: - raise ValueError( - f"The substate class '{substate_name}' has been defined multiple times. Shadowing " - f"substate classes is not allowed." - ) - self.substates[substate_name] = substate(parent_state=self) + self.substates[substate.get_name()] = substate(parent_state=self) # Convert the event handlers to functions. self._init_event_handlers() # Create a fresh copy of the backend variables for this instance self._backend_vars = copy.deepcopy(self.backend_vars) - def _init_event_handlers(self, state: State | None = None): + def _init_event_handlers(self, state: BaseState | None = None): """Initialize event handlers. Allow event handlers to be called directly on the instance. This is @@ -281,17 +274,44 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Args: **kwargs: The kwargs to pass to the pydantic init_subclass method. + + Raises: + ValueError: If a substate class shadows another. """ + is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ super().__init_subclass__(**kwargs) # Event handlers should not shadow builtin state methods. cls._check_overridden_methods() + # Reset subclass tracking for this class. + cls.class_subclasses = set() + # Get the parent vars. parent_state = cls.get_parent_state() if parent_state is not None: cls.inherited_vars = parent_state.vars cls.inherited_backend_vars = parent_state.backend_vars + # Check if another substate class with the same name has already been defined. + if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses): + if is_testing_env: + # Clear existing subclass with same name when app is reloaded via + # utils.prerequisites.get_app(reload=True) + parent_state.class_subclasses = set( + c + for c in parent_state.class_subclasses + if c.__name__ != cls.__name__ + ) + else: + # During normal operation, subclasses cannot have the same name, even if they are + # defined in different modules. + raise ValueError( + f"The substate class '{cls.__name__}' has been defined multiple times. " + "Shadowing substate classes is not allowed." + ) + # Track this new subclass in the parent state's subclasses set. + parent_state.class_subclasses.add(cls) + cls.new_backend_vars = { name: value for name, value in cls.__dict__.items() @@ -437,7 +457,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): @classmethod @functools.lru_cache() - def get_parent_state(cls) -> Type[State] | None: + def get_parent_state(cls) -> Type[BaseState] | None: """Get the parent state. Returns: @@ -446,20 +466,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow): parent_states = [ base for base in cls.__bases__ - if types._issubclass(base, State) and base is not State + if types._issubclass(base, BaseState) and base is not BaseState ] assert len(parent_states) < 2, "Only one parent state is allowed." return parent_states[0] if len(parent_states) == 1 else None # type: ignore @classmethod - @functools.lru_cache() - def get_substates(cls) -> set[Type[State]]: + def get_substates(cls) -> set[Type[BaseState]]: """Get the substates of the state. Returns: The substates of the state. """ - return set(cls.__subclasses__()) + return cls.class_subclasses @classmethod @functools.lru_cache() @@ -487,7 +506,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): @classmethod @functools.lru_cache() - def get_class_substate(cls, path: Sequence[str]) -> Type[State]: + def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]: """Get the class substate. Args: @@ -643,7 +662,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): """ return { func[0]: func[1] - for func in inspect.getmembers(State, predicate=inspect.isfunction) + for func in inspect.getmembers(BaseState, predicate=inspect.isfunction) if not func[0].startswith("__") } @@ -909,7 +928,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): for substate in self.substates.values(): substate._reset_client_storage() - def get_substate(self, path: Sequence[str]) -> State | None: + def get_substate(self, path: Sequence[str]) -> BaseState | None: """Get the substate. Args: @@ -933,7 +952,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): def _get_event_handler( self, event: Event - ) -> tuple[State | StateProxy, EventHandler]: + ) -> tuple[BaseState | StateProxy, EventHandler]: """Get the event handler for the given event. Args: @@ -1050,7 +1069,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): ) async def _process_event( - self, handler: EventHandler, state: State | StateProxy, payload: Dict + self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict ) -> AsyncIterator[StateUpdate]: """Process event. @@ -1263,7 +1282,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): d.update(substate_d) return d - async def __aenter__(self) -> State: + async def __aenter__(self) -> BaseState: """Enter the async context manager protocol. This should not be used for the State class, but exists for @@ -1288,6 +1307,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow): pass +class State(BaseState): + """The app Base State.""" + + # The hydrated bool. + is_hydrated: bool = False + + class StateProxy(wrapt.ObjectProxy): """Proxy of a state instance to control mutability of vars for a background task. @@ -1455,10 +1481,10 @@ class StateManager(Base, ABC): """A class to manage many client states.""" # The state class to use. - state: Type[State] + state: Type[BaseState] @classmethod - def create(cls, state: Type[State]): + def create(cls, state: Type[BaseState]): """Create a new state manager. Args: @@ -1473,7 +1499,7 @@ class StateManager(Base, ABC): return StateManagerMemory(state=state) @abstractmethod - async def get_state(self, token: str) -> State: + async def get_state(self, token: str) -> BaseState: """Get the state for a token. Args: @@ -1485,7 +1511,7 @@ class StateManager(Base, ABC): pass @abstractmethod - async def set_state(self, token: str, state: State): + async def set_state(self, token: str, state: BaseState): """Set the state for a token. Args: @@ -1496,7 +1522,7 @@ class StateManager(Base, ABC): @abstractmethod @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[State]: + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. Args: @@ -1512,7 +1538,7 @@ class StateManagerMemory(StateManager): """A state manager that stores states in memory.""" # The mapping of client ids to states. - states: Dict[str, State] = {} + states: Dict[str, BaseState] = {} # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock = asyncio.Lock() @@ -1527,7 +1553,7 @@ class StateManagerMemory(StateManager): "_states_locks": {"exclude": True}, } - async def get_state(self, token: str) -> State: + async def get_state(self, token: str) -> BaseState: """Get the state for a token. Args: @@ -1540,7 +1566,7 @@ class StateManagerMemory(StateManager): self.states[token] = self.state() return self.states[token] - async def set_state(self, token: str, state: State): + async def set_state(self, token: str, state: BaseState): """Set the state for a token. Args: @@ -1550,7 +1576,7 @@ class StateManagerMemory(StateManager): pass @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[State]: + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. Args: @@ -1598,7 +1624,7 @@ class StateManagerRedis(StateManager): b"evicted", } - async def get_state(self, token: str) -> State: + async def get_state(self, token: str) -> BaseState: """Get the state for a token. Args: @@ -1613,7 +1639,9 @@ class StateManagerRedis(StateManager): return await self.get_state(token) return cloudpickle.loads(redis_state) - async def set_state(self, token: str, state: State, lock_id: bytes | None = None): + async def set_state( + self, token: str, state: BaseState, lock_id: bytes | None = None + ): """Set the state for a token. Args: @@ -1637,7 +1665,7 @@ class StateManagerRedis(StateManager): await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration) @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[State]: + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. Args: @@ -1879,7 +1907,7 @@ class MutableProxy(wrapt.ObjectProxy): __mutable_types__ = (list, dict, set, Base) - def __init__(self, wrapped: Any, state: State, field_name: str): + def __init__(self, wrapped: Any, state: BaseState, field_name: str): """Create a proxy for a mutable object that tracks changes. Args: diff --git a/reflex/testing.py b/reflex/testing.py index a6bb3d400..2cc20e6a4 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -38,7 +38,7 @@ import reflex.utils.build import reflex.utils.exec import reflex.utils.prerequisites import reflex.utils.processes -from reflex.state import State, StateManagerMemory, StateManagerRedis +from reflex.state import BaseState, State, StateManagerMemory, StateManagerRedis try: from selenium import webdriver # pyright: ignore [reportMissingImports] @@ -162,6 +162,9 @@ class AppHarness: with chdir(self.app_path): # ensure config and app are reloaded when testing different app reflex.config.get_config(reload=True) + # reset rx.State subclasses + State.class_subclasses.clear() + # self.app_module.app. self.app_module = reflex.utils.prerequisites.get_app(reload=True) self.app_instance = self.app_module.app if isinstance(self.app_instance.state_manager, StateManagerRedis): @@ -434,7 +437,7 @@ class AppHarness: self._frontends.append(driver) return driver - async def get_state(self, token: str) -> State: + async def get_state(self, token: str) -> BaseState: """Get the state associated with the given token. Args: @@ -561,7 +564,7 @@ class AppHarness: ) return element.get_attribute("value") - def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, reflex.State]: + def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]: """Poll app state_manager for any connected clients. Args: diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 17195f224..f0e53b2d2 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -124,11 +124,13 @@ def get_app(reload: bool = False) -> ModuleType: Returns: The app based on the default config. """ + os.environ[constants.RELOAD_CONFIG] = str(reload) config = get_config() module = ".".join([config.app_name, config.app_name]) sys.path.insert(0, os.getcwd()) app = __import__(module, fromlist=(constants.CompileVars.APP,)) if reload: + importlib.reload(app) return app diff --git a/reflex/vars.py b/reflex/vars.py index caec26f87..6bfd6c57d 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -41,7 +41,7 @@ from reflex.utils import console, format, imports, serializers, types from reflex.utils.imports import ImportDict, ImportVar if TYPE_CHECKING: - from reflex.state import State + from reflex.state import BaseState # Set of unique variable names. USED_VARIABLES = set() @@ -1472,7 +1472,7 @@ class Var: ) ) - def _var_set_state(self, state: Type[State] | str) -> Any: + def _var_set_state(self, state: Type[BaseState] | str) -> Any: """Set the state of the var. Args: @@ -1604,14 +1604,14 @@ class BaseVar(Var): return setter return ".".join((self._var_data.state, setter)) - def get_setter(self) -> Callable[[State, Any], None]: + def get_setter(self) -> Callable[[BaseState, Any], None]: """Get the var's setter function. Returns: A function that that creates a setter for the var. """ - def setter(state: State, value: Any): + def setter(state: BaseState, value: Any): """Get the setter for the var. Args: @@ -1643,9 +1643,9 @@ class ComputedVar(Var, property): def __init__( self, - fget: Callable[[State], Any], - fset: Callable[[State, Any], None] | None = None, - fdel: Callable[[State], Any] | None = None, + fget: Callable[[BaseState], Any], + fset: Callable[[BaseState, Any], None] | None = None, + fdel: Callable[[BaseState], Any] | None = None, doc: str | None = None, **kwargs, ): diff --git a/reflex/vars.pyi b/reflex/vars.pyi index c4c96af90..6ec1bd987 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -5,6 +5,7 @@ from _typeshed import Incomplete from reflex import constants as constants from reflex.base import Base as Base from reflex.state import State as State +from reflex.state import BaseState as BaseState from reflex.utils import console as console, format as format, types as types from reflex.utils.imports import ImportVar from types import FunctionType @@ -110,7 +111,7 @@ class Var: def as_ref(self) -> Var: ... @property def _var_full_name(self) -> str: ... - def _var_set_state(self, state: Type[State] | str) -> Any: ... + def _var_set_state(self, state: Type[BaseState] | str) -> Any: ... @dataclass(eq=False) class BaseVar(Var): @@ -123,7 +124,7 @@ class BaseVar(Var): def __hash__(self) -> int: ... def get_default_value(self) -> Any: ... def get_setter_name(self, include_state: bool = ...) -> str: ... - def get_setter(self) -> Callable[[State, Any], None]: ... + def get_setter(self) -> Callable[[BaseState, Any], None]: ... @dataclass(init=False) class ComputedVar(Var): diff --git a/tests/__init__.py b/tests/__init__.py index b318bba63..b4d8570e5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1,4 @@ """Root directory for tests.""" +import os + +from reflex import constants diff --git a/tests/components/base/test_script.py b/tests/components/base/test_script.py index cc16ab718..7ccdc0634 100644 --- a/tests/components/base/test_script.py +++ b/tests/components/base/test_script.py @@ -2,7 +2,7 @@ import pytest from reflex.components.base.script import Script -from reflex.state import State +from reflex.state import BaseState def test_script_inline(): @@ -31,7 +31,7 @@ def test_script_neither(): Script.create() -class EvState(State): +class EvState(BaseState): """State for testing event handlers.""" def on_ready(self): diff --git a/tests/components/datadisplay/conftest.py b/tests/components/datadisplay/conftest.py index c0e61c437..93796ed23 100644 --- a/tests/components/datadisplay/conftest.py +++ b/tests/components/datadisplay/conftest.py @@ -5,6 +5,7 @@ import pandas as pd import pytest import reflex as rx +from reflex.state import BaseState @pytest.fixture @@ -18,7 +19,7 @@ def data_table_state(request): The data table state class. """ - class DataTableState(rx.State): + class DataTableState(BaseState): data = request.param["data"] columns = ["column1", "column2"] @@ -33,7 +34,7 @@ def data_table_state2(): The data table state class. """ - class DataTableState(rx.State): + class DataTableState(BaseState): _data = pd.DataFrame() @rx.var @@ -51,7 +52,7 @@ def data_table_state3(): The data table state class. """ - class DataTableState(rx.State): + class DataTableState(BaseState): _data: List = [] _columns: List = ["col1", "col2"] @@ -74,7 +75,7 @@ def data_table_state4(): The data table state class. """ - class DataTableState(rx.State): + class DataTableState(BaseState): _data: List = [] _columns: List = ["col1", "col2"] diff --git a/tests/components/datadisplay/test_table.py b/tests/components/datadisplay/test_table.py index 94e39f6e5..212f1fb59 100644 --- a/tests/components/datadisplay/test_table.py +++ b/tests/components/datadisplay/test_table.py @@ -4,12 +4,12 @@ from typing import List, Tuple import pytest from reflex.components.datadisplay.table import Tbody, Tfoot, Thead -from reflex.state import State +from reflex.state import BaseState PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8 -class TableState(State): +class TableState(BaseState): """Test State class.""" rows_List_List_str: List[List[str]] = [["random", "row"]] diff --git a/tests/components/forms/test_debounce.py b/tests/components/forms/test_debounce.py index 1965fabf5..97cfa8648 100644 --- a/tests/components/forms/test_debounce.py +++ b/tests/components/forms/test_debounce.py @@ -3,6 +3,7 @@ import pytest import reflex as rx +from reflex.state import BaseState from reflex.vars import BaseVar @@ -24,7 +25,7 @@ def test_render_many_child(): _ = rx.debounce_input("foo", "bar").render() -class S(rx.State): +class S(BaseState): """Example state for debounce tests.""" value: str = "" diff --git a/tests/components/layout/test_cond.py b/tests/components/layout/test_cond.py index 00cf4de7d..08c6883c4 100644 --- a/tests/components/layout/test_cond.py +++ b/tests/components/layout/test_cond.py @@ -15,12 +15,13 @@ from reflex.components.layout.responsive import ( tablet_only, ) from reflex.components.typography.text import Text +from reflex.state import BaseState from reflex.vars import Var @pytest.fixture def cond_state(request): - class CondState(rx.State): + class CondState(BaseState): value: request.param["value_type"] = request.param["value"] # noqa return CondState diff --git a/tests/components/layout/test_foreach.py b/tests/components/layout/test_foreach.py index aacdf3638..71ae36b23 100644 --- a/tests/components/layout/test_foreach.py +++ b/tests/components/layout/test_foreach.py @@ -4,10 +4,10 @@ import pytest from reflex.components import box, foreach, text from reflex.components.layout import Foreach -from reflex.state import State +from reflex.state import BaseState -class ForEachState(State): +class ForEachState(BaseState): """A state for testing the ForEach component.""" colors_list: List[str] = ["red", "yellow"] diff --git a/tests/components/test_component.py b/tests/components/test_component.py index d90efd658..ac5b23130 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -14,7 +14,7 @@ from reflex.components.component import ( from reflex.components.layout.box import Box from reflex.constants import EventTriggers from reflex.event import EventChain, EventHandler -from reflex.state import State +from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports from reflex.utils.imports import ImportVar @@ -23,7 +23,7 @@ from reflex.vars import Var, VarData @pytest.fixture def test_state(): - class TestState(State): + class TestState(BaseState): num: int def do_something(self): @@ -400,7 +400,7 @@ def test_get_event_triggers(component1, component2): ) -class C1State(State): +class C1State(BaseState): """State for testing C1 component.""" def mock_handler(self, _e, _bravo, _charlie): diff --git a/tests/conftest.py b/tests/conftest.py index d2dd301f4..e5dddc470 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,6 @@ from typing import Dict, Generator import pytest -import reflex as rx from reflex.app import App from reflex.event import EventSpec @@ -225,23 +224,3 @@ def token() -> str: A fresh/unique token string. """ return str(uuid.uuid4()) - - -@pytest.fixture -def duplicate_substate(): - """Create a Test state that has duplicate child substates. - - Returns: - The test state. - """ - - class TestState(rx.State): - pass - - class ChildTestState(TestState): # type: ignore # noqa - pass - - class ChildTestState(TestState): # type: ignore # noqa - pass - - return TestState diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 7767dcf8b..2f21557f0 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -2,13 +2,14 @@ from typing import Any, Dict import pytest +from reflex import constants from reflex.app import App from reflex.constants import CompileVars from reflex.middleware.hydrate_middleware import HydrateMiddleware -from reflex.state import State, StateUpdate +from reflex.state import BaseState, StateUpdate -def exp_is_hydrated(state: State) -> Dict[str, Any]: +def exp_is_hydrated(state: BaseState) -> Dict[str, Any]: """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. Args: @@ -20,7 +21,7 @@ def exp_is_hydrated(state: State) -> Dict[str, Any]: return {state.get_name(): {CompileVars.IS_HYDRATED: True}} -class TestState(State): +class TestState(BaseState): """A test state with no return in handler.""" __test__ = False @@ -32,7 +33,7 @@ class TestState(State): self.num += 1 -class TestState2(State): +class TestState2(BaseState): """A test state with return in handler.""" __test__ = False @@ -54,7 +55,7 @@ class TestState2(State): self.name = "random" -class TestState3(State): +class TestState3(BaseState): """A test state with async handler.""" __test__ = False @@ -97,6 +98,9 @@ async def test_preprocess( event_fixture: The event fixture(an Event). expected: Expected delta. """ + test_state.add_var( + constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False + ) app = App(state=test_state, load_events={"index": [test_state.test_handler]}) state = test_state() diff --git a/tests/states/__init__.py b/tests/states/__init__.py index 2c007172a..11e891ab4 100644 --- a/tests/states/__init__.py +++ b/tests/states/__init__.py @@ -1,10 +1,12 @@ -"""Common rx.State subclasses for use in tests.""" +"""Common rx.BaseState subclasses for use in tests.""" import reflex as rx +from reflex.state import BaseState from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState from .upload import ( ChildFileUploadState, FileStateBase1, + FileStateBase2, FileUploadState, GrandChildFileUploadState, SubUploadState, @@ -12,7 +14,7 @@ from .upload import ( ) -class GenState(rx.State): +class GenState(BaseState): """A state with event handlers that generate multiple updates.""" value: int diff --git a/tests/states/mutation.py b/tests/states/mutation.py index b3d98301f..5825b6d12 100644 --- a/tests/states/mutation.py +++ b/tests/states/mutation.py @@ -3,9 +3,10 @@ from typing import Dict, List, Set, Union import reflex as rx +from reflex.state import BaseState -class DictMutationTestState(rx.State): +class DictMutationTestState(BaseState): """A state for testing ReflexDict mutation.""" # plain dict @@ -62,7 +63,7 @@ class DictMutationTestState(rx.State): self.friend_in_nested_dict["friend"]["age"] = 30 -class ListMutationTestState(rx.State): +class ListMutationTestState(BaseState): """A state for testing ReflexList mutation.""" # plain list @@ -144,7 +145,7 @@ class CustomVar(rx.Base): custom: OtherBase = OtherBase() -class MutableTestState(rx.State): +class MutableTestState(BaseState): """A test state.""" array: List[Union[str, List, Dict[str, str]]] = [ diff --git a/tests/states/upload.py b/tests/states/upload.py index ec2585dd1..8abe11c24 100644 --- a/tests/states/upload.py +++ b/tests/states/upload.py @@ -3,9 +3,10 @@ from pathlib import Path from typing import ClassVar, List import reflex as rx +from reflex.state import BaseState, State -class UploadState(rx.State): +class UploadState(BaseState): """The base state for uploading a file.""" async def handle_upload1(self, files: List[rx.UploadFile]): @@ -17,7 +18,7 @@ class UploadState(rx.State): pass -class BaseState(rx.State): +class BaseState(BaseState): """The test base state.""" pass @@ -37,7 +38,7 @@ class SubUploadState(BaseState): pass -class FileUploadState(rx.State): +class FileUploadState(State): """The base state for uploading a file.""" img_list: List[str] @@ -79,7 +80,7 @@ class FileUploadState(rx.State): pass -class FileStateBase1(rx.State): +class FileStateBase1(State): """The base state for a child FileUploadState.""" pass diff --git a/tests/test_app.py b/tests/test_app.py index 6aa73974e..7d667ac1e 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -28,7 +28,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text from reflex.event import Event, get_hydrate_event from reflex.middleware import HydrateMiddleware from reflex.model import Model -from reflex.state import RouterData, State, StateManagerRedis, StateUpdate +from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate from reflex.style import Style from reflex.utils import format from reflex.vars import ComputedVar @@ -43,7 +43,7 @@ from .states import ( ) -class EmptyState(State): +class EmptyState(BaseState): """An empty state.""" pass @@ -77,14 +77,14 @@ def about_page(): return about -class ATestState(State): +class ATestState(BaseState): """A simple state for testing.""" var: int @pytest.fixture() -def test_state() -> Type[State]: +def test_state() -> Type[BaseState]: """A default state. Returns: @@ -94,14 +94,14 @@ def test_state() -> Type[State]: @pytest.fixture() -def redundant_test_state() -> Type[State]: +def redundant_test_state() -> Type[BaseState]: """A default state. Returns: A default state. """ - class RedundantTestState(State): + class RedundantTestState(BaseState): var: int return RedundantTestState @@ -198,12 +198,12 @@ def test_default_app(app: App): def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): - """Test that an error is thrown when multiple classes subclass rx.State. + """Test that an error is thrown when multiple classes subclass rx.BaseState. Args: monkeypatch: Pytest monkeypatch object. - test_state: A test state subclassing rx.State. - redundant_test_state: Another test state subclassing rx.State. + test_state: A test state subclassing rx.BaseState. + redundant_test_state: Another test state subclassing rx.BaseState. """ monkeypatch.delenv(constants.PYTEST_CURRENT_TEST) with pytest.raises(ValueError): @@ -705,12 +705,12 @@ async def test_dict_mutation_detection__plain_list( [ ( FileUploadState, - {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}}, + {"state.file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}}, ), ( ChildFileUploadState, { - "file_state_base1.child_file_upload_state": { + "state.file_state_base1.child_file_upload_state": { "img_list": ["image1.jpg", "image2.jpg"] } }, @@ -718,14 +718,14 @@ async def test_dict_mutation_detection__plain_list( ( GrandChildFileUploadState, { - "file_state_base1.file_state_base2.grand_child_file_upload_state": { + "state.file_state_base1.file_state_base2.grand_child_file_upload_state": { "img_list": ["image1.jpg", "image2.jpg"] } }, ), ], ) -async def test_upload_file(tmp_path, state, delta, token: str): +async def test_upload_file(tmp_path, state, delta, token: str, mocker): """Test that file upload works correctly. Args: @@ -733,10 +733,15 @@ async def test_upload_file(tmp_path, state, delta, token: str): state: The state class. delta: Expected delta token: a Token. + mocker: pytest mocker object. """ + mocker.patch( + "reflex.state.State.class_subclasses", + {state if state is FileUploadState else FileStateBase1}, + ) state._tmp_path = tmp_path # The App state must be the "root" of the state tree - app = App(state=state if state is FileUploadState else FileStateBase1) + app = App(state=State) app.event_namespace.emit = AsyncMock() # type: ignore current_state = await app.state_manager.get_state(token) data = b"This is binary data" @@ -749,7 +754,7 @@ async def test_upload_file(tmp_path, state, delta, token: str): request_mock = unittest.mock.Mock() request_mock.headers = { "reflex-client-token": token, - "reflex-event-handler": f"{state_name}.multi_handle_upload", + "reflex-event-handler": f"state.{state_name}.multi_handle_upload", } file1 = UploadFile( @@ -851,7 +856,7 @@ async def test_upload_file_background(state, tmp_path, token): await app.state_manager.redis.close() -class DynamicState(State): +class DynamicState(BaseState): """State class for testing dynamic route var. This is defined at module level because event handlers cannot be addressed @@ -891,9 +896,7 @@ class DynamicState(State): @pytest.mark.asyncio async def test_dynamic_route_var_route_change_completed_on_load( - index_page, - windows_platform: bool, - token: str, + index_page, windows_platform: bool, token: str, mocker ): """Create app with dynamic route var, and simulate navigation. @@ -904,7 +907,12 @@ async def test_dynamic_route_var_route_change_completed_on_load( index_page: The index page. windows_platform: Whether the system is windows. token: a Token. + mocker: pytest mocker object. """ + mocker.patch("reflex.state.State.class_subclasses", {DynamicState}) + DynamicState.add_var( + constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False + ) arg_name = "dynamic" route = f"/test/[{arg_name}]" if windows_platform: diff --git a/tests/test_event.py b/tests/test_event.py index 284012b13..2326f0920 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -4,7 +4,7 @@ import pytest from reflex import event from reflex.event import Event, EventHandler, EventSpec, fix_events -from reflex.state import State +from reflex.state import BaseState from reflex.utils import format from reflex.vars import Var @@ -303,7 +303,7 @@ def test_event_actions(): def test_event_actions_on_state(): - class EventActionState(State): + class EventActionState(BaseState): def handler(self): pass diff --git a/tests/test_state.py b/tests/test_state.py index 591026d2c..b4ca0c87a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -19,11 +19,11 @@ from reflex.base import Base from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.event import Event, EventHandler from reflex.state import ( + BaseState, ImmutableStateError, LockExpiredError, MutableProxy, RouterData, - State, StateManager, StateManagerMemory, StateManagerRedis, @@ -75,7 +75,7 @@ class Object(Base): prop2: str = "hello" -class TestState(State): +class TestState(BaseState): """A test state.""" # Set this class as not test one @@ -148,7 +148,7 @@ class GrandchildState(ChildState): pass -class DateTimeState(State): +class DateTimeState(BaseState): """A State with some datetime fields.""" d: datetime.date = datetime.date.fromisoformat("1989-11-09") @@ -253,7 +253,6 @@ def test_class_vars(test_state): """ cls = type(test_state) assert set(cls.vars.keys()) == { - CompileVars.IS_HYDRATED, # added by hydrate_middleware to all State "router", "num1", "num2", @@ -641,7 +640,6 @@ def test_reset(test_state, child_state): "obj", "upper", "complex", - "is_hydrated", "fig", "key", "sum", @@ -837,7 +835,7 @@ def test_get_query_params(test_state): def test_add_var(): - class DynamicState(State): + class DynamicState(BaseState): pass ds1 = DynamicState() @@ -870,7 +868,7 @@ def test_add_var_default_handlers(test_state): assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler) -class InterdependentState(State): +class InterdependentState(BaseState): """A state with 3 vars and 3 computed vars. x: a variable that no computed var depends on @@ -915,7 +913,7 @@ class InterdependentState(State): @pytest.fixture -def interdependent_state() -> State: +def interdependent_state() -> BaseState: """A state with varying dependency between vars. Returns: @@ -988,7 +986,7 @@ def test_per_state_backend_var(interdependent_state): def test_child_state(): """Test that the child state computed vars can reference parent state vars.""" - class MainState(State): + class MainState(BaseState): v: int = 2 class ChildState(MainState): @@ -1006,7 +1004,7 @@ def test_child_state(): def test_conditional_computed_vars(): """Test that computed vars can have conditionals.""" - class MainState(State): + class MainState(BaseState): flag: bool = False t1: str = "a" t2: str = "b" @@ -1051,7 +1049,7 @@ def test_event_handlers_convert_to_fns(test_state, child_state): def test_event_handlers_call_other_handlers(): """Test that event handlers can call other event handlers.""" - class MainState(State): + class MainState(BaseState): v: int = 0 def set_v(self, v: int): @@ -1077,7 +1075,7 @@ def test_computed_var_cached(): """Test that a ComputedVar doesn't recalculate when accessed.""" comp_v_calls = 0 - class ComputedState(State): + class ComputedState(BaseState): v: int = 0 @rx.cached_var @@ -1102,7 +1100,7 @@ def test_computed_var_cached(): def test_computed_var_cached_depends_on_non_cached(): """Test that a cached_var is recalculated if it depends on non-cached ComputedVar.""" - class ComputedState(State): + class ComputedState(BaseState): v: int = 0 @rx.var @@ -1144,7 +1142,7 @@ def test_computed_var_depends_on_parent_non_cached(): """Child state cached_var that depends on parent state un cached var is always recalculated.""" counter = 0 - class ParentState(State): + class ParentState(BaseState): @rx.var def no_cache_v(self) -> int: nonlocal counter @@ -1165,21 +1163,18 @@ def test_computed_var_depends_on_parent_non_cached(): dict1 = ps.dict() assert dict1[ps.get_full_name()] == { "no_cache_v": 1, - CompileVars.IS_HYDRATED: False, "router": formatted_router, } assert dict1[cs.get_full_name()] == {"dep_v": 2} dict2 = ps.dict() assert dict2[ps.get_full_name()] == { "no_cache_v": 3, - CompileVars.IS_HYDRATED: False, "router": formatted_router, } assert dict2[cs.get_full_name()] == {"dep_v": 4} dict3 = ps.dict() assert dict3[ps.get_full_name()] == { "no_cache_v": 5, - CompileVars.IS_HYDRATED: False, "router": formatted_router, } assert dict3[cs.get_full_name()] == {"dep_v": 6} @@ -1195,7 +1190,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): """ counter = 0 - class HandlerState(State): + class HandlerState(BaseState): x: int = 42 def handler(self): @@ -1226,7 +1221,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): def test_computed_var_dependencies(): """Test that a ComputedVar correctly tracks its dependencies.""" - class ComputedState(State): + class ComputedState(BaseState): v: int = 0 w: int = 0 x: int = 0 @@ -1293,7 +1288,7 @@ def test_computed_var_dependencies(): def test_backend_method(): """A method with leading underscore should be callable from event handler.""" - class BackendMethodState(State): + class BackendMethodState(BaseState): def _be_method(self): return True @@ -1369,7 +1364,7 @@ def test_error_on_state_method_shadow(): """Test that an error is thrown when an event handler shadows a state method.""" with pytest.raises(NameError) as err: - class InvalidTest(rx.State): + class InvalidTest(BaseState): def reset(self): pass @@ -1382,7 +1377,7 @@ def test_error_on_state_method_shadow(): def test_state_with_invalid_yield(): """Test that an error is thrown when a state yields an invalid value.""" - class StateWithInvalidYield(rx.State): + class StateWithInvalidYield(BaseState): """A state that yields an invalid value.""" def invalid_handler(self): @@ -1666,7 +1661,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): assert mcall.kwargs["to"] == grandchild_state.get_sid() -class BackgroundTaskState(State): +class BackgroundTaskState(BaseState): """A state with a background task.""" order: List[str] = [] @@ -2192,9 +2187,20 @@ def test_mutable_copy_vars(mutable_state, copy_func): assert not isinstance(var_copy, MutableProxy) -def test_duplicate_substate_class(duplicate_substate): +def test_duplicate_substate_class(mocker): + mocker.patch("reflex.state.os.environ", {}) with pytest.raises(ValueError): - duplicate_substate() + + class TestState(BaseState): + pass + + class ChildTestState(TestState): # type: ignore # noqa + pass + + class ChildTestState(TestState): # type: ignore # noqa + pass + + return TestState class Foo(Base): @@ -2206,7 +2212,7 @@ class Foo(Base): def test_json_dumps_with_mutables(): """Test that json.dumps works with Base vars inside mutable types.""" - class MutableContainsBase(State): + class MutableContainsBase(BaseState): items: List[Foo] = [Foo()] dict_val = MutableContainsBase().dict() @@ -2216,7 +2222,7 @@ def test_json_dumps_with_mutables(): f_formatted_router = str(formatted_router).replace("'", '"') assert ( val - == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}' + == f'{{"{MutableContainsBase.get_full_name()}": {{"items": {f_items}, "router": {f_formatted_router}}}}}' ) @@ -2225,7 +2231,7 @@ def test_reset_with_mutables(): default = [[0, 0], [0, 1], [1, 1]] copied_default = copy.deepcopy(default) - class MutableResetState(State): + class MutableResetState(BaseState): items: List[List[int]] = default instance = MutableResetState() @@ -2273,7 +2279,7 @@ class Custom3(Base): def test_state_union_optional(): """Test that state can be defined with Union and Optional vars.""" - class UnionState(State): + class UnionState(BaseState): int_float: Union[int, float] = 0 opt_int: Optional[int] c3: Optional[Custom3] diff --git a/tests/test_var.py b/tests/test_var.py index c57c7eb23..8e7358bf5 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -6,7 +6,7 @@ import pytest from pandas import DataFrame from reflex.base import Base -from reflex.state import State +from reflex.state import BaseState from reflex.vars import ( BaseVar, ComputedVar, @@ -24,12 +24,6 @@ test_vars = [ ] -class BaseState(State): - """A Test State.""" - - val: str = "key" - - @pytest.fixture def TestObj(): class TestObj(Base): @@ -41,7 +35,7 @@ def TestObj(): @pytest.fixture def ParentState(TestObj): - class ParentState(State): + class ParentState(BaseState): foo: int bar: int @@ -74,7 +68,7 @@ def GrandChildState(ChildState, TestObj): @pytest.fixture def StateWithAnyVar(TestObj): - class StateWithAnyVar(State): + class StateWithAnyVar(BaseState): @ComputedVar def var_without_annotation(self) -> typing.Any: return TestObj @@ -84,7 +78,7 @@ def StateWithAnyVar(TestObj): @pytest.fixture def StateWithCorrectVarAnnotation(): - class StateWithCorrectVarAnnotation(State): + class StateWithCorrectVarAnnotation(BaseState): @ComputedVar def var_with_annotation(self) -> str: return "Correct annotation" @@ -94,7 +88,7 @@ def StateWithCorrectVarAnnotation(): @pytest.fixture def StateWithWrongVarAnnotation(TestObj): - class StateWithWrongVarAnnotation(State): + class StateWithWrongVarAnnotation(BaseState): @ComputedVar def var_with_annotation(self) -> str: return TestObj diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 257f37514..e528536a1 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -528,7 +528,6 @@ formatted_router = { }, "dt": "1989-11-09 18:53:00+01:00", "fig": [], - "is_hydrated": False, "key": "", "map_key": "a", "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]}, @@ -553,7 +552,6 @@ formatted_router = { DateTimeState.get_full_name(): { "d": "1989-11-09", "dt": "1989-11-09 18:53:00+01:00", - "is_hydrated": False, "t": "18:53:00+01:00", "td": "11 days, 0:11:00", "router": formatted_router, diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index b64f5f188..82085dd74 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -10,7 +10,7 @@ from packaging import version from reflex import constants from reflex.base import Base from reflex.event import EventHandler -from reflex.state import State +from reflex.state import BaseState from reflex.utils import ( build, prerequisites, @@ -43,7 +43,7 @@ V056 = version.parse("0.5.6") VMAXPLUS1 = version.parse(get_above_max_version()) -class ExampleTestState(State): +class ExampleTestState(BaseState): """Test state class.""" def test_event_handler(self):