fix: do not allow instantiation of State mixins (#4347)

* fix: do not allow instantiation of State mixins
Closes #4343

* improve error message for ComponentState mixins

* fix typo

Co-authored-by: Masen Furer <m_github@0x26.net>

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
benedikt-bartscher 2024-11-19 04:15:01 +01:00 committed by GitHub
parent af4fe48428
commit 1f9a17539c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 1 deletions

View File

@ -87,6 +87,7 @@ from reflex.utils.exceptions import (
ImmutableStateError, ImmutableStateError,
InvalidStateManagerMode, InvalidStateManagerMode,
LockExpiredError, LockExpiredError,
ReflexRuntimeError,
SetUndefinedStateVarError, SetUndefinedStateVarError,
StateSchemaMismatchError, StateSchemaMismatchError,
) )
@ -387,6 +388,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"State classes should not be instantiated directly in a Reflex app. " "State classes should not be instantiated directly in a Reflex app. "
"See https://reflex.dev/docs/state/ for further information." "See https://reflex.dev/docs/state/ for further information."
) )
if type(self)._mixin:
raise ReflexRuntimeError(
f"{type(self).__name__} is a state mixin and cannot be instantiated directly."
)
kwargs["parent_state"] = parent_state kwargs["parent_state"] = parent_state
super().__init__() super().__init__()
for name, value in kwargs.items(): for name, value in kwargs.items():
@ -2367,6 +2372,23 @@ class ComponentState(State, mixin=True):
# The number of components created from this class. # The number of components created from this class.
_per_component_state_instance_count: ClassVar[int] = 0 _per_component_state_instance_count: ClassVar[int] = 0
def __init__(self, *args, **kwargs):
"""Do not allow direct initialization of the ComponentState.
Args:
*args: The args to pass to the State init method.
**kwargs: The kwargs to pass to the State init method.
Raises:
ReflexRuntimeError: If the ComponentState is initialized directly.
"""
if type(self)._mixin:
raise ReflexRuntimeError(
f"{ComponentState.__name__} {type(self).__name__} is not meant to be initialized directly. "
+ "Use the `create` method to create a new instance and access the state via the `State` attribute."
)
super().__init__(*args, **kwargs)
@classmethod @classmethod
def __init_subclass__(cls, mixin: bool = True, **kwargs): def __init_subclass__(cls, mixin: bool = True, **kwargs):
"""Overwrite mixin default to True. """Overwrite mixin default to True.

View File

@ -1,7 +1,10 @@
"""Ensure that Components returned by ComponentState.create have independent State classes.""" """Ensure that Components returned by ComponentState.create have independent State classes."""
import pytest
import reflex as rx import reflex as rx
from reflex.components.base.bare import Bare from reflex.components.base.bare import Bare
from reflex.utils.exceptions import ReflexRuntimeError
def test_component_state(): def test_component_state():
@ -40,3 +43,21 @@ def test_component_state():
assert len(cs2.children) == 1 assert len(cs2.children) == 1
assert cs2.children[0].render() == Bare.create("b").render() assert cs2.children[0].render() == Bare.create("b").render()
assert cs2.id == "b" assert cs2.id == "b"
def test_init_component_state() -> None:
"""Ensure that ComponentState subclasses cannot be instantiated directly."""
class CS(rx.ComponentState):
@classmethod
def get_component(cls, *children, **props):
return rx.el.div()
with pytest.raises(ReflexRuntimeError):
CS()
class SubCS(CS):
pass
with pytest.raises(ReflexRuntimeError):
SubCS()

View File

@ -43,7 +43,7 @@ from reflex.state import (
) )
from reflex.testing import chdir from reflex.testing import chdir
from reflex.utils import format, prerequisites, types from reflex.utils import format, prerequisites, types
from reflex.utils.exceptions import SetUndefinedStateVarError from reflex.utils.exceptions import ReflexRuntimeError, SetUndefinedStateVarError
from reflex.utils.format import json_dumps from reflex.utils.format import json_dumps
from reflex.vars.base import Var, computed_var from reflex.vars.base import Var, computed_var
from tests.units.states.mutation import MutableSQLAModel, MutableTestState from tests.units.states.mutation import MutableSQLAModel, MutableTestState
@ -3441,3 +3441,19 @@ def test_get_value():
"bar": "foo", "bar": "foo",
} }
} }
def test_init_mixin() -> None:
"""Ensure that State mixins can not be instantiated directly."""
class Mixin(BaseState, mixin=True):
pass
with pytest.raises(ReflexRuntimeError):
Mixin()
class SubMixin(Mixin, mixin=True):
pass
with pytest.raises(ReflexRuntimeError):
SubMixin()