Reduce pickle size (#4063)

* Only serialize base vars
* Never serialize router/router_data in substates
* Hash the schema to reduce serialized size
* lru_cache the schema to avoid recomputing it
This commit is contained in:
Masen Furer 2024-10-07 09:34:36 -07:00 committed by GitHub
parent 5c0518053d
commit edd17208c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 35 deletions

View File

@ -691,6 +691,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
parent_state.get_parent_state(),
)
# Reset cached schema value
cls._to_schema.cache_clear()
@classmethod
def _check_overridden_methods(cls):
"""Check for shadow methods and raise error if any.
@ -1945,20 +1948,58 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
The state dict for serialization.
"""
state = super().__getstate__()
# Never serialize parent_state or substates
state["__dict__"] = state["__dict__"].copy()
if state["__dict__"].get("parent_state") is not None:
# Do not serialize router data in substates (only the root state).
state["__dict__"].pop("router", None)
state["__dict__"].pop("router_data", None)
# Never serialize parent_state or substates.
state["__dict__"]["parent_state"] = None
state["__dict__"]["substates"] = {}
state["__dict__"].pop("_was_touched", None)
# Remove all inherited vars.
for inherited_var_name in self.inherited_vars:
state["__dict__"].pop(inherited_var_name, None)
return state
@classmethod
@functools.lru_cache()
def _to_schema(cls) -> str:
"""Convert a state to a schema.
Returns:
The hash of the schema.
"""
def _field_tuple(
field_name: str,
) -> Tuple[str, str, Any, Union[bool, None], Any]:
model_field = cls.__fields__[field_name]
return (
field_name,
model_field.name,
_serialize_type(model_field.type_),
(
model_field.required
if isinstance(model_field.required, bool)
else None
),
(model_field.default if is_serializable(model_field.default) else None),
)
return md5(
pickle.dumps(
list(sorted(_field_tuple(field_name) for field_name in cls.base_vars))
)
).hexdigest()
def _serialize(self) -> bytes:
"""Serialize the state for redis.
Returns:
The serialized state.
"""
return pickle.dumps((state_to_schema(self), self))
return pickle.dumps((self._to_schema(), self))
@classmethod
def _deserialize(
@ -1985,7 +2026,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
(substate_schema, state) = pickle.load(fp)
else:
raise ValueError("Only one of `data` or `fp` must be provided")
if substate_schema != state_to_schema(state):
if substate_schema != state._to_schema():
raise StateSchemaMismatchError()
return state
@ -2620,35 +2661,6 @@ def is_serializable(value: Any) -> bool:
return False
def state_to_schema(
state: BaseState,
) -> List[Tuple[str, str, Any, Union[bool, None], Any]]:
"""Convert a state to a schema.
Args:
state: The state to convert to a schema.
Returns:
The schema.
"""
return list(
sorted(
(
field_name,
model_field.name,
_serialize_type(model_field.type_),
(
model_field.required
if isinstance(model_field.required, bool)
else None
),
(model_field.default if is_serializable(model_field.default) else None),
)
for field_name, model_field in state.__fields__.items()
)
)
def reset_disk_state_manager():
"""Reset the disk state manager."""
states_directory = prerequisites.get_web_dir() / constants.Dirs.STATES

View File

@ -41,13 +41,13 @@ def DynamicRoute():
return rx.fragment(
rx.input(
value=DynamicState.router.session.client_token,
is_read_only=True,
read_only=True,
id="token",
),
rx.input(value=rx.State.page_id, is_read_only=True, id="page_id"), # type: ignore
rx.input(value=rx.State.page_id, read_only=True, id="page_id"), # type: ignore
rx.input(
value=DynamicState.router.page.raw_path,
is_read_only=True,
read_only=True,
id="raw_path",
),
rx.link("index", href="/", id="link_index"),