typed mixins and ComponentState (#3196)

* typed mixins

* implicit mixin=True kwarg for ComponentState subclasses

* fix: always init other subclasses

* adjust tests: all mixins support base vars now
This commit is contained in:
benedikt-bartscher 2024-05-15 21:07:41 +02:00 committed by GitHub
parent 87a3ddea7f
commit d96baac7d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 75 additions and 29 deletions

View File

@ -1,4 +1,5 @@
"""Test that per-component state scaffold works and operates independently."""
from typing import Generator
import pytest

View File

@ -45,17 +45,15 @@ def StateInheritance():
"""Test that state inheritance works as expected."""
import reflex as rx
class ChildMixin:
# mixin basevars only work with pydantic/rx.Base models
# child_mixin: str = "child_mixin"
class ChildMixin(rx.State, mixin=True):
child_mixin: str = "child_mixin"
@rx.var
def computed_child_mixin(self) -> str:
return "computed_child_mixin"
class Mixin(ChildMixin):
# mixin basevars only work with pydantic/rx.Base models
# mixin: str = "mixin"
class Mixin(ChildMixin, mixin=True):
mixin: str = "mixin"
@rx.var
def computed_mixin(self) -> str:
@ -64,7 +62,7 @@ def StateInheritance():
def on_click_mixin(self):
return rx.call_script("alert('clicked')")
class OtherMixin(rx.Base):
class OtherMixin(rx.State, mixin=True):
other_mixin: str = "other_mixin"
other_mixin_clicks: int = 0
@ -78,7 +76,7 @@ def StateInheritance():
f"{self.__class__.__name__}.clicked.{self.other_mixin_clicks}"
)
class Base1(rx.State, Mixin):
class Base1(Mixin, rx.State):
_base1: str = "_base1"
base1: str = "base1"
@ -122,14 +120,15 @@ def StateInheritance():
def index() -> rx.Component:
return rx.vstack(
rx.chakra.input(
rx.input(
id="token", value=Base1.router.session.client_token, is_read_only=True
),
# Base 1
# Base 1 (Mixin, ChildMixin)
rx.heading(Base1.computed_mixin, id="base1-computed_mixin"),
rx.heading(Base1.computed_basevar, id="base1-computed_basevar"),
rx.heading(Base1.computed_child_mixin, id="base1-child-mixin"),
rx.heading(Base1.computed_child_mixin, id="base1-computed-child-mixin"),
rx.heading(Base1.base1, id="base1-base1"),
rx.heading(Base1.child_mixin, id="base1-child-mixin"),
rx.button(
"Base1.on_click_mixin",
on_click=Base1.on_click_mixin, # type: ignore
@ -138,31 +137,33 @@ def StateInheritance():
rx.heading(
Base1.computed_backend_vars_base1, id="base1-computed_backend_vars"
),
# Base 2
# Base 2 (no mixins)
rx.heading(Base2.computed_basevar, id="base2-computed_basevar"),
rx.heading(Base2.base2, id="base2-base2"),
rx.heading(
Base2.computed_backend_vars_base2, id="base2-computed_backend_vars"
),
# Child 1
# Child 1 (Mixin, ChildMixin, OtherMixin)
rx.heading(Child1.computed_basevar, id="child1-computed_basevar"),
rx.heading(Child1.computed_mixin, id="child1-computed_mixin"),
rx.heading(Child1.computed_other_mixin, id="child1-other-mixin"),
rx.heading(Child1.computed_child_mixin, id="child1-child-mixin"),
rx.heading(Child1.computed_child_mixin, id="child1-computed-child-mixin"),
rx.heading(Child1.base1, id="child1-base1"),
rx.heading(Child1.other_mixin, id="child1-other_mixin"),
rx.heading(Child1.child_mixin, id="child1-child-mixin"),
rx.button(
"Child1.on_click_other_mixin",
on_click=Child1.on_click_other_mixin, # type: ignore
id="child1-other-mixin-btn",
),
# Child 2
# Child 2 (Mixin, ChildMixin, OtherMixin)
rx.heading(Child2.computed_basevar, id="child2-computed_basevar"),
rx.heading(Child2.computed_mixin, id="child2-computed_mixin"),
rx.heading(Child2.computed_other_mixin, id="child2-other-mixin"),
rx.heading(Child2.computed_child_mixin, id="child2-child-mixin"),
rx.heading(Child2.computed_child_mixin, id="child2-computed-child-mixin"),
rx.heading(Child2.base2, id="child2-base2"),
rx.heading(Child2.other_mixin, id="child2-other_mixin"),
rx.heading(Child2.child_mixin, id="child2-child-mixin"),
rx.button(
"Child2.on_click_mixin",
on_click=Child2.on_click_mixin, # type: ignore
@ -173,15 +174,16 @@ def StateInheritance():
on_click=Child2.on_click_other_mixin, # type: ignore
id="child2-other-mixin-btn",
),
# Child 3
# Child 3 (Mixin, ChildMixin, OtherMixin)
rx.heading(Child3.computed_basevar, id="child3-computed_basevar"),
rx.heading(Child3.computed_mixin, id="child3-computed_mixin"),
rx.heading(Child3.computed_other_mixin, id="child3-other-mixin"),
rx.heading(Child3.computed_childvar, id="child3-computed_childvar"),
rx.heading(Child3.computed_child_mixin, id="child3-child-mixin"),
rx.heading(Child3.computed_child_mixin, id="child3-computed-child-mixin"),
rx.heading(Child3.child3, id="child3-child3"),
rx.heading(Child3.base2, id="child3-base2"),
rx.heading(Child3.other_mixin, id="child3-other_mixin"),
rx.heading(Child3.child_mixin, id="child3-child-mixin"),
rx.button(
"Child3.on_click_mixin",
on_click=Child3.on_click_mixin, # type: ignore
@ -282,7 +284,9 @@ def test_state_inheritance(
base1_computed_basevar = driver.find_element(By.ID, "base1-computed_basevar")
assert base1_computed_basevar.text == "computed_basevar1"
base1_computed_child_mixin = driver.find_element(By.ID, "base1-child-mixin")
base1_computed_child_mixin = driver.find_element(
By.ID, "base1-computed-child-mixin"
)
assert base1_computed_child_mixin.text == "computed_child_mixin"
base1_base1 = driver.find_element(By.ID, "base1-base1")
@ -293,6 +297,9 @@ def test_state_inheritance(
)
assert base1_computed_backend_vars.text == "_base1"
base1_child_mixin = driver.find_element(By.ID, "base1-child-mixin")
assert base1_child_mixin.text == "child_mixin"
# Base 2
base2_computed_basevar = driver.find_element(By.ID, "base2-computed_basevar")
assert base2_computed_basevar.text == "computed_basevar2"
@ -315,7 +322,9 @@ def test_state_inheritance(
child1_computed_other_mixin = driver.find_element(By.ID, "child1-other-mixin")
assert child1_computed_other_mixin.text == "other_mixin"
child1_computed_child_mixin = driver.find_element(By.ID, "child1-child-mixin")
child1_computed_child_mixin = driver.find_element(
By.ID, "child1-computed-child-mixin"
)
assert child1_computed_child_mixin.text == "computed_child_mixin"
child1_base1 = driver.find_element(By.ID, "child1-base1")
@ -324,6 +333,9 @@ def test_state_inheritance(
child1_other_mixin = driver.find_element(By.ID, "child1-other_mixin")
assert child1_other_mixin.text == "other_mixin"
child1_child_mixin = driver.find_element(By.ID, "child1-child-mixin")
assert child1_child_mixin.text == "child_mixin"
# Child 2
child2_computed_basevar = driver.find_element(By.ID, "child2-computed_basevar")
assert child2_computed_basevar.text == "computed_basevar2"
@ -334,7 +346,9 @@ def test_state_inheritance(
child2_computed_other_mixin = driver.find_element(By.ID, "child2-other-mixin")
assert child2_computed_other_mixin.text == "other_mixin"
child2_computed_child_mixin = driver.find_element(By.ID, "child2-child-mixin")
child2_computed_child_mixin = driver.find_element(
By.ID, "child2-computed-child-mixin"
)
assert child2_computed_child_mixin.text == "computed_child_mixin"
child2_base2 = driver.find_element(By.ID, "child2-base2")
@ -343,6 +357,9 @@ def test_state_inheritance(
child2_other_mixin = driver.find_element(By.ID, "child2-other_mixin")
assert child2_other_mixin.text == "other_mixin"
child2_child_mixin = driver.find_element(By.ID, "child2-child-mixin")
assert child2_child_mixin.text == "child_mixin"
# Child 3
child3_computed_basevar = driver.find_element(By.ID, "child3-computed_basevar")
assert child3_computed_basevar.text == "computed_basevar2"
@ -356,7 +373,9 @@ def test_state_inheritance(
child3_computed_childvar = driver.find_element(By.ID, "child3-computed_childvar")
assert child3_computed_childvar.text == "computed_childvar"
child3_computed_child_mixin = driver.find_element(By.ID, "child3-child-mixin")
child3_computed_child_mixin = driver.find_element(
By.ID, "child3-computed-child-mixin"
)
assert child3_computed_child_mixin.text == "computed_child_mixin"
child3_child3 = driver.find_element(By.ID, "child3-child3")
@ -368,6 +387,9 @@ def test_state_inheritance(
child3_other_mixin = driver.find_element(By.ID, "child3-other_mixin")
assert child3_other_mixin.text == "other_mixin"
child3_child_mixin = driver.find_element(By.ID, "child3-child-mixin")
assert child3_child_mixin.text == "child_mixin"
child3_computed_backend_vars = driver.find_element(
By.ID, "child3-computed_backend_vars"
)

View File

@ -359,6 +359,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Whether the state has ever been touched since instantiation.
_was_touched: bool = False
# Whether this state class is a mixin and should not be instantiated.
_mixin: ClassVar[bool] = False
# A special event handler for setting base vars.
setvar: ClassVar[EventHandler]
@ -428,17 +431,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""
return [
v
for mixin in cls.__mro__
if mixin is cls or not issubclass(mixin, (BaseState, ABC))
for mixin in cls._mixins() + [cls]
for v in mixin.__dict__.values()
if isinstance(v, ComputedVar)
]
@classmethod
def __init_subclass__(cls, **kwargs):
def __init_subclass__(cls, mixin: bool = False, **kwargs):
"""Do some magic for the subclass initialization.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
Raises:
@ -447,6 +450,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
from reflex.utils.exceptions import StateValueError
super().__init_subclass__(**kwargs)
cls._mixin = mixin
if mixin:
return
# Event handlers should not shadow builtin state methods.
cls._check_overridden_methods()
# Computed vars should not shadow builtin state props.
@ -618,8 +626,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return [
mixin
for mixin in cls.__mro__
if not issubclass(mixin, (BaseState, ABC))
and mixin not in [pydantic.BaseModel, Base]
if (
mixin not in [pydantic.BaseModel, Base, cls]
and issubclass(mixin, BaseState)
and mixin._mixin is True
)
]
@classmethod
@ -742,7 +753,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
parent_states = [
base
for base in cls.__bases__
if types._issubclass(base, BaseState) and base is not BaseState
if issubclass(base, BaseState) and base is not BaseState and not base._mixin
]
assert len(parent_states) < 2, "Only one parent state is allowed."
return parent_states[0] if len(parent_states) == 1 else None # type: ignore
@ -1833,7 +1844,7 @@ class OnLoadInternalState(State):
]
class ComponentState(Base):
class ComponentState(State, mixin=True):
"""Base class to allow for the creation of a state instance per component.
This allows for the bundling of UI and state logic into a single class,
@ -1875,6 +1886,18 @@ class ComponentState(Base):
# The number of components created from this class.
_per_component_state_instance_count: ClassVar[int] = 0
@classmethod
def __init_subclass__(cls, mixin: bool = False, **kwargs):
"""Overwrite mixin default to True.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
"""
if ComponentState in cls.__bases__:
mixin = True
super().__init_subclass__(mixin=mixin, **kwargs)
@classmethod
def get_component(cls, *children, **props) -> "Component":
"""Get the component instance.