Source code for myriad.core.replay_buffer

"""
A JAX-native, functional implementation of a replay buffer.
"""

from typing import NamedTuple, Tuple

import chex
import jax
import jax.numpy as jnp
from flax.struct import dataclass


[docs] class ReplayBufferState(NamedTuple): """ State of the replay buffer. Contains the stored data and the current position. Attributes: data: A PyTree of JAX arrays where each leaf has a shape of (buffer_size, ...). position: The current index in the buffer to write the next transition. size: The current number of valid transitions stored in the buffer. """ data: chex.ArrayTree position: chex.Array size: chex.Array
[docs] @dataclass class ReplayBuffer: """ A class that holds the pure functions for a replay buffer. Attributes: buffer_size: The maximum number of transitions to store. """ buffer_size: int
[docs] def init(self, sample_transition: chex.ArrayTree) -> ReplayBufferState: """ Initializes the replay buffer state. Args: sample_transition: A sample transition PyTree to infer shapes and dtypes. Returns: The initial ReplayBufferState. """ # Convert python scalars to arrays and infer shapes/dtypes robustly def _to_array(x): if isinstance(x, jnp.ndarray): return x return jnp.asarray(x) sample_transition = jax.tree_util.tree_map(_to_array, sample_transition) # Create zero-filled arrays for the buffer data based on the sample transition data = jax.tree_util.tree_map( lambda x: jnp.zeros((self.buffer_size,) + x.shape, dtype=x.dtype), sample_transition, ) return ReplayBufferState( data=data, position=jnp.array(0, dtype=jnp.int32), size=jnp.array(0, dtype=jnp.int32), )
[docs] def add(self, state: ReplayBufferState, transitions: chex.ArrayTree) -> ReplayBufferState: """ Adds a batch of transitions to the buffer. Args: state: The current state of the replay buffer. transitions: A PyTree of transitions to add. Each leaf must have a leading dimension matching the number of parallel environments. Returns: The new ReplayBufferState after adding the transitions. """ # Validate transitions tree is non-empty leaves = jax.tree_util.tree_leaves(transitions) if len(leaves) == 0: raise ValueError("transitions must be a non-empty PyTree") # Get the number of transitions to add from the leading dimension num_transitions_to_add = leaves[0].shape[0] # Calculate the indices to write to, wrapping around the buffer idxs = jnp.arange(num_transitions_to_add, dtype=state.position.dtype) indices = (state.position + idxs) % self.buffer_size # Update the buffer data at the calculated indices new_data = jax.tree_util.tree_map( lambda buffer_leaf, transition_leaf: buffer_leaf.at[indices].set(transition_leaf), state.data, transitions, ) new_position = (state.position + num_transitions_to_add) % self.buffer_size new_size = jnp.minimum(state.size + num_transitions_to_add, self.buffer_size) return ReplayBufferState(data=new_data, position=new_position, size=new_size)
[docs] def sample( self, state: ReplayBufferState, batch_size: int, key: chex.PRNGKey ) -> tuple[ReplayBufferState, chex.ArrayTree]: """ Samples a random batch of transitions from the buffer. Args: state: The current state of the replay buffer. batch_size: The number of transitions to sample. key: A JAX PRNG key for sampling. Returns: A tuple containing the unchanged buffer state and the sampled batch. """ max_index = jnp.maximum(state.size, 1) # avoid maxval==0 sample_indices = jax.random.randint(key, (int(batch_size),), 0, max_index, dtype=jnp.int32) sampled = jax.tree_util.tree_map(lambda buf: buf[sample_indices], state.data) return state, sampled
[docs] def add_and_sample( self, state: ReplayBufferState, transitions: chex.ArrayTree, batch_size: int, key: chex.PRNGKey, ) -> Tuple[ReplayBufferState, chex.ArrayTree]: """ Adds a batch of transitions to the buffer and samples a random batch. This is a pure function. Args: state: The current state of the replay buffer. transitions: A PyTree of transitions to add. Each leaf must have a leading dimension matching the number of parallel environments. batch_size: The number of transitions to sample. key: A JAX PRNG key for sampling. Returns: A tuple containing the new buffer state and the sampled batch. """ new_state = self.add(state, transitions) return self.sample(new_state, batch_size, key)