diff --git a/reflex/state.py b/reflex/state.py index 96435dbaf..26bef5d7e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -461,10 +461,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for name, value in cls.__dict__.items() if types.is_backend_base_variable(name, cls) } - # Add annotated backend vars that do not have a default value. + # Add annotated backend vars that may not have a default value. new_backend_vars.update( { - name: Var("", _var_type=annotation_value).get_default_value() + name: cls._get_var_default(name, annotation_value) for name, annotation_value in get_type_hints(cls).items() if name not in new_backend_vars and types.is_backend_base_variable(name, cls) @@ -990,6 +990,26 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Ensure frontend uses null coalescing when accessing. object.__setattr__(prop, "_var_type", Optional[prop._var_type]) + @classmethod + def _get_var_default(cls, name: str, annotation_value: Any) -> Any: + """Get the default value of a (backend) var. + + Args: + name: The name of the var. + annotation_value: The annotation value of the var. + + Returns: + The default value of the var or None. + """ + try: + return getattr(cls, name) + except AttributeError: + try: + return Var("", _var_type=annotation_value).get_default_value() + except TypeError: + pass + return None + @staticmethod def _get_base_functions() -> dict[str, FunctionType]: """Get all functions of the state class excluding dunder methods. diff --git a/tests/integration/test_component_state.py b/tests/integration/test_component_state.py index 77b8b3fa1..7b35f8116 100644 --- a/tests/integration/test_component_state.py +++ b/tests/integration/test_component_state.py @@ -5,6 +5,7 @@ from typing import Generator import pytest from selenium.webdriver.common.by import By +from reflex.state import State, _substate_key from reflex.testing import AppHarness from . import utils @@ -12,13 +13,21 @@ from . import utils def ComponentStateApp(): """App using per component state.""" + from typing import Generic, TypeVar + import reflex as rx - class MultiCounter(rx.ComponentState): + E = TypeVar("E") + + class MultiCounter(rx.ComponentState, Generic[E]): count: int = 0 + _be: E + _be_int: int + _be_str: str = "42" def increment(self): self.count += 1 + self._be = self.count # type: ignore @classmethod def get_component(cls, *children, **props): @@ -48,6 +57,14 @@ def ComponentStateApp(): on_click=mc_a.State.increment, # type: ignore id="inc-a", ), + rx.text( + mc_a.State.get_name() if mc_a.State is not None else "", + id="a_state_name", + ), + rx.text( + mc_b.State.get_name() if mc_b.State is not None else "", + id="b_state_name", + ), ) @@ -80,6 +97,7 @@ async def test_component_state_app(component_state_app: AppHarness): ss = utils.SessionStorage(driver) assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + root_state_token = _substate_key(ss.get("token"), State) count_a = driver.find_element(By.ID, "count-a") count_b = driver.find_element(By.ID, "count-b") @@ -87,6 +105,18 @@ async def test_component_state_app(component_state_app: AppHarness): button_b = driver.find_element(By.ID, "button-b") button_inc_a = driver.find_element(By.ID, "inc-a") + # Check that backend vars in mixins are okay + a_state_name = driver.find_element(By.ID, "a_state_name").text + b_state_name = driver.find_element(By.ID, "b_state_name").text + root_state = await component_state_app.get_state(root_state_token) + a_state = root_state.substates[a_state_name] + b_state = root_state.substates[b_state_name] + assert a_state._backend_vars == a_state.backend_vars + assert a_state._backend_vars == b_state._backend_vars + assert a_state._backend_vars["_be"] is None + assert a_state._backend_vars["_be_int"] == 0 + assert a_state._backend_vars["_be_str"] == "42" + assert count_a.text == "0" button_a.click() @@ -98,6 +128,14 @@ async def test_component_state_app(component_state_app: AppHarness): button_inc_a.click() assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3" + root_state = await component_state_app.get_state(root_state_token) + a_state = root_state.substates[a_state_name] + b_state = root_state.substates[b_state_name] + assert a_state._backend_vars != a_state.backend_vars + assert a_state._be == a_state._backend_vars["_be"] == 3 + assert b_state._be is None + assert b_state._backend_vars["_be"] is None + assert count_b.text == "0" button_b.click() @@ -105,3 +143,9 @@ async def test_component_state_app(component_state_app: AppHarness): button_b.click() assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2" + + root_state = await component_state_app.get_state(root_state_token) + a_state = root_state.substates[a_state_name] + b_state = root_state.substates[b_state_name] + assert b_state._backend_vars != b_state.backend_vars + assert b_state._be == b_state._backend_vars["_be"] == 2