diff --git a/integration/test_component_state.py b/integration/test_component_state.py index d2c10c766..e903a1b74 100644 --- a/integration/test_component_state.py +++ b/integration/test_component_state.py @@ -1,4 +1,5 @@ """Test that per-component state scaffold works and operates independently.""" + from typing import Generator import pytest diff --git a/integration/test_state_inheritance.py b/integration/test_state_inheritance.py index 08a7fc951..86ab625e1 100644 --- a/integration/test_state_inheritance.py +++ b/integration/test_state_inheritance.py @@ -45,17 +45,15 @@ def StateInheritance(): """Test that state inheritance works as expected.""" import reflex as rx - class ChildMixin: - # mixin basevars only work with pydantic/rx.Base models - # child_mixin: str = "child_mixin" + class ChildMixin(rx.State, mixin=True): + child_mixin: str = "child_mixin" @rx.var def computed_child_mixin(self) -> str: return "computed_child_mixin" - class Mixin(ChildMixin): - # mixin basevars only work with pydantic/rx.Base models - # mixin: str = "mixin" + class Mixin(ChildMixin, mixin=True): + mixin: str = "mixin" @rx.var def computed_mixin(self) -> str: @@ -64,7 +62,7 @@ def StateInheritance(): def on_click_mixin(self): return rx.call_script("alert('clicked')") - class OtherMixin(rx.Base): + class OtherMixin(rx.State, mixin=True): other_mixin: str = "other_mixin" other_mixin_clicks: int = 0 @@ -78,7 +76,7 @@ def StateInheritance(): f"{self.__class__.__name__}.clicked.{self.other_mixin_clicks}" ) - class Base1(rx.State, Mixin): + class Base1(Mixin, rx.State): _base1: str = "_base1" base1: str = "base1" @@ -122,14 +120,15 @@ def StateInheritance(): def index() -> rx.Component: return rx.vstack( - rx.chakra.input( + rx.input( id="token", value=Base1.router.session.client_token, is_read_only=True ), - # Base 1 + # Base 1 (Mixin, ChildMixin) rx.heading(Base1.computed_mixin, id="base1-computed_mixin"), rx.heading(Base1.computed_basevar, id="base1-computed_basevar"), - rx.heading(Base1.computed_child_mixin, id="base1-child-mixin"), + rx.heading(Base1.computed_child_mixin, id="base1-computed-child-mixin"), rx.heading(Base1.base1, id="base1-base1"), + rx.heading(Base1.child_mixin, id="base1-child-mixin"), rx.button( "Base1.on_click_mixin", on_click=Base1.on_click_mixin, # type: ignore @@ -138,31 +137,33 @@ def StateInheritance(): rx.heading( Base1.computed_backend_vars_base1, id="base1-computed_backend_vars" ), - # Base 2 + # Base 2 (no mixins) rx.heading(Base2.computed_basevar, id="base2-computed_basevar"), rx.heading(Base2.base2, id="base2-base2"), rx.heading( Base2.computed_backend_vars_base2, id="base2-computed_backend_vars" ), - # Child 1 + # Child 1 (Mixin, ChildMixin, OtherMixin) rx.heading(Child1.computed_basevar, id="child1-computed_basevar"), rx.heading(Child1.computed_mixin, id="child1-computed_mixin"), rx.heading(Child1.computed_other_mixin, id="child1-other-mixin"), - rx.heading(Child1.computed_child_mixin, id="child1-child-mixin"), + rx.heading(Child1.computed_child_mixin, id="child1-computed-child-mixin"), rx.heading(Child1.base1, id="child1-base1"), rx.heading(Child1.other_mixin, id="child1-other_mixin"), + rx.heading(Child1.child_mixin, id="child1-child-mixin"), rx.button( "Child1.on_click_other_mixin", on_click=Child1.on_click_other_mixin, # type: ignore id="child1-other-mixin-btn", ), - # Child 2 + # Child 2 (Mixin, ChildMixin, OtherMixin) rx.heading(Child2.computed_basevar, id="child2-computed_basevar"), rx.heading(Child2.computed_mixin, id="child2-computed_mixin"), rx.heading(Child2.computed_other_mixin, id="child2-other-mixin"), - rx.heading(Child2.computed_child_mixin, id="child2-child-mixin"), + rx.heading(Child2.computed_child_mixin, id="child2-computed-child-mixin"), rx.heading(Child2.base2, id="child2-base2"), rx.heading(Child2.other_mixin, id="child2-other_mixin"), + rx.heading(Child2.child_mixin, id="child2-child-mixin"), rx.button( "Child2.on_click_mixin", on_click=Child2.on_click_mixin, # type: ignore @@ -173,15 +174,16 @@ def StateInheritance(): on_click=Child2.on_click_other_mixin, # type: ignore id="child2-other-mixin-btn", ), - # Child 3 + # Child 3 (Mixin, ChildMixin, OtherMixin) rx.heading(Child3.computed_basevar, id="child3-computed_basevar"), rx.heading(Child3.computed_mixin, id="child3-computed_mixin"), rx.heading(Child3.computed_other_mixin, id="child3-other-mixin"), rx.heading(Child3.computed_childvar, id="child3-computed_childvar"), - rx.heading(Child3.computed_child_mixin, id="child3-child-mixin"), + rx.heading(Child3.computed_child_mixin, id="child3-computed-child-mixin"), rx.heading(Child3.child3, id="child3-child3"), rx.heading(Child3.base2, id="child3-base2"), rx.heading(Child3.other_mixin, id="child3-other_mixin"), + rx.heading(Child3.child_mixin, id="child3-child-mixin"), rx.button( "Child3.on_click_mixin", on_click=Child3.on_click_mixin, # type: ignore @@ -282,7 +284,9 @@ def test_state_inheritance( base1_computed_basevar = driver.find_element(By.ID, "base1-computed_basevar") assert base1_computed_basevar.text == "computed_basevar1" - base1_computed_child_mixin = driver.find_element(By.ID, "base1-child-mixin") + base1_computed_child_mixin = driver.find_element( + By.ID, "base1-computed-child-mixin" + ) assert base1_computed_child_mixin.text == "computed_child_mixin" base1_base1 = driver.find_element(By.ID, "base1-base1") @@ -293,6 +297,9 @@ def test_state_inheritance( ) assert base1_computed_backend_vars.text == "_base1" + base1_child_mixin = driver.find_element(By.ID, "base1-child-mixin") + assert base1_child_mixin.text == "child_mixin" + # Base 2 base2_computed_basevar = driver.find_element(By.ID, "base2-computed_basevar") assert base2_computed_basevar.text == "computed_basevar2" @@ -315,7 +322,9 @@ def test_state_inheritance( child1_computed_other_mixin = driver.find_element(By.ID, "child1-other-mixin") assert child1_computed_other_mixin.text == "other_mixin" - child1_computed_child_mixin = driver.find_element(By.ID, "child1-child-mixin") + child1_computed_child_mixin = driver.find_element( + By.ID, "child1-computed-child-mixin" + ) assert child1_computed_child_mixin.text == "computed_child_mixin" child1_base1 = driver.find_element(By.ID, "child1-base1") @@ -324,6 +333,9 @@ def test_state_inheritance( child1_other_mixin = driver.find_element(By.ID, "child1-other_mixin") assert child1_other_mixin.text == "other_mixin" + child1_child_mixin = driver.find_element(By.ID, "child1-child-mixin") + assert child1_child_mixin.text == "child_mixin" + # Child 2 child2_computed_basevar = driver.find_element(By.ID, "child2-computed_basevar") assert child2_computed_basevar.text == "computed_basevar2" @@ -334,7 +346,9 @@ def test_state_inheritance( child2_computed_other_mixin = driver.find_element(By.ID, "child2-other-mixin") assert child2_computed_other_mixin.text == "other_mixin" - child2_computed_child_mixin = driver.find_element(By.ID, "child2-child-mixin") + child2_computed_child_mixin = driver.find_element( + By.ID, "child2-computed-child-mixin" + ) assert child2_computed_child_mixin.text == "computed_child_mixin" child2_base2 = driver.find_element(By.ID, "child2-base2") @@ -343,6 +357,9 @@ def test_state_inheritance( child2_other_mixin = driver.find_element(By.ID, "child2-other_mixin") assert child2_other_mixin.text == "other_mixin" + child2_child_mixin = driver.find_element(By.ID, "child2-child-mixin") + assert child2_child_mixin.text == "child_mixin" + # Child 3 child3_computed_basevar = driver.find_element(By.ID, "child3-computed_basevar") assert child3_computed_basevar.text == "computed_basevar2" @@ -356,7 +373,9 @@ def test_state_inheritance( child3_computed_childvar = driver.find_element(By.ID, "child3-computed_childvar") assert child3_computed_childvar.text == "computed_childvar" - child3_computed_child_mixin = driver.find_element(By.ID, "child3-child-mixin") + child3_computed_child_mixin = driver.find_element( + By.ID, "child3-computed-child-mixin" + ) assert child3_computed_child_mixin.text == "computed_child_mixin" child3_child3 = driver.find_element(By.ID, "child3-child3") @@ -368,6 +387,9 @@ def test_state_inheritance( child3_other_mixin = driver.find_element(By.ID, "child3-other_mixin") assert child3_other_mixin.text == "other_mixin" + child3_child_mixin = driver.find_element(By.ID, "child3-child-mixin") + assert child3_child_mixin.text == "child_mixin" + child3_computed_backend_vars = driver.find_element( By.ID, "child3-computed_backend_vars" ) diff --git a/reflex/state.py b/reflex/state.py index ebe33aa26..287f70073 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -359,6 +359,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Whether the state has ever been touched since instantiation. _was_touched: bool = False + # Whether this state class is a mixin and should not be instantiated. + _mixin: ClassVar[bool] = False + # A special event handler for setting base vars. setvar: ClassVar[EventHandler] @@ -428,17 +431,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """ return [ v - for mixin in cls.__mro__ - if mixin is cls or not issubclass(mixin, (BaseState, ABC)) + for mixin in cls._mixins() + [cls] for v in mixin.__dict__.values() if isinstance(v, ComputedVar) ] @classmethod - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, mixin: bool = False, **kwargs): """Do some magic for the subclass initialization. Args: + mixin: Whether the subclass is a mixin and should not be initialized. **kwargs: The kwargs to pass to the pydantic init_subclass method. Raises: @@ -447,6 +450,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): from reflex.utils.exceptions import StateValueError super().__init_subclass__(**kwargs) + + cls._mixin = mixin + if mixin: + return + # Event handlers should not shadow builtin state methods. cls._check_overridden_methods() # Computed vars should not shadow builtin state props. @@ -618,8 +626,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return [ mixin for mixin in cls.__mro__ - if not issubclass(mixin, (BaseState, ABC)) - and mixin not in [pydantic.BaseModel, Base] + if ( + mixin not in [pydantic.BaseModel, Base, cls] + and issubclass(mixin, BaseState) + and mixin._mixin is True + ) ] @classmethod @@ -742,7 +753,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): parent_states = [ base for base in cls.__bases__ - if types._issubclass(base, BaseState) and base is not BaseState + if issubclass(base, BaseState) and base is not BaseState and not base._mixin ] assert len(parent_states) < 2, "Only one parent state is allowed." return parent_states[0] if len(parent_states) == 1 else None # type: ignore @@ -1833,7 +1844,7 @@ class OnLoadInternalState(State): ] -class ComponentState(Base): +class ComponentState(State, mixin=True): """Base class to allow for the creation of a state instance per component. This allows for the bundling of UI and state logic into a single class, @@ -1875,6 +1886,18 @@ class ComponentState(Base): # The number of components created from this class. _per_component_state_instance_count: ClassVar[int] = 0 + @classmethod + def __init_subclass__(cls, mixin: bool = False, **kwargs): + """Overwrite mixin default to True. + + Args: + mixin: Whether the subclass is a mixin and should not be initialized. + **kwargs: The kwargs to pass to the pydantic init_subclass method. + """ + if ComponentState in cls.__bases__: + mixin = True + super().__init_subclass__(mixin=mixin, **kwargs) + @classmethod def get_component(cls, *children, **props) -> "Component": """Get the component instance.