From c4a29229b13c59f61723d069d2abce93e537db1d Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Mon, 28 Oct 2024 23:48:09 +0100 Subject: [PATCH] bypass pydantic runtime validation for state init closes #4235 --- reflex/state.py | 8 ++++---- tests/units/test_state.py | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 6e229b97d..9e0ad3118 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -344,7 +344,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): def __init__( self, - *args, parent_state: BaseState | None = None, init_substates: bool = True, _reflex_internal_init: bool = False, @@ -355,11 +354,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead. Args: - *args: The args to pass to the Pydantic init method. parent_state: The parent state. init_substates: Whether to initialize the substates in this instance. _reflex_internal_init: A flag to indicate that the state is being initialized by the framework. - **kwargs: The kwargs to pass to the Pydantic init method. + **kwargs: The kwargs to set as attributes on the state. Raises: ReflexRuntimeError: If the state is instantiated directly by end user. @@ -372,7 +370,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): "See https://reflex.dev/docs/state/ for further information." ) kwargs["parent_state"] = parent_state - super().__init__(*args, **kwargs) + super().__init__() + for name, value in kwargs.items(): + setattr(self, name, value) # Setup the substates (for memory state manager only). if init_substates: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 89dd1fd3d..46bcc23b3 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3396,3 +3396,10 @@ def test_fallback_pickle(): assert len(pk) == 0 with pytest.raises(EOFError): BaseState._deserialize(pk) + + +def test_typed_state() -> None: + class TypedState(rx.State): + field: rx.Field[str] = rx.field("") + + _ = TypedState(field="str")