[REF-3961] move "warn_if_too_large" logic into BaseState ()

Check for too large serialized state whenever `BaseState._serialize` is used,
so it can apply to all state managers, not just `StateManagerRedis`.
This commit is contained in:
Masen Furer 2024-11-04 10:11:04 -08:00 committed by GitHub
parent b70f33d972
commit 0d9fc53a7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -104,6 +104,8 @@ var = computed_var
# If the state is this large, it's considered a performance issue.
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
# Only warn about each state class size once.
_WARNED_ABOUT_STATE_SIZE: Set[str] = set()
# Errors caught during pickling of state
HANDLED_PICKLE_ERRORS = (
@ -2046,6 +2048,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
state["__dict__"].pop(inherited_var_name, None)
return state
def _warn_if_too_large(
self,
pickle_state_size: int,
):
"""Print a warning when the state is too large.
Args:
pickle_state_size: The size of the pickled state.
"""
state_full_name = self.get_full_name()
if (
state_full_name not in _WARNED_ABOUT_STATE_SIZE
and pickle_state_size > TOO_LARGE_SERIALIZED_STATE
and self.substates
):
console.warn(
f"State {state_full_name} serializes to {pickle_state_size} bytes "
"which may present performance issues. Consider reducing the size of this state."
)
_WARNED_ABOUT_STATE_SIZE.add(state_full_name)
@classmethod
@functools.lru_cache()
def _to_schema(cls) -> str:
@ -2084,7 +2107,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
The serialized state.
"""
try:
return pickle.dumps((self._to_schema(), self))
pickle_state = pickle.dumps((self._to_schema(), self))
self._warn_if_too_large(len(pickle_state))
return pickle_state
except HANDLED_PICKLE_ERRORS as og_pickle_error:
error = (
f"Failed to serialize state {self.get_full_name()} due to unpicklable object. "
@ -3075,9 +3100,6 @@ class StateManagerRedis(StateManager):
b"evicted",
}
# Only warn about each state class size once.
_warned_about_state_size: ClassVar[Set[str]] = set()
async def _get_parent_state(
self, token: str, state: BaseState | None = None
) -> BaseState | None:
@ -3221,29 +3243,6 @@ class StateManagerRedis(StateManager):
return state._get_root_state()
return state
def _warn_if_too_large(
self,
state: BaseState,
pickle_state_size: int,
):
"""Print a warning when the state is too large.
Args:
state: The state to check.
pickle_state_size: The size of the pickled state.
"""
state_full_name = state.get_full_name()
if (
state_full_name not in self._warned_about_state_size
and pickle_state_size > TOO_LARGE_SERIALIZED_STATE
and state.substates
):
console.warn(
f"State {state_full_name} serializes to {pickle_state_size} bytes "
"which may present performance issues. Consider reducing the size of this state."
)
self._warned_about_state_size.add(state_full_name)
@override
async def set_state(
self,
@ -3294,7 +3293,6 @@ class StateManagerRedis(StateManager):
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
if state._get_was_touched():
pickle_state = state._serialize()
self._warn_if_too_large(state, len(pickle_state))
if pickle_state:
await self.redis.set(
_substate_key(client_token, state),