"""Agent state serialization using Flax msgpack format.Provides utilities for serializing and deserializing agent states (typicallyFlax optimizer states and neural network parameters) using Flax's msgpackserialization. This is more reliable than pickle for JAX/Flax objects.All functions raise RuntimeError with clear messages on failure."""frompathlibimportPathfromtypingimportAnyfromflaximportserialization
[docs]defserialize_agent_state(agent_state:Any)->bytes:"""Serialize agent state to msgpack bytes. Args: agent_state: Agent state to serialize (typically Flax TrainState or similar) Returns: Serialized bytes Raises: RuntimeError: If serialization fails """try:returnserialization.msgpack_serialize(agent_state)exceptExceptionase:raiseRuntimeError(f"Failed to serialize agent state. Ensure the agent state contains only "f"JAX/Flax types (pytrees, arrays, etc.). Original error: {e}")frome
[docs]defdeserialize_agent_state(data:bytes)->Any:"""Deserialize agent state from msgpack bytes. Args: data: Msgpack-serialized bytes Returns: Deserialized agent state Raises: RuntimeError: If deserialization fails """try:returnserialization.msgpack_restore(data)exceptExceptionase:raiseRuntimeError(f"Failed to deserialize agent state. The data may be corrupted or "f"incompatible. Original error: {e}")frome
[docs]defsave_agent_state(agent_state:Any,path:str|Path)->None:"""Serialize and save agent state to file. Args: agent_state: Agent state to save path: File path (typically with .msgpack extension) Raises: RuntimeError: If serialization or file writing fails """path=Path(path)try:data=serialize_agent_state(agent_state)path.parent.mkdir(parents=True,exist_ok=True)path.write_bytes(data)exceptRuntimeError:# Re-raise serialization errors as-israiseexceptExceptionase:raiseRuntimeError(f"Failed to write agent state to {path}. Check file permissions and "f"disk space. Original error: {e}")frome
[docs]defload_agent_state(path:str|Path)->Any:"""Load and deserialize agent state from file. Args: path: File path to load from Returns: Deserialized agent state Raises: FileNotFoundError: If file doesn't exist RuntimeError: If deserialization fails """path=Path(path)ifnotpath.exists():raiseFileNotFoundError(f"Agent checkpoint not found: {path}")try:data=path.read_bytes()returndeserialize_agent_state(data)exceptRuntimeError:# Re-raise deserialization errors as-israiseexceptExceptionase:raiseRuntimeError(f"Failed to read agent state from {path}. The file may be corrupted. "f"Original error: {e}")frome