From b3fb1bdca025f69006d8b4824ca515300cc3c6ad Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sun, 10 Nov 2024 20:54:35 +0100 Subject: [PATCH] fix: do not allow instantiation of State mixins Closes #4343 --- reflex/state.py | 4 ++++ .../units/components/test_component_state.py | 21 +++++++++++++++++++ tests/units/test_state.py | 18 +++++++++++++++- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/reflex/state.py b/reflex/state.py index 66b1e3cab..724aeee90 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -386,6 +386,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): "State classes should not be instantiated directly in a Reflex app. " "See https://reflex.dev/docs/state/ for further information." ) + if type(self)._mixin: + raise ReflexRuntimeError( + f"{type(self).__name__} is a state mixin and cannot be instantiated directly." + ) kwargs["parent_state"] = parent_state super().__init__() for name, value in kwargs.items(): diff --git a/tests/units/components/test_component_state.py b/tests/units/components/test_component_state.py index 574997ba5..1b62e35c8 100644 --- a/tests/units/components/test_component_state.py +++ b/tests/units/components/test_component_state.py @@ -1,7 +1,10 @@ """Ensure that Components returned by ComponentState.create have independent State classes.""" +import pytest + import reflex as rx from reflex.components.base.bare import Bare +from reflex.utils.exceptions import ReflexRuntimeError def test_component_state(): @@ -40,3 +43,21 @@ def test_component_state(): assert len(cs2.children) == 1 assert cs2.children[0].render() == Bare.create("b").render() assert cs2.id == "b" + + +def test_init_component_state() -> None: + """Ensure that ComponentState subclasses cannot be instantiated directly.""" + + class CS(rx.ComponentState): + @classmethod + def get_component(cls, *children, **props): + return rx.el.div() + + with pytest.raises(ReflexRuntimeError): + CS() + + class SubCS(CS): + pass + + with pytest.raises(ReflexRuntimeError): + SubCS() diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 2ce0b7bd5..e427c17d9 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -43,7 +43,7 @@ from reflex.state import ( ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types -from reflex.utils.exceptions import SetUndefinedStateVarError +from reflex.utils.exceptions import ReflexRuntimeError, SetUndefinedStateVarError from reflex.utils.format import json_dumps from reflex.vars.base import Var, computed_var from tests.units.states.mutation import MutableSQLAModel, MutableTestState @@ -3411,3 +3411,19 @@ def test_typed_state() -> None: field: rx.Field[str] = rx.field("") _ = TypedState(field="str") + + +def test_init_mixin() -> None: + """Ensure that State mixins can not be instantiated directly.""" + + class Mixin(BaseState, mixin=True): + pass + + with pytest.raises(ReflexRuntimeError): + Mixin() + + class SubMixin(Mixin, mixin=True): + pass + + with pytest.raises(ReflexRuntimeError): + SubMixin()