diff --git a/reflex/state.py b/reflex/state.py index a31aae032..16875bec3 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -104,6 +104,7 @@ from reflex.utils.exceptions import ( LockExpiredError, ReflexRuntimeError, SetUndefinedStateVarError, + StateMismatchError, StateSchemaMismatchError, StateSerializationError, StateTooLargeError, @@ -1543,7 +1544,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Return the direct parent of target_state_cls for subsequent linking. return parent_state - def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState: + def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE: """Get a state instance from the cache. Args: @@ -1551,11 +1552,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Returns: The instance of state_cls associated with this state's client_token. + + Raises: + StateMismatchError: If the state instance is not of the expected type. """ root_state = self._get_root_state() - return root_state.get_substate(state_cls.get_full_name().split(".")) + substate = root_state.get_substate(state_cls.get_full_name().split(".")) + if not isinstance(substate, state_cls): + raise StateMismatchError( + f"Searched for state {state_cls.get_full_name()} but found {substate}." + ) + return substate - async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState: + async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: """Get a state instance from redis. Args: @@ -1566,6 +1575,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Raises: RuntimeError: If redis is not used in this backend process. + StateMismatchError: If the state instance is not of the expected type. """ # Fetch all missing parent states from redis. parent_state_of_state_cls = await self._populate_parent_states(state_cls) @@ -1577,14 +1587,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) - return await state_manager.get_state( + + state_in_redis = await state_manager.get_state( token=_substate_key(self.router.session.client_token, state_cls), top_level=False, get_substates=True, parent_state=parent_state_of_state_cls, ) - async def get_state(self, state_cls: Type[BaseState]) -> BaseState: + if not isinstance(state_in_redis, state_cls): + raise StateMismatchError( + f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." + ) + + return state_in_redis + + async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE: """Get an instance of the state associated with this token. Allows for arbitrary access to sibling states from within an event handler. @@ -2316,6 +2334,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return state +T_STATE = TypeVar("T_STATE", bound=BaseState) + + class State(BaseState): """The app Base State.""" diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index bceadc977..339abcda1 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -163,6 +163,10 @@ class StateSerializationError(ReflexError): """Raised when the state cannot be serialized.""" +class StateMismatchError(ReflexError, ValueError): + """Raised when the state retrieved does not match the expected state.""" + + class SystemPackageMissingError(ReflexError): """Raised when a system package is missing."""