bypass pydantic runtime validation for state init

closes #4235
This commit is contained in:
Benedikt Bartscher 2024-10-28 23:48:09 +01:00
parent 41b1958626
commit c4a29229b1
No known key found for this signature in database
2 changed files with 11 additions and 4 deletions

View File

@ -344,7 +344,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def __init__( def __init__(
self, self,
*args,
parent_state: BaseState | None = None, parent_state: BaseState | None = None,
init_substates: bool = True, init_substates: bool = True,
_reflex_internal_init: bool = False, _reflex_internal_init: bool = False,
@ -355,11 +354,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead. DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead.
Args: Args:
*args: The args to pass to the Pydantic init method.
parent_state: The parent state. parent_state: The parent state.
init_substates: Whether to initialize the substates in this instance. init_substates: Whether to initialize the substates in this instance.
_reflex_internal_init: A flag to indicate that the state is being initialized by the framework. _reflex_internal_init: A flag to indicate that the state is being initialized by the framework.
**kwargs: The kwargs to pass to the Pydantic init method. **kwargs: The kwargs to set as attributes on the state.
Raises: Raises:
ReflexRuntimeError: If the state is instantiated directly by end user. ReflexRuntimeError: If the state is instantiated directly by end user.
@ -372,7 +370,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"See https://reflex.dev/docs/state/ for further information." "See https://reflex.dev/docs/state/ for further information."
) )
kwargs["parent_state"] = parent_state kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs) super().__init__()
for name, value in kwargs.items():
setattr(self, name, value)
# Setup the substates (for memory state manager only). # Setup the substates (for memory state manager only).
if init_substates: if init_substates:

View File

@ -3396,3 +3396,10 @@ def test_fallback_pickle():
assert len(pk) == 0 assert len(pk) == 0
with pytest.raises(EOFError): with pytest.raises(EOFError):
BaseState._deserialize(pk) BaseState._deserialize(pk)
def test_typed_state() -> None:
class TypedState(rx.State):
field: rx.Field[str] = rx.field("")
_ = TypedState(field="str")