fix: do not allow instantiation of State mixins (#4347)

* fix: do not allow instantiation of State mixins
Closes #4343

* improve error message for ComponentState mixins

* fix typo

Co-authored-by: Masen Furer <m_github@0x26.net>

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
benedikt-bartscher 2024-11-19 04:15:01 +01:00 committed by GitHub
parent af4fe48428
commit 1f9a17539c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 1 deletions

View File

@ -87,6 +87,7 @@ from reflex.utils.exceptions import (
ImmutableStateError,
InvalidStateManagerMode,
LockExpiredError,
ReflexRuntimeError,
SetUndefinedStateVarError,
StateSchemaMismatchError,
)
@ -387,6 +388,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"State classes should not be instantiated directly in a Reflex app. "
"See https://reflex.dev/docs/state/ for further information."
)
if type(self)._mixin:
raise ReflexRuntimeError(
f"{type(self).__name__} is a state mixin and cannot be instantiated directly."
)
kwargs["parent_state"] = parent_state
super().__init__()
for name, value in kwargs.items():
@ -2367,6 +2372,23 @@ class ComponentState(State, mixin=True):
# The number of components created from this class.
_per_component_state_instance_count: ClassVar[int] = 0
def __init__(self, *args, **kwargs):
"""Do not allow direct initialization of the ComponentState.
Args:
*args: The args to pass to the State init method.
**kwargs: The kwargs to pass to the State init method.
Raises:
ReflexRuntimeError: If the ComponentState is initialized directly.
"""
if type(self)._mixin:
raise ReflexRuntimeError(
f"{ComponentState.__name__} {type(self).__name__} is not meant to be initialized directly. "
+ "Use the `create` method to create a new instance and access the state via the `State` attribute."
)
super().__init__(*args, **kwargs)
@classmethod
def __init_subclass__(cls, mixin: bool = True, **kwargs):
"""Overwrite mixin default to True.

View File

@ -1,7 +1,10 @@
"""Ensure that Components returned by ComponentState.create have independent State classes."""
import pytest
import reflex as rx
from reflex.components.base.bare import Bare
from reflex.utils.exceptions import ReflexRuntimeError
def test_component_state():
@ -40,3 +43,21 @@ def test_component_state():
assert len(cs2.children) == 1
assert cs2.children[0].render() == Bare.create("b").render()
assert cs2.id == "b"
def test_init_component_state() -> None:
"""Ensure that ComponentState subclasses cannot be instantiated directly."""
class CS(rx.ComponentState):
@classmethod
def get_component(cls, *children, **props):
return rx.el.div()
with pytest.raises(ReflexRuntimeError):
CS()
class SubCS(CS):
pass
with pytest.raises(ReflexRuntimeError):
SubCS()

View File

@ -43,7 +43,7 @@ from reflex.state import (
)
from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
from reflex.utils.exceptions import SetUndefinedStateVarError
from reflex.utils.exceptions import ReflexRuntimeError, SetUndefinedStateVarError
from reflex.utils.format import json_dumps
from reflex.vars.base import Var, computed_var
from tests.units.states.mutation import MutableSQLAModel, MutableTestState
@ -3441,3 +3441,19 @@ def test_get_value():
"bar": "foo",
}
}
def test_init_mixin() -> None:
"""Ensure that State mixins can not be instantiated directly."""
class Mixin(BaseState, mixin=True):
pass
with pytest.raises(ReflexRuntimeError):
Mixin()
class SubMixin(Mixin, mixin=True):
pass
with pytest.raises(ReflexRuntimeError):
SubMixin()