fix: do not allow instantiation of State mixins

Closes #4343
This commit is contained in:
Benedikt Bartscher 2024-11-10 20:54:35 +01:00
parent e0d1a58496
commit b3fb1bdca0
No known key found for this signature in database
3 changed files with 42 additions and 1 deletions

View File

@ -386,6 +386,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"State classes should not be instantiated directly in a Reflex app. " "State classes should not be instantiated directly in a Reflex app. "
"See https://reflex.dev/docs/state/ for further information." "See https://reflex.dev/docs/state/ for further information."
) )
if type(self)._mixin:
raise ReflexRuntimeError(
f"{type(self).__name__} is a state mixin and cannot be instantiated directly."
)
kwargs["parent_state"] = parent_state kwargs["parent_state"] = parent_state
super().__init__() super().__init__()
for name, value in kwargs.items(): for name, value in kwargs.items():

View File

@ -1,7 +1,10 @@
"""Ensure that Components returned by ComponentState.create have independent State classes.""" """Ensure that Components returned by ComponentState.create have independent State classes."""
import pytest
import reflex as rx import reflex as rx
from reflex.components.base.bare import Bare from reflex.components.base.bare import Bare
from reflex.utils.exceptions import ReflexRuntimeError
def test_component_state(): def test_component_state():
@ -40,3 +43,21 @@ def test_component_state():
assert len(cs2.children) == 1 assert len(cs2.children) == 1
assert cs2.children[0].render() == Bare.create("b").render() assert cs2.children[0].render() == Bare.create("b").render()
assert cs2.id == "b" assert cs2.id == "b"
def test_init_component_state() -> None:
"""Ensure that ComponentState subclasses cannot be instantiated directly."""
class CS(rx.ComponentState):
@classmethod
def get_component(cls, *children, **props):
return rx.el.div()
with pytest.raises(ReflexRuntimeError):
CS()
class SubCS(CS):
pass
with pytest.raises(ReflexRuntimeError):
SubCS()

View File

@ -43,7 +43,7 @@ from reflex.state import (
) )
from reflex.testing import chdir from reflex.testing import chdir
from reflex.utils import format, prerequisites, types from reflex.utils import format, prerequisites, types
from reflex.utils.exceptions import SetUndefinedStateVarError from reflex.utils.exceptions import ReflexRuntimeError, SetUndefinedStateVarError
from reflex.utils.format import json_dumps from reflex.utils.format import json_dumps
from reflex.vars.base import Var, computed_var from reflex.vars.base import Var, computed_var
from tests.units.states.mutation import MutableSQLAModel, MutableTestState from tests.units.states.mutation import MutableSQLAModel, MutableTestState
@ -3411,3 +3411,19 @@ def test_typed_state() -> None:
field: rx.Field[str] = rx.field("") field: rx.Field[str] = rx.field("")
_ = TypedState(field="str") _ = TypedState(field="str")
def test_init_mixin() -> None:
"""Ensure that State mixins can not be instantiated directly."""
class Mixin(BaseState, mixin=True):
pass
with pytest.raises(ReflexRuntimeError):
Mixin()
class SubMixin(Mixin, mixin=True):
pass
with pytest.raises(ReflexRuntimeError):
SubMixin()