fix: do not allow instantiation of State mixins (#4347)
* fix: do not allow instantiation of State mixins Closes #4343 * improve error message for ComponentState mixins * fix typo Co-authored-by: Masen Furer <m_github@0x26.net> --------- Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
parent
af4fe48428
commit
1f9a17539c
@ -87,6 +87,7 @@ from reflex.utils.exceptions import (
|
|||||||
ImmutableStateError,
|
ImmutableStateError,
|
||||||
InvalidStateManagerMode,
|
InvalidStateManagerMode,
|
||||||
LockExpiredError,
|
LockExpiredError,
|
||||||
|
ReflexRuntimeError,
|
||||||
SetUndefinedStateVarError,
|
SetUndefinedStateVarError,
|
||||||
StateSchemaMismatchError,
|
StateSchemaMismatchError,
|
||||||
)
|
)
|
||||||
@ -387,6 +388,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"State classes should not be instantiated directly in a Reflex app. "
|
"State classes should not be instantiated directly in a Reflex app. "
|
||||||
"See https://reflex.dev/docs/state/ for further information."
|
"See https://reflex.dev/docs/state/ for further information."
|
||||||
)
|
)
|
||||||
|
if type(self)._mixin:
|
||||||
|
raise ReflexRuntimeError(
|
||||||
|
f"{type(self).__name__} is a state mixin and cannot be instantiated directly."
|
||||||
|
)
|
||||||
kwargs["parent_state"] = parent_state
|
kwargs["parent_state"] = parent_state
|
||||||
super().__init__()
|
super().__init__()
|
||||||
for name, value in kwargs.items():
|
for name, value in kwargs.items():
|
||||||
@ -2367,6 +2372,23 @@ class ComponentState(State, mixin=True):
|
|||||||
# The number of components created from this class.
|
# The number of components created from this class.
|
||||||
_per_component_state_instance_count: ClassVar[int] = 0
|
_per_component_state_instance_count: ClassVar[int] = 0
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""Do not allow direct initialization of the ComponentState.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: The args to pass to the State init method.
|
||||||
|
**kwargs: The kwargs to pass to the State init method.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ReflexRuntimeError: If the ComponentState is initialized directly.
|
||||||
|
"""
|
||||||
|
if type(self)._mixin:
|
||||||
|
raise ReflexRuntimeError(
|
||||||
|
f"{ComponentState.__name__} {type(self).__name__} is not meant to be initialized directly. "
|
||||||
|
+ "Use the `create` method to create a new instance and access the state via the `State` attribute."
|
||||||
|
)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass__(cls, mixin: bool = True, **kwargs):
|
def __init_subclass__(cls, mixin: bool = True, **kwargs):
|
||||||
"""Overwrite mixin default to True.
|
"""Overwrite mixin default to True.
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
"""Ensure that Components returned by ComponentState.create have independent State classes."""
|
"""Ensure that Components returned by ComponentState.create have independent State classes."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
import reflex as rx
|
import reflex as rx
|
||||||
from reflex.components.base.bare import Bare
|
from reflex.components.base.bare import Bare
|
||||||
|
from reflex.utils.exceptions import ReflexRuntimeError
|
||||||
|
|
||||||
|
|
||||||
def test_component_state():
|
def test_component_state():
|
||||||
@ -40,3 +43,21 @@ def test_component_state():
|
|||||||
assert len(cs2.children) == 1
|
assert len(cs2.children) == 1
|
||||||
assert cs2.children[0].render() == Bare.create("b").render()
|
assert cs2.children[0].render() == Bare.create("b").render()
|
||||||
assert cs2.id == "b"
|
assert cs2.id == "b"
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_component_state() -> None:
|
||||||
|
"""Ensure that ComponentState subclasses cannot be instantiated directly."""
|
||||||
|
|
||||||
|
class CS(rx.ComponentState):
|
||||||
|
@classmethod
|
||||||
|
def get_component(cls, *children, **props):
|
||||||
|
return rx.el.div()
|
||||||
|
|
||||||
|
with pytest.raises(ReflexRuntimeError):
|
||||||
|
CS()
|
||||||
|
|
||||||
|
class SubCS(CS):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ReflexRuntimeError):
|
||||||
|
SubCS()
|
||||||
|
@ -43,7 +43,7 @@ from reflex.state import (
|
|||||||
)
|
)
|
||||||
from reflex.testing import chdir
|
from reflex.testing import chdir
|
||||||
from reflex.utils import format, prerequisites, types
|
from reflex.utils import format, prerequisites, types
|
||||||
from reflex.utils.exceptions import SetUndefinedStateVarError
|
from reflex.utils.exceptions import ReflexRuntimeError, SetUndefinedStateVarError
|
||||||
from reflex.utils.format import json_dumps
|
from reflex.utils.format import json_dumps
|
||||||
from reflex.vars.base import Var, computed_var
|
from reflex.vars.base import Var, computed_var
|
||||||
from tests.units.states.mutation import MutableSQLAModel, MutableTestState
|
from tests.units.states.mutation import MutableSQLAModel, MutableTestState
|
||||||
@ -3441,3 +3441,19 @@ def test_get_value():
|
|||||||
"bar": "foo",
|
"bar": "foo",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_mixin() -> None:
|
||||||
|
"""Ensure that State mixins can not be instantiated directly."""
|
||||||
|
|
||||||
|
class Mixin(BaseState, mixin=True):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ReflexRuntimeError):
|
||||||
|
Mixin()
|
||||||
|
|
||||||
|
class SubMixin(Mixin, mixin=True):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ReflexRuntimeError):
|
||||||
|
SubMixin()
|
||||||
|
Loading…
Reference in New Issue
Block a user