From 5cd04b4176f5eaf825acc0a4d3c2fb7c1ed37d0e Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 10 Jan 2025 16:25:15 -0800 Subject: [PATCH] improve type support for .get_state --- reflex/state.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index a31aae032..9a2448819 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1543,7 +1543,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Return the direct parent of target_state_cls for subsequent linking. return parent_state - def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState: + def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE: """Get a state instance from the cache. Args: @@ -1553,9 +1553,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): The instance of state_cls associated with this state's client_token. """ root_state = self._get_root_state() - return root_state.get_substate(state_cls.get_full_name().split(".")) + substate = root_state.get_substate(state_cls.get_full_name().split(".")) + if not isinstance(substate, state_cls): + raise ValueError( + f"Searched for state {state_cls.get_full_name()} but found {substate}." + ) + return substate - async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState: + async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: """Get a state instance from redis. Args: @@ -1577,14 +1582,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) - return await state_manager.get_state( + + state_in_redis = await state_manager.get_state( token=_substate_key(self.router.session.client_token, state_cls), top_level=False, get_substates=True, parent_state=parent_state_of_state_cls, ) - async def get_state(self, state_cls: Type[BaseState]) -> BaseState: + if not isinstance(state_in_redis, state_cls): + raise ValueError( + f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." + ) + + return state_in_redis + + async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE: """Get an instance of the state associated with this token. Allows for arbitrary access to sibling states from within an event handler. @@ -2316,6 +2329,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return state +T_STATE = TypeVar("T_STATE", bound=BaseState) + + class State(BaseState): """The app Base State."""