[REF-1982] state: Warn if redis state is "too big" (#2868)

If the state serializes to over 100kb and has substates, then print a warning
suggesting the developer reduce the size of the state.
This commit is contained in:
Masen Furer 2024-03-20 16:50:48 -07:00 committed by GitHub
parent f80d7978d2
commit 58f706ac7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,6 +52,10 @@ Delta = Dict[str, Any]
var = computed_var
# If the state is this large, it's considered a performance issue.
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
class HeaderData(Base):
"""An object containing headers data."""
@ -2183,6 +2187,9 @@ class StateManagerRedis(StateManager):
b"evicted",
}
# Only warn about each state class size once.
_warned_about_state_size: ClassVar[Set[str]] = set()
def _get_root_state(self, state: BaseState) -> BaseState:
"""Chase parent_state pointers to find an instance of the top-level state.
@ -2334,6 +2341,29 @@ class StateManagerRedis(StateManager):
return self._get_root_state(state)
return state
def _warn_if_too_large(
self,
state: BaseState,
pickle_state_size: int,
):
"""Print a warning when the state is too large.
Args:
state: The state to check.
pickle_state_size: The size of the pickled state.
"""
state_full_name = state.get_full_name()
if (
state_full_name not in self._warned_about_state_size
and pickle_state_size > TOO_LARGE_SERIALIZED_STATE
and state.substates
):
console.warn(
f"State {state_full_name} serializes to {pickle_state_size} bytes "
"which may present performance issues. Consider reducing the size of this state."
)
self._warned_about_state_size.add(state_full_name)
async def set_state(
self,
token: str,
@ -2382,9 +2412,11 @@ class StateManagerRedis(StateManager):
)
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
if state._get_was_touched():
pickle_state = cloudpickle.dumps(state)
self._warn_if_too_large(state, len(pickle_state))
await self.redis.set(
_substate_key(client_token, state),
cloudpickle.dumps(state),
pickle_state,
ex=self.token_expiration,
)