Core & Utilities¶
Spaces¶
JAX-friendly space definitions for RL environments.
- class myriad.core.spaces.Box(low, high, shape=(), dtype=<class 'jax.numpy.float32'>)[source]¶
Bases:
SpaceA box in R^n with a bounds per dimension.
Core Types¶
- class myriad.core.types.Observation(*args, **kwargs)[source]¶
Bases:
ProtocolProtocol for structured observation pytrees.
Observations are typically NamedTuples with named fields that can be converted to flat arrays for neural network input. The to_array() method enables agents to work with observations in either structured (for field introspection) or flattened (for network input) form.
- __init__(*args, **kwargs)¶
- class myriad.core.types.Transition(obs, action, reward, next_obs, done)[source]¶
Bases:
NamedTuple- obs: Array¶
Alias for field number 0
- action: Array¶
Alias for field number 1
- reward: Array¶
Alias for field number 2
- next_obs: Array¶
Alias for field number 3
- done: Array¶
Alias for field number 4
- class myriad.core.types.BaseModel[source]¶
Bases:
BaseModelPydantic BaseModel subclass to be used throughout the codebase.
- model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True}¶
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Replay Buffer¶
A JAX-native, functional implementation of a replay buffer.
- class myriad.core.replay_buffer.ReplayBufferState(data, position, size)[source]¶
Bases:
NamedTupleState of the replay buffer. Contains the stored data and the current position.
- Variables:
data (jax.jaxlib._jax.Array | numpy.ndarray | numpy.bool | numpy.number | Iterable[jax.jaxlib._jax.Array | numpy.ndarray | numpy.bool | numpy.number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]] | Mapping[Any, jax.jaxlib._jax.Array | numpy.ndarray | numpy.bool | numpy.number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]) – A PyTree of JAX arrays where each leaf has a shape of (buffer_size, …).
position (jax.jaxlib._jax.Array | numpy.ndarray | numpy.bool | numpy.number) – The current index in the buffer to write the next transition.
size (jax.jaxlib._jax.Array | numpy.ndarray | numpy.bool | numpy.number) – The current number of valid transitions stored in the buffer.
- class myriad.core.replay_buffer.ReplayBuffer(buffer_size)[source]¶
Bases:
objectA class that holds the pure functions for a replay buffer.
- Variables:
buffer_size (int) – The maximum number of transitions to store.
- add(state, transitions)[source]¶
Adds a batch of transitions to the buffer.
- Parameters:
- Returns:
The new ReplayBufferState after adding the transitions.
- Return type:
- sample(state, batch_size, key)[source]¶
Samples a random batch of transitions from the buffer.
- Parameters:
state (ReplayBufferState) – The current state of the replay buffer.
batch_size (int) – The number of transitions to sample.
key (Array) – A JAX PRNG key for sampling.
- Returns:
A tuple containing the unchanged buffer state and the sampled batch.
- Return type:
tuple[ReplayBufferState, Array | ndarray | bool | number | Iterable[chex.ArrayTree] | Mapping[Any, chex.ArrayTree]]
- add_and_sample(state, transitions, batch_size, key)[source]¶
Adds a batch of transitions to the buffer and samples a random batch. This is a pure function.
- Parameters:
state (ReplayBufferState) – The current state of the replay buffer.
transitions (Array | ndarray | bool | number | Iterable[chex.ArrayTree] | Mapping[Any, chex.ArrayTree]) – A PyTree of transitions to add. Each leaf must have a leading dimension matching the number of parallel environments.
batch_size (int) – The number of transitions to sample.
key (Array) – A JAX PRNG key for sampling.
- Returns:
A tuple containing the new buffer state and the sampled batch.
- Return type:
Tuple[ReplayBufferState, Array | ndarray | bool | number | Iterable[chex.ArrayTree] | Mapping[Any, chex.ArrayTree]]
- __init__(buffer_size)¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
Utilities¶
Utility functions for Myriad.
- myriad.utils.filter_kwargs(kwargs, cls)[source]¶
Return only the kwargs whose names match fields of the given dataclass.
Uses dataclass field introspection so routing stays in sync with the class definition automatically — no manually maintained field-name sets needed.
Works with both standard dataclasses and flax.struct.dataclass.
- myriad.utils.load_config(path, config_cls)[source]¶
Load a YAML config file and convert to a Pydantic config object.
- Parameters:
- Returns:
Instantiated and validated config object
- Return type:
T
Example
>>> from myriad.configs.default import EvalConfig >>> config = load_config("config.yaml", EvalConfig)
- myriad.utils.to_array(obs)[source]¶
Convert observation to array format.
Handles different observation types: - Arrays (JAX/numpy): Returned as-is - Observations with .to_array(): Converted via that method - Other types: Attempted conversion via jnp.asarray
- Parameters:
obs (Observation | Array) – Observation (satisfying the Observation Protocol) or array
- Returns:
Array representation of the observation
- Raises:
ValueError – If observation cannot be converted to array
- Return type:
Array
- myriad.utils.plot_training_curve(results, labels=None, xlabel='Steps per Env', ylabel='Mean Return', title=None, figsize=(8, 4), show_std=True, ax=None)[source]¶
Plot training curve(s) showing mean return with optional standard deviation.
- Parameters:
results (TrainingResults | list[TrainingResults]) – Single TrainingResults or list of results to plot
labels (str | list[str] | None) – Legend label(s) for the curve(s). If None, uses agent name from config
xlabel (str) – Label for x-axis # cspell:disable-line
ylabel (str) – Label for y-axis # cspell:disable-line
title (str | None) – Plot title. If None, auto-generates from environment name
figsize (tuple[float, float]) – Figure size (width, height) in inches
show_std (bool) – Whether to show standard deviation as shaded region
ax (Axes | None) – Existing axes to plot on. If None, creates new figure
- Returns:
Tuple of (figure, axes) objects
- Return type:
tuple[Figure, Axes]
Example
>>> results = train_and_evaluate(config) >>> fig, ax = plot_training_curve(results) >>> plt.show()
>>> # Compare multiple runs >>> results_list = [results_dqn, results_ppo] >>> fig, ax = plot_training_curve(results_list, labels=["DQN", "PPO"]) >>> plt.show()