Parallelism in Myriad

Myriad is designed to run large numbers of environments in parallel on a single GPU. This tutorial shows what that looks like in practice, using PQN (Parallelized Q-Network) as a concrete example.

PQN (Simplifying Deep Temporal Difference Learning, Gallici et al., ICLR 2025 Spotlight) is a good fit for Myriad: it drops the replay buffer and target network of DQN in favour of λ-returns and layer normalisation, and updates directly on rollouts collected from many parallel environments. No replay buffer means no fixed memory overhead — just scale up num_envs.

We use hyperparameters from PureJaxQL — the reference JAX implementation — and run both algorithms entirely within Myriad.

Three things this tutorial demonstrates:

  1. Correctness: Myriad’s PQN reproduces PureJaxQL’s CartPole result with the same hyperparameters

  2. Sample efficiency: PQN (32 envs) converges faster per environment step than single-env DQN

  3. Scalability: going from 1 to 2048 parallel environments barely increases wall-clock time — JAX vectorises the computation across the GPU, so adding environments is nearly free

import shutil
import time

import jax
import matplotlib.pyplot as plt
import numpy as np
from _helpers import setup_logging

from myriad import create_config, train_and_evaluate

setup_logging()

devices = jax.devices()
print(f"Devices: {devices}")
gpu_devices = [d for d in devices if "gpu" in str(d).lower()]
if gpu_devices:
    print(f"GPU: {gpu_devices[0]}")
INFO:2026-02-23 16:48:55,026:jax._src.xla_bridge:822: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
[16:48:55 INFO] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
Devices: [CudaDevice(id=0)]

Section A: Myriad PQN — Replicating PureJaxQL’s CartPole Result

We use PureJaxQL’s published hyperparameters verbatim (source). Same seed, same config, expected convergence to ~500 return. Results are identical on the same GPU hardware; minor numerical differences may occur across different GPU models or driver versions due to non-deterministic GPU kernel optimizations.

# PureJaxQL published hyperparameters for CartPole
# Reference: https://github.com/mttga/purejaxql
PUREJAXQL_PARAMS = {
    "env": "cartpole-control",
    "num_envs": 32,
    "steps_per_env": 3_000,
    "rollout_steps": 64,
    "learning_rate": 1e-4,
    "max_grad_norm": 10.0,
    "num_epochs": 4,
    "num_minibatches": 16,
    "gamma": 0.99,
    "lambda_": 0.95,
    "reward_scale": 0.1,
    "epsilon_start": 1.0,
    "epsilon_end": 0.2,
    "epsilon_decay_fraction": 0.2,  # decay over 20% of updates (= first 62 of 312 total)
    "hidden_size": 256,
    "num_layers": 2,
    "eval_frequency": 640,
    "eval_rollouts": 128,
    "seed": 0,
}

config_pqn = create_config(agent="pqn", **PUREJAXQL_PARAMS)
results_pqn = train_and_evaluate(config_pqn)

pqn_final_return = results_pqn.eval_metrics.mean_return[-1]
print(f"PQN final return: {pqn_final_return:.1f} (PureJaxQL reported: ~495)")
[16:49:02 INFO] Step  640/3000 ( 21%) | loss=0.343 | eval_return=305.90
[16:49:02 INFO] Step 1280/3000 ( 43%) | loss=0.140 | eval_return=113.06
[16:49:03 INFO] Step 1920/3000 ( 64%) | loss=0.131 | eval_return=500.00
[16:49:03 INFO] Step 2560/3000 ( 85%) | loss=0.029 | eval_return=500.00
[16:49:03 INFO] Artifacts: outputs/2026-02-23/16-48-57
  ├── .hydra/  (config snapshot)
  ├── results.pkl  (metrics & config)
  └── run_metadata.yaml  (timing & status)
PQN final return: 500.0 (PureJaxQL reported: ~495)

Section B: Myriad DQN — Classical Single-Environment Baseline

DQN (Deep Q-Network) is the classic off-policy baseline. Unlike PQN, it uses:

  • A replay buffer to decorrelate experience

  • A target network to stabilise TD targets

We run DQN with a single environment for 50K steps — enough to converge on CartPole. We’ll compare its learning curve against PQN in the next section.

config_dqn = create_config(
    env="cartpole-control",
    agent="dqn",
    num_envs=1,         # classical single-environment DQN
    steps_per_env=50_000,
    eval_frequency=5_000,
    eval_rollouts=50,
    epsilon_decay_fraction=0.4,
    target_network_frequency=100,
    seed=0,
)

results_dqn = train_and_evaluate(config_dqn)

dqn_final_return = results_dqn.eval_metrics.mean_return[-1]
print(f"DQN final return: {dqn_final_return:.1f} (100K sequential steps)")
[16:49:10 INFO] Step  5000/50000 ( 10%) | loss=6.275 | eval_return=176.90
[16:49:10 INFO] Step 10000/50000 ( 20%) | loss=13.510 | eval_return=169.30
[16:49:11 INFO] Step 15000/50000 ( 30%) | loss=40.784 | eval_return=196.68
[16:49:11 INFO] Step 20000/50000 ( 40%) | loss=2.466 | eval_return=147.30
[16:49:12 INFO] Step 25000/50000 ( 50%) | loss=1.565 | eval_return=175.00
[16:49:12 INFO] Step 30000/50000 ( 60%) | loss=1.913 | eval_return=489.02
[16:49:12 INFO] Step 35000/50000 ( 70%) | loss=1.953 | eval_return=500.00
[16:49:13 INFO] Step 40000/50000 ( 80%) | loss=4.523 | eval_return=92.70
[16:49:13 INFO] Step 45000/50000 ( 90%) | loss=0.842 | eval_return=500.00
[16:49:14 INFO] Step 50000/50000 (100%) | loss=370.990 | eval_return=500.00
[16:49:14 INFO] Artifacts: outputs/2026-02-23/16-49-05
  ├── .hydra/  (config snapshot)
  ├── results.pkl  (metrics & config)
  └── run_metadata.yaml  (timing & status)
DQN final return: 500.0 (100K sequential steps)

Section C: PQN vs DQN — Learning Curves

Both algorithms on the same plot, x-axis is steps per environment. PQN (32 parallel envs) converges in far fewer per-environment steps — the large parallel batch gives better gradient signal early on.

fig, ax = plt.subplots(figsize=(8, 5))

# PQN (32 envs): shade ±1 std across eval rollouts
steps_pqn = np.array(results_pqn.eval_metrics.steps_per_env)
mean_pqn = np.array(results_pqn.eval_metrics.mean_return)
std_pqn = np.array(results_pqn.eval_metrics.std_return)
ax.plot(steps_pqn, mean_pqn, "o-", label="PQN (32 envs)", color="#3498db", alpha=0.9)
ax.fill_between(steps_pqn, mean_pqn - std_pqn, mean_pqn + std_pqn, alpha=0.15, color="#3498db")

# DQN (1 env)
steps_dqn = np.array(results_dqn.eval_metrics.steps_per_env)
mean_dqn = np.array(results_dqn.eval_metrics.mean_return)
ax.plot(steps_dqn, mean_dqn, "s--", label="DQN (1 env)", color="#e67e22", alpha=0.9)

ax.axhline(500, color="gray", linestyle=":", alpha=0.5, label="Max return (500)")
ax.set_xlabel("Steps per Environment")
ax.set_ylabel("Mean Return")
ax.set_title("PQN vs DQN: Learning Curves on CartPole")
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

print(f"{'Algorithm':<10} {'Envs':>6} {'Steps/Env':>10} {'Final Return':>14}")
print("-" * 44)
print(f"{'PQN':<10} {32:>6} {20_000:>10,} {pqn_final_return:>14.1f}")
print(f"{'DQN':<10} {1:>6} {100_000:>10,} {dqn_final_return:>14.1f}")
../../_images/6b5a4fbfc8c3e974cad1581ec6e9835ecd49d282226bddf9059162bc244f3d93.png
Algorithm    Envs  Steps/Env   Final Return
--------------------------------------------
PQN            32     20,000          500.0
DQN             1    100,000          500.0

Section D: Myriad Scales — Throughput vs Parallelism

The key question: does wall-clock time increase as we add more parallel environments?

With JAX, the answer is mostly no. Vectorised environment steps and batched network updates mean that scaling from 1 to 2048 environments costs very little extra time on the GPU — throughput (steps/sec) grows by orders of magnitude while wall-clock time stays roughly flat.

CartPole is a trivially easy problem, and that’s fine here. We’re not measuring sample efficiency or final performance — we’re measuring raw throughput. The trivial task means the scaling signal is clean, uncontaminated by learning dynamics.

The same capability applies directly to harder problems: complex physics, stochastic dynamics, system identification across uncertain parameter spaces. CartPole just makes it easy to verify that the machinery works.

sweep_num_envs = [1, 32, 128, 512, 1024, 2048]
sweep_results = {}
sweep_times = {}

for num_envs in sweep_num_envs:
    print(f"Training PQN with {num_envs} parallel environment(s)...")
    config = create_config(
        agent="pqn",
        **{**PUREJAXQL_PARAMS, "num_envs": num_envs},
    )
    t0 = time.perf_counter()
    results = train_and_evaluate(config)
    elapsed = time.perf_counter() - t0

    sweep_results[num_envs] = results
    sweep_times[num_envs] = elapsed

    total = num_envs * PUREJAXQL_PARAMS["steps_per_env"]
    throughput = total / elapsed
    print(f"  final return: {results.eval_metrics.mean_return[-1]:.1f}  "
          f"time: {elapsed:.1f}s  throughput: {throughput:,.0f} steps/s\n")
Training PQN with 1 parallel environment(s)...
[16:49:26 INFO] Step  640/3000 ( 21%) | loss=0.213 | eval_return=150.16
[16:49:26 INFO] Step 1280/3000 ( 43%) | loss=0.147 | eval_return=33.28
[16:49:26 INFO] Step 1920/3000 ( 64%) | loss=0.084 | eval_return=55.29
[16:49:26 INFO] Step 2560/3000 ( 85%) | loss=0.076 | eval_return=124.15
[16:49:26 INFO] Artifacts: outputs/2026-02-23/16-49-22
  ├── .hydra/  (config snapshot)
  ├── results.pkl  (metrics & config)
  └── run_metadata.yaml  (timing & status)
  final return: 124.1  time: 3.7s  throughput: 821 steps/s

Training PQN with 32 parallel environment(s)...
[16:49:29 INFO] Step  640/3000 ( 21%) | loss=0.343 | eval_return=305.90
[16:49:29 INFO] Step 1280/3000 ( 43%) | loss=0.140 | eval_return=113.06
[16:49:29 INFO] Step 1920/3000 ( 64%) | loss=0.131 | eval_return=500.00
[16:49:30 INFO] Step 2560/3000 ( 85%) | loss=0.029 | eval_return=500.00
[16:49:30 INFO] Artifacts: outputs/2026-02-23/16-49-26
  ├── .hydra/  (config snapshot)
  ├── results.pkl  (metrics & config)
  └── run_metadata.yaml  (timing & status)
  final return: 500.0  time: 3.7s  throughput: 25,641 steps/s

Training PQN with 128 parallel environment(s)...
[16:49:34 INFO] Step  640/3000 ( 21%) | loss=0.397 | eval_return=399.20
[16:49:34 INFO] Step 1280/3000 ( 43%) | loss=0.582 | eval_return=392.35
[16:49:34 INFO] Step 1920/3000 ( 64%) | loss=0.661 | eval_return=500.00
[16:49:34 INFO] Step 2560/3000 ( 85%) | loss=0.482 | eval_return=500.00
[16:49:34 INFO] Artifacts: outputs/2026-02-23/16-49-30
  ├── .hydra/  (config snapshot)
  ├── results.pkl  (metrics & config)
  └── run_metadata.yaml  (timing & status)
  final return: 500.0  time: 4.7s  throughput: 82,078 steps/s

Training PQN with 512 parallel environment(s)...
[16:49:39 INFO] Step  640/3000 ( 21%) | loss=0.454 | eval_return=431.59
[16:49:39 INFO] Step 1280/3000 ( 43%) | loss=0.261 | eval_return=500.00
[16:49:39 INFO] Step 1920/3000 ( 64%) | loss=0.271 | eval_return=500.00
[16:49:39 INFO] Step 2560/3000 ( 85%) | loss=0.954 | eval_return=394.79
[16:49:39 INFO] Artifacts: outputs/2026-02-23/16-49-34
  ├── .hydra/  (config snapshot)
  ├── results.pkl  (metrics & config)
  └── run_metadata.yaml  (timing & status)
  final return: 394.8  time: 4.8s  throughput: 319,636 steps/s

Training PQN with 1024 parallel environment(s)...
[16:49:44 INFO] Step  640/3000 ( 21%) | loss=0.377 | eval_return=448.72
[16:49:44 INFO] Step 1280/3000 ( 43%) | loss=0.260 | eval_return=500.00
[16:49:44 INFO] Step 1920/3000 ( 64%) | loss=0.603 | eval_return=408.43
[16:49:44 INFO] Step 2560/3000 ( 85%) | loss=0.114 | eval_return=379.23
[16:49:45 INFO] Artifacts: outputs/2026-02-23/16-49-39
  ├── .hydra/  (config snapshot)
  ├── results.pkl  (metrics & config)
  └── run_metadata.yaml  (timing & status)
  final return: 379.2  time: 5.4s  throughput: 572,866 steps/s

Training PQN with 2048 parallel environment(s)...
[16:49:49 INFO] Step  640/3000 ( 21%) | loss=0.423 | eval_return=351.34
[16:49:50 INFO] Step 1280/3000 ( 43%) | loss=0.216 | eval_return=500.00
[16:49:50 INFO] Step 1920/3000 ( 64%) | loss=0.343 | eval_return=500.00
[16:49:50 INFO] Step 2560/3000 ( 85%) | loss=0.342 | eval_return=500.00
[16:49:50 INFO] Artifacts: outputs/2026-02-23/16-49-45
  ├── .hydra/  (config snapshot)
  ├── results.pkl  (metrics & config)
  └── run_metadata.yaml  (timing & status)
  final return: 500.0  time: 5.5s  throughput: 1,108,173 steps/s
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

colors_sweep = ["#e74c3c", "#e67e22", "#f1c40f", "#2ecc71", "#3498db", "#9b59b6"]

# Plot 1: Learning curves (return vs steps_per_env)
for color, num_envs in zip(colors_sweep, sweep_num_envs):
    results = sweep_results[num_envs]
    steps = results.eval_metrics.steps_per_env
    mean = results.eval_metrics.mean_return
    std = results.eval_metrics.std_return
    label = f"{num_envs} env{'s' if num_envs > 1 else ''}"
    ax1.plot(steps, mean, "o-", label=label, color=color, alpha=0.9)
    ax1.fill_between(
        steps,
        np.array(mean) - np.array(std),
        np.array(mean) + np.array(std),
        alpha=0.15,
        color=color,
    )

ax1.set_xlabel("Steps per Env")
ax1.set_ylabel("Mean Return")
ax1.set_title("Learning Curves: PQN on CartPole")
ax1.legend(fontsize=8)
ax1.grid(alpha=0.3)

# Plot 2: Throughput vs num_envs (log scale — values span 3 orders of magnitude)
steps_per_env = PUREJAXQL_PARAMS["steps_per_env"]
throughputs = [
    (n * steps_per_env) / sweep_times[n]
    for n in sweep_num_envs
]

bars = ax2.bar(
    np.arange(len(sweep_num_envs)),
    throughputs,
    color=colors_sweep,
    edgecolor="white",
    width=0.6,
)
ax2.set_xticks(np.arange(len(sweep_num_envs)))
ax2.set_xticklabels([str(n) for n in sweep_num_envs], rotation=30, ha="right")
for i, tp in enumerate(throughputs):
    ax2.text(i, tp * 1.08, f"{tp/1e6:.2f}M" if tp >= 1e6 else f"{tp/1e3:.0f}K",
             ha="center", va="bottom", fontsize=8)

ax2.set_yscale("log")
ax2.set_xlabel("Number of Parallel Environments")
ax2.set_ylabel("Throughput (steps/sec, log scale)")
ax2.set_title("GPU Throughput vs Parallelism")
ax2.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()
../../_images/bce95a45e3246cdd419bd5b3179a10a73db5e4d35dfdffa869d0360f810ec3a5.png
The Kernel crashed while executing code in the current cell or a previous cell. 

Please review the code in the cell(s) to identify a possible cause of the failure. 

Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. 

View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.

Key Takeaways

Myriad works with high-parallelism algorithms out of the box (Sections A–B):
Using PureJaxQL’s published hyperparameters verbatim, Myriad’s PQN reaches ~500 mean return on CartPole — matching their reported results with no adaptation needed. This confirms that algorithms designed around large parallel batches integrate directly into Myriad’s API.

Parallelism improves sample efficiency (Section C):
On a shared x-axis (steps per environment), PQN with 32 parallel environments converges much faster than single-environment DQN. Larger rollout batches provide better gradient signal per update — a direct benefit of Myriad’s parallel infrastructure.

Myriad scales: more environments, same wall-clock time (Section D):
Scaling from 1 to 2048 parallel environments, throughput grows by roughly 3 orders of magnitude while wall-clock time barely changes. JAX vectorises environment steps and network updates across the GPU — adding environments is nearly free.

CartPole is a proof of concept, not the point:
CartPole is trivially easy, and any configuration solves it quickly. The real value of this scalability appears on harder problems — complex physics, stochastic dynamics, system identification with uncertain parameters — where running thousands of parallel experiments is the only way to gather sufficient statistical signal.


References:

Cleanup

for results in [results_pqn, results_dqn, *sweep_results.values()]:
    shutil.rmtree(results.run_dir, ignore_errors=True)