From 1f9a17539cc21e652cac783c9d9a34be320240b4 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Tue, 19 Nov 2024 04:15:01 +0100 Subject: [PATCH] fix: do not allow instantiation of State mixins (#4347) * fix: do not allow instantiation of State mixins Closes #4343 * improve error message for ComponentState mixins * fix typo Co-authored-by: Masen Furer --------- Co-authored-by: Masen Furer --- reflex/state.py | 22 +++++++++++++++++++ .../units/components/test_component_state.py | 21 ++++++++++++++++++ tests/units/test_state.py | 18 ++++++++++++++- 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/reflex/state.py b/reflex/state.py index 719ff43b3..9ff6f0ea8 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -87,6 +87,7 @@ from reflex.utils.exceptions import ( ImmutableStateError, InvalidStateManagerMode, LockExpiredError, + ReflexRuntimeError, SetUndefinedStateVarError, StateSchemaMismatchError, ) @@ -387,6 +388,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): "State classes should not be instantiated directly in a Reflex app. " "See https://reflex.dev/docs/state/ for further information." ) + if type(self)._mixin: + raise ReflexRuntimeError( + f"{type(self).__name__} is a state mixin and cannot be instantiated directly." + ) kwargs["parent_state"] = parent_state super().__init__() for name, value in kwargs.items(): @@ -2367,6 +2372,23 @@ class ComponentState(State, mixin=True): # The number of components created from this class. _per_component_state_instance_count: ClassVar[int] = 0 + def __init__(self, *args, **kwargs): + """Do not allow direct initialization of the ComponentState. + + Args: + *args: The args to pass to the State init method. + **kwargs: The kwargs to pass to the State init method. + + Raises: + ReflexRuntimeError: If the ComponentState is initialized directly. + """ + if type(self)._mixin: + raise ReflexRuntimeError( + f"{ComponentState.__name__} {type(self).__name__} is not meant to be initialized directly. " + + "Use the `create` method to create a new instance and access the state via the `State` attribute." + ) + super().__init__(*args, **kwargs) + @classmethod def __init_subclass__(cls, mixin: bool = True, **kwargs): """Overwrite mixin default to True. diff --git a/tests/units/components/test_component_state.py b/tests/units/components/test_component_state.py index 574997ba5..1b62e35c8 100644 --- a/tests/units/components/test_component_state.py +++ b/tests/units/components/test_component_state.py @@ -1,7 +1,10 @@ """Ensure that Components returned by ComponentState.create have independent State classes.""" +import pytest + import reflex as rx from reflex.components.base.bare import Bare +from reflex.utils.exceptions import ReflexRuntimeError def test_component_state(): @@ -40,3 +43,21 @@ def test_component_state(): assert len(cs2.children) == 1 assert cs2.children[0].render() == Bare.create("b").render() assert cs2.id == "b" + + +def test_init_component_state() -> None: + """Ensure that ComponentState subclasses cannot be instantiated directly.""" + + class CS(rx.ComponentState): + @classmethod + def get_component(cls, *children, **props): + return rx.el.div() + + with pytest.raises(ReflexRuntimeError): + CS() + + class SubCS(CS): + pass + + with pytest.raises(ReflexRuntimeError): + SubCS() diff --git a/tests/units/test_state.py b/tests/units/test_state.py index a69b9916a..7cebaff8e 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -43,7 +43,7 @@ from reflex.state import ( ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types -from reflex.utils.exceptions import SetUndefinedStateVarError +from reflex.utils.exceptions import ReflexRuntimeError, SetUndefinedStateVarError from reflex.utils.format import json_dumps from reflex.vars.base import Var, computed_var from tests.units.states.mutation import MutableSQLAModel, MutableTestState @@ -3441,3 +3441,19 @@ def test_get_value(): "bar": "foo", } } + + +def test_init_mixin() -> None: + """Ensure that State mixins can not be instantiated directly.""" + + class Mixin(BaseState, mixin=True): + pass + + with pytest.raises(ReflexRuntimeError): + Mixin() + + class SubMixin(Mixin, mixin=True): + pass + + with pytest.raises(ReflexRuntimeError): + SubMixin()