improve type support for .get_state (#4623)

* improve type support for .get_state

* dang it darglint
This commit is contained in:
Khaleel Al-Adhami 2025-01-15 14:21:33 -08:00 committed by GitHub
parent 9fe8e6f1ce
commit b50b7692b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 5 deletions

View File

@ -104,6 +104,7 @@ from reflex.utils.exceptions import (
LockExpiredError, LockExpiredError,
ReflexRuntimeError, ReflexRuntimeError,
SetUndefinedStateVarError, SetUndefinedStateVarError,
StateMismatchError,
StateSchemaMismatchError, StateSchemaMismatchError,
StateSerializationError, StateSerializationError,
StateTooLargeError, StateTooLargeError,
@ -1543,7 +1544,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Return the direct parent of target_state_cls for subsequent linking. # Return the direct parent of target_state_cls for subsequent linking.
return parent_state return parent_state
def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState: def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from the cache. """Get a state instance from the cache.
Args: Args:
@ -1551,11 +1552,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns: Returns:
The instance of state_cls associated with this state's client_token. The instance of state_cls associated with this state's client_token.
Raises:
StateMismatchError: If the state instance is not of the expected type.
""" """
root_state = self._get_root_state() root_state = self._get_root_state()
return root_state.get_substate(state_cls.get_full_name().split(".")) substate = root_state.get_substate(state_cls.get_full_name().split("."))
if not isinstance(substate, state_cls):
raise StateMismatchError(
f"Searched for state {state_cls.get_full_name()} but found {substate}."
)
return substate
async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState: async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from redis. """Get a state instance from redis.
Args: Args:
@ -1566,6 +1575,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Raises: Raises:
RuntimeError: If redis is not used in this backend process. RuntimeError: If redis is not used in this backend process.
StateMismatchError: If the state instance is not of the expected type.
""" """
# Fetch all missing parent states from redis. # Fetch all missing parent states from redis.
parent_state_of_state_cls = await self._populate_parent_states(state_cls) parent_state_of_state_cls = await self._populate_parent_states(state_cls)
@ -1577,14 +1587,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).", "(All states should already be available -- this is likely a bug).",
) )
return await state_manager.get_state(
state_in_redis = await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls), token=_substate_key(self.router.session.client_token, state_cls),
top_level=False, top_level=False,
get_substates=True, get_substates=True,
parent_state=parent_state_of_state_cls, parent_state=parent_state_of_state_cls,
) )
async def get_state(self, state_cls: Type[BaseState]) -> BaseState: if not isinstance(state_in_redis, state_cls):
raise StateMismatchError(
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
)
return state_in_redis
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get an instance of the state associated with this token. """Get an instance of the state associated with this token.
Allows for arbitrary access to sibling states from within an event handler. Allows for arbitrary access to sibling states from within an event handler.
@ -2316,6 +2334,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return state return state
T_STATE = TypeVar("T_STATE", bound=BaseState)
class State(BaseState): class State(BaseState):
"""The app Base State.""" """The app Base State."""

View File

@ -163,6 +163,10 @@ class StateSerializationError(ReflexError):
"""Raised when the state cannot be serialized.""" """Raised when the state cannot be serialized."""
class StateMismatchError(ReflexError, ValueError):
"""Raised when the state retrieved does not match the expected state."""
class SystemPackageMissingError(ReflexError): class SystemPackageMissingError(ReflexError):
"""Raised when a system package is missing.""" """Raised when a system package is missing."""