improve type support for .get_state

This commit is contained in:
Khaleel Al-Adhami 2025-01-10 16:25:15 -08:00
parent 427d7c56ab
commit 5cd04b4176

View File

@ -1543,7 +1543,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Return the direct parent of target_state_cls for subsequent linking.
return parent_state
def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from the cache.
Args:
@ -1553,9 +1553,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
The instance of state_cls associated with this state's client_token.
"""
root_state = self._get_root_state()
return root_state.get_substate(state_cls.get_full_name().split("."))
substate = root_state.get_substate(state_cls.get_full_name().split("."))
if not isinstance(substate, state_cls):
raise ValueError(
f"Searched for state {state_cls.get_full_name()} but found {substate}."
)
return substate
async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from redis.
Args:
@ -1577,14 +1582,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
return await state_manager.get_state(
state_in_redis = await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state_of_state_cls,
)
async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
if not isinstance(state_in_redis, state_cls):
raise ValueError(
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
)
return state_in_redis
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get an instance of the state associated with this token.
Allows for arbitrary access to sibling states from within an event handler.
@ -2316,6 +2329,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return state
T_STATE = TypeVar("T_STATE", bound=BaseState)
class State(BaseState):
"""The app Base State."""