Core & Utilities

Spaces

JAX-friendly space definitions for RL environments.

class myriad.core.spaces.Space[source]

Bases: object

Base class for all spaces.

sample(key)[source]

Sample a random value from the space.

contains(x)[source]

Check if x is a valid value in this space.

class myriad.core.spaces.Box(low, high, shape=(), dtype=<class 'jax.numpy.float32'>)[source]

Bases: Space

A box in R^n with a bounds per dimension.

__init__(low, high, shape=(), dtype=<class 'jax.numpy.float32'>)[source]
sample(key)[source]

Sample uniformly from the box.

contains(x)[source]

Check if x is within bounds.

class myriad.core.spaces.Discrete(n, dtype=<class 'jax.numpy.int32'>)[source]

Bases: Space

A finite set of integer actions {0, 1, …, n-1}.

__init__(n, dtype=<class 'jax.numpy.int32'>)[source]
sample(key)[source]

Sample uniformly from the discrete set.

contains(x)[source]

Check if x is a valid discrete value.

Core Types

class myriad.core.types.Observation(*args, **kwargs)[source]

Bases: Protocol

Protocol for structured observation pytrees.

Observations are typically NamedTuples with named fields that can be converted to flat arrays for neural network input. The to_array() method enables agents to work with observations in either structured (for field introspection) or flattened (for network input) form.

to_array()[source]

Convert observation to a flat JAX array.

__init__(*args, **kwargs)
class myriad.core.types.Transition(obs, action, reward, next_obs, done)[source]

Bases: NamedTuple

obs: Array

Alias for field number 0

action: Array

Alias for field number 1

reward: Array

Alias for field number 2

next_obs: Array

Alias for field number 3

done: Array

Alias for field number 4

class myriad.core.types.BaseModel[source]

Bases: BaseModel

Pydantic BaseModel subclass to be used throughout the codebase.

model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Replay Buffer

A JAX-native, functional implementation of a replay buffer.

class myriad.core.replay_buffer.ReplayBufferState(data, position, size)[source]

Bases: NamedTuple

State of the replay buffer. Contains the stored data and the current position.

Variables:
data: Array | ndarray | bool | number | Iterable[Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]] | Mapping[Any, Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]

Alias for field number 0

position: Array | ndarray | bool | number

Alias for field number 1

size: Array | ndarray | bool | number

Alias for field number 2

class myriad.core.replay_buffer.ReplayBuffer(buffer_size)[source]

Bases: object

A class that holds the pure functions for a replay buffer.

Variables:

buffer_size (int) – The maximum number of transitions to store.

buffer_size: int
init(sample_transition)[source]

Initializes the replay buffer state.

Parameters:

sample_transition (Array | ndarray | bool | number | Iterable[chex.ArrayTree] | Mapping[Any, chex.ArrayTree]) – A sample transition PyTree to infer shapes and dtypes.

Returns:

The initial ReplayBufferState.

Return type:

ReplayBufferState

add(state, transitions)[source]

Adds a batch of transitions to the buffer.

Parameters:
  • state (ReplayBufferState) – The current state of the replay buffer.

  • transitions (Array | ndarray | bool | number | Iterable[chex.ArrayTree] | Mapping[Any, chex.ArrayTree]) – A PyTree of transitions to add. Each leaf must have a leading dimension matching the number of parallel environments.

Returns:

The new ReplayBufferState after adding the transitions.

Return type:

ReplayBufferState

sample(state, batch_size, key)[source]

Samples a random batch of transitions from the buffer.

Parameters:
  • state (ReplayBufferState) – The current state of the replay buffer.

  • batch_size (int) – The number of transitions to sample.

  • key (Array) – A JAX PRNG key for sampling.

Returns:

A tuple containing the unchanged buffer state and the sampled batch.

Return type:

tuple[ReplayBufferState, Array | ndarray | bool | number | Iterable[chex.ArrayTree] | Mapping[Any, chex.ArrayTree]]

add_and_sample(state, transitions, batch_size, key)[source]

Adds a batch of transitions to the buffer and samples a random batch. This is a pure function.

Parameters:
  • state (ReplayBufferState) – The current state of the replay buffer.

  • transitions (Array | ndarray | bool | number | Iterable[chex.ArrayTree] | Mapping[Any, chex.ArrayTree]) – A PyTree of transitions to add. Each leaf must have a leading dimension matching the number of parallel environments.

  • batch_size (int) – The number of transitions to sample.

  • key (Array) – A JAX PRNG key for sampling.

Returns:

A tuple containing the new buffer state and the sampled batch.

Return type:

Tuple[ReplayBufferState, Array | ndarray | bool | number | Iterable[chex.ArrayTree] | Mapping[Any, chex.ArrayTree]]

__init__(buffer_size)
replace(**updates)

Returns a new object replacing the specified fields with new values.

Utilities

Utility functions for Myriad.

myriad.utils.filter_kwargs(kwargs, cls)[source]

Return only the kwargs whose names match fields of the given dataclass.

Uses dataclass field introspection so routing stays in sync with the class definition automatically — no manually maintained field-name sets needed.

Works with both standard dataclasses and flax.struct.dataclass.

myriad.utils.load_config(path, config_cls)[source]

Load a YAML config file and convert to a Pydantic config object.

Parameters:
  • path (str | Path) – Path to YAML config file

  • config_cls (Type[T]) – Pydantic config class to instantiate (e.g., EvalConfig)

Returns:

Instantiated and validated config object

Return type:

T

Example

>>> from myriad.configs.default import EvalConfig
>>> config = load_config("config.yaml", EvalConfig)
myriad.utils.to_array(obs)[source]

Convert observation to array format.

Handles different observation types: - Arrays (JAX/numpy): Returned as-is - Observations with .to_array(): Converted via that method - Other types: Attempted conversion via jnp.asarray

Parameters:

obs (Observation | Array) – Observation (satisfying the Observation Protocol) or array

Returns:

Array representation of the observation

Raises:

ValueError – If observation cannot be converted to array

Return type:

Array

myriad.utils.plot_training_curve(results, labels=None, xlabel='Steps per Env', ylabel='Mean Return', title=None, figsize=(8, 4), show_std=True, ax=None)[source]

Plot training curve(s) showing mean return with optional standard deviation.

Parameters:
  • results (TrainingResults | list[TrainingResults]) – Single TrainingResults or list of results to plot

  • labels (str | list[str] | None) – Legend label(s) for the curve(s). If None, uses agent name from config

  • xlabel (str) – Label for x-axis # cspell:disable-line

  • ylabel (str) – Label for y-axis # cspell:disable-line

  • title (str | None) – Plot title. If None, auto-generates from environment name

  • figsize (tuple[float, float]) – Figure size (width, height) in inches

  • show_std (bool) – Whether to show standard deviation as shaded region

  • ax (Axes | None) – Existing axes to plot on. If None, creates new figure

Returns:

Tuple of (figure, axes) objects

Return type:

tuple[Figure, Axes]

Example

>>> results = train_and_evaluate(config)
>>> fig, ax = plot_training_curve(results)
>>> plt.show()
>>> # Compare multiple runs
>>> results_list = [results_dqn, results_ppo]
>>> fig, ax = plot_training_curve(results_list, labels=["DQN", "PPO"])
>>> plt.show()