Parallelism in Myriad¶
Myriad is designed to run large numbers of environments in parallel on a single GPU. This tutorial shows what that looks like in practice, using PQN (Parallelized Q-Network) as a concrete example.
PQN (Simplifying Deep Temporal Difference Learning,
Gallici et al., ICLR 2025 Spotlight) is a good fit for Myriad: it drops the replay buffer
and target network of DQN in favour of λ-returns and layer normalisation, and updates
directly on rollouts collected from many parallel environments. No replay buffer means
no fixed memory overhead — just scale up num_envs.
We use hyperparameters from PureJaxQL — the reference JAX implementation — and run both algorithms entirely within Myriad.
Three things this tutorial demonstrates:
Correctness: Myriad’s PQN reproduces PureJaxQL’s CartPole result with the same hyperparameters
Sample efficiency: PQN (32 envs) converges faster per environment step than single-env DQN
Scalability: going from 1 to 2048 parallel environments barely increases wall-clock time — JAX vectorises the computation across the GPU, so adding environments is nearly free
import shutil
import time
import jax
import matplotlib.pyplot as plt
import numpy as np
from _helpers import setup_logging
from myriad import create_config, train_and_evaluate
setup_logging()
devices = jax.devices()
print(f"Devices: {devices}")
gpu_devices = [d for d in devices if "gpu" in str(d).lower()]
if gpu_devices:
print(f"GPU: {gpu_devices[0]}")
INFO:2026-02-23 16:48:55,026:jax._src.xla_bridge:822: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
[16:48:55 INFO] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
Devices: [CudaDevice(id=0)]
Section A: Myriad PQN — Replicating PureJaxQL’s CartPole Result¶
We use PureJaxQL’s published hyperparameters verbatim (source). Same seed, same config, expected convergence to ~500 return. Results are identical on the same GPU hardware; minor numerical differences may occur across different GPU models or driver versions due to non-deterministic GPU kernel optimizations.
# PureJaxQL published hyperparameters for CartPole
# Reference: https://github.com/mttga/purejaxql
PUREJAXQL_PARAMS = {
"env": "cartpole-control",
"num_envs": 32,
"steps_per_env": 3_000,
"rollout_steps": 64,
"learning_rate": 1e-4,
"max_grad_norm": 10.0,
"num_epochs": 4,
"num_minibatches": 16,
"gamma": 0.99,
"lambda_": 0.95,
"reward_scale": 0.1,
"epsilon_start": 1.0,
"epsilon_end": 0.2,
"epsilon_decay_fraction": 0.2, # decay over 20% of updates (= first 62 of 312 total)
"hidden_size": 256,
"num_layers": 2,
"eval_frequency": 640,
"eval_rollouts": 128,
"seed": 0,
}
config_pqn = create_config(agent="pqn", **PUREJAXQL_PARAMS)
results_pqn = train_and_evaluate(config_pqn)
pqn_final_return = results_pqn.eval_metrics.mean_return[-1]
print(f"PQN final return: {pqn_final_return:.1f} (PureJaxQL reported: ~495)")
[16:49:02 INFO] Step 640/3000 ( 21%) | loss=0.343 | eval_return=305.90
[16:49:02 INFO] Step 1280/3000 ( 43%) | loss=0.140 | eval_return=113.06
[16:49:03 INFO] Step 1920/3000 ( 64%) | loss=0.131 | eval_return=500.00
[16:49:03 INFO] Step 2560/3000 ( 85%) | loss=0.029 | eval_return=500.00
[16:49:03 INFO] Artifacts: outputs/2026-02-23/16-48-57
├── .hydra/ (config snapshot)
├── results.pkl (metrics & config)
└── run_metadata.yaml (timing & status)
PQN final return: 500.0 (PureJaxQL reported: ~495)
Section B: Myriad DQN — Classical Single-Environment Baseline¶
DQN (Deep Q-Network) is the classic off-policy baseline. Unlike PQN, it uses:
A replay buffer to decorrelate experience
A target network to stabilise TD targets
We run DQN with a single environment for 50K steps — enough to converge on CartPole. We’ll compare its learning curve against PQN in the next section.
config_dqn = create_config(
env="cartpole-control",
agent="dqn",
num_envs=1, # classical single-environment DQN
steps_per_env=50_000,
eval_frequency=5_000,
eval_rollouts=50,
epsilon_decay_fraction=0.4,
target_network_frequency=100,
seed=0,
)
results_dqn = train_and_evaluate(config_dqn)
dqn_final_return = results_dqn.eval_metrics.mean_return[-1]
print(f"DQN final return: {dqn_final_return:.1f} (100K sequential steps)")
[16:49:10 INFO] Step 5000/50000 ( 10%) | loss=6.275 | eval_return=176.90
[16:49:10 INFO] Step 10000/50000 ( 20%) | loss=13.510 | eval_return=169.30
[16:49:11 INFO] Step 15000/50000 ( 30%) | loss=40.784 | eval_return=196.68
[16:49:11 INFO] Step 20000/50000 ( 40%) | loss=2.466 | eval_return=147.30
[16:49:12 INFO] Step 25000/50000 ( 50%) | loss=1.565 | eval_return=175.00
[16:49:12 INFO] Step 30000/50000 ( 60%) | loss=1.913 | eval_return=489.02
[16:49:12 INFO] Step 35000/50000 ( 70%) | loss=1.953 | eval_return=500.00
[16:49:13 INFO] Step 40000/50000 ( 80%) | loss=4.523 | eval_return=92.70
[16:49:13 INFO] Step 45000/50000 ( 90%) | loss=0.842 | eval_return=500.00
[16:49:14 INFO] Step 50000/50000 (100%) | loss=370.990 | eval_return=500.00
[16:49:14 INFO] Artifacts: outputs/2026-02-23/16-49-05
├── .hydra/ (config snapshot)
├── results.pkl (metrics & config)
└── run_metadata.yaml (timing & status)
DQN final return: 500.0 (100K sequential steps)
Section C: PQN vs DQN — Learning Curves¶
Both algorithms on the same plot, x-axis is steps per environment. PQN (32 parallel envs) converges in far fewer per-environment steps — the large parallel batch gives better gradient signal early on.
fig, ax = plt.subplots(figsize=(8, 5))
# PQN (32 envs): shade ±1 std across eval rollouts
steps_pqn = np.array(results_pqn.eval_metrics.steps_per_env)
mean_pqn = np.array(results_pqn.eval_metrics.mean_return)
std_pqn = np.array(results_pqn.eval_metrics.std_return)
ax.plot(steps_pqn, mean_pqn, "o-", label="PQN (32 envs)", color="#3498db", alpha=0.9)
ax.fill_between(steps_pqn, mean_pqn - std_pqn, mean_pqn + std_pqn, alpha=0.15, color="#3498db")
# DQN (1 env)
steps_dqn = np.array(results_dqn.eval_metrics.steps_per_env)
mean_dqn = np.array(results_dqn.eval_metrics.mean_return)
ax.plot(steps_dqn, mean_dqn, "s--", label="DQN (1 env)", color="#e67e22", alpha=0.9)
ax.axhline(500, color="gray", linestyle=":", alpha=0.5, label="Max return (500)")
ax.set_xlabel("Steps per Environment")
ax.set_ylabel("Mean Return")
ax.set_title("PQN vs DQN: Learning Curves on CartPole")
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
print(f"{'Algorithm':<10} {'Envs':>6} {'Steps/Env':>10} {'Final Return':>14}")
print("-" * 44)
print(f"{'PQN':<10} {32:>6} {20_000:>10,} {pqn_final_return:>14.1f}")
print(f"{'DQN':<10} {1:>6} {100_000:>10,} {dqn_final_return:>14.1f}")
Algorithm Envs Steps/Env Final Return
--------------------------------------------
PQN 32 20,000 500.0
DQN 1 100,000 500.0
Section D: Myriad Scales — Throughput vs Parallelism¶
The key question: does wall-clock time increase as we add more parallel environments?
With JAX, the answer is mostly no. Vectorised environment steps and batched network updates mean that scaling from 1 to 2048 environments costs very little extra time on the GPU — throughput (steps/sec) grows by orders of magnitude while wall-clock time stays roughly flat.
CartPole is a trivially easy problem, and that’s fine here. We’re not measuring sample efficiency or final performance — we’re measuring raw throughput. The trivial task means the scaling signal is clean, uncontaminated by learning dynamics.
The same capability applies directly to harder problems: complex physics, stochastic dynamics, system identification across uncertain parameter spaces. CartPole just makes it easy to verify that the machinery works.
sweep_num_envs = [1, 32, 128, 512, 1024, 2048]
sweep_results = {}
sweep_times = {}
for num_envs in sweep_num_envs:
print(f"Training PQN with {num_envs} parallel environment(s)...")
config = create_config(
agent="pqn",
**{**PUREJAXQL_PARAMS, "num_envs": num_envs},
)
t0 = time.perf_counter()
results = train_and_evaluate(config)
elapsed = time.perf_counter() - t0
sweep_results[num_envs] = results
sweep_times[num_envs] = elapsed
total = num_envs * PUREJAXQL_PARAMS["steps_per_env"]
throughput = total / elapsed
print(f" final return: {results.eval_metrics.mean_return[-1]:.1f} "
f"time: {elapsed:.1f}s throughput: {throughput:,.0f} steps/s\n")
Training PQN with 1 parallel environment(s)...
[16:49:26 INFO] Step 640/3000 ( 21%) | loss=0.213 | eval_return=150.16
[16:49:26 INFO] Step 1280/3000 ( 43%) | loss=0.147 | eval_return=33.28
[16:49:26 INFO] Step 1920/3000 ( 64%) | loss=0.084 | eval_return=55.29
[16:49:26 INFO] Step 2560/3000 ( 85%) | loss=0.076 | eval_return=124.15
[16:49:26 INFO] Artifacts: outputs/2026-02-23/16-49-22
├── .hydra/ (config snapshot)
├── results.pkl (metrics & config)
└── run_metadata.yaml (timing & status)
final return: 124.1 time: 3.7s throughput: 821 steps/s
Training PQN with 32 parallel environment(s)...
[16:49:29 INFO] Step 640/3000 ( 21%) | loss=0.343 | eval_return=305.90
[16:49:29 INFO] Step 1280/3000 ( 43%) | loss=0.140 | eval_return=113.06
[16:49:29 INFO] Step 1920/3000 ( 64%) | loss=0.131 | eval_return=500.00
[16:49:30 INFO] Step 2560/3000 ( 85%) | loss=0.029 | eval_return=500.00
[16:49:30 INFO] Artifacts: outputs/2026-02-23/16-49-26
├── .hydra/ (config snapshot)
├── results.pkl (metrics & config)
└── run_metadata.yaml (timing & status)
final return: 500.0 time: 3.7s throughput: 25,641 steps/s
Training PQN with 128 parallel environment(s)...
[16:49:34 INFO] Step 640/3000 ( 21%) | loss=0.397 | eval_return=399.20
[16:49:34 INFO] Step 1280/3000 ( 43%) | loss=0.582 | eval_return=392.35
[16:49:34 INFO] Step 1920/3000 ( 64%) | loss=0.661 | eval_return=500.00
[16:49:34 INFO] Step 2560/3000 ( 85%) | loss=0.482 | eval_return=500.00
[16:49:34 INFO] Artifacts: outputs/2026-02-23/16-49-30
├── .hydra/ (config snapshot)
├── results.pkl (metrics & config)
└── run_metadata.yaml (timing & status)
final return: 500.0 time: 4.7s throughput: 82,078 steps/s
Training PQN with 512 parallel environment(s)...
[16:49:39 INFO] Step 640/3000 ( 21%) | loss=0.454 | eval_return=431.59
[16:49:39 INFO] Step 1280/3000 ( 43%) | loss=0.261 | eval_return=500.00
[16:49:39 INFO] Step 1920/3000 ( 64%) | loss=0.271 | eval_return=500.00
[16:49:39 INFO] Step 2560/3000 ( 85%) | loss=0.954 | eval_return=394.79
[16:49:39 INFO] Artifacts: outputs/2026-02-23/16-49-34
├── .hydra/ (config snapshot)
├── results.pkl (metrics & config)
└── run_metadata.yaml (timing & status)
final return: 394.8 time: 4.8s throughput: 319,636 steps/s
Training PQN with 1024 parallel environment(s)...
[16:49:44 INFO] Step 640/3000 ( 21%) | loss=0.377 | eval_return=448.72
[16:49:44 INFO] Step 1280/3000 ( 43%) | loss=0.260 | eval_return=500.00
[16:49:44 INFO] Step 1920/3000 ( 64%) | loss=0.603 | eval_return=408.43
[16:49:44 INFO] Step 2560/3000 ( 85%) | loss=0.114 | eval_return=379.23
[16:49:45 INFO] Artifacts: outputs/2026-02-23/16-49-39
├── .hydra/ (config snapshot)
├── results.pkl (metrics & config)
└── run_metadata.yaml (timing & status)
final return: 379.2 time: 5.4s throughput: 572,866 steps/s
Training PQN with 2048 parallel environment(s)...
[16:49:49 INFO] Step 640/3000 ( 21%) | loss=0.423 | eval_return=351.34
[16:49:50 INFO] Step 1280/3000 ( 43%) | loss=0.216 | eval_return=500.00
[16:49:50 INFO] Step 1920/3000 ( 64%) | loss=0.343 | eval_return=500.00
[16:49:50 INFO] Step 2560/3000 ( 85%) | loss=0.342 | eval_return=500.00
[16:49:50 INFO] Artifacts: outputs/2026-02-23/16-49-45
├── .hydra/ (config snapshot)
├── results.pkl (metrics & config)
└── run_metadata.yaml (timing & status)
final return: 500.0 time: 5.5s throughput: 1,108,173 steps/s
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))
colors_sweep = ["#e74c3c", "#e67e22", "#f1c40f", "#2ecc71", "#3498db", "#9b59b6"]
# Plot 1: Learning curves (return vs steps_per_env)
for color, num_envs in zip(colors_sweep, sweep_num_envs):
results = sweep_results[num_envs]
steps = results.eval_metrics.steps_per_env
mean = results.eval_metrics.mean_return
std = results.eval_metrics.std_return
label = f"{num_envs} env{'s' if num_envs > 1 else ''}"
ax1.plot(steps, mean, "o-", label=label, color=color, alpha=0.9)
ax1.fill_between(
steps,
np.array(mean) - np.array(std),
np.array(mean) + np.array(std),
alpha=0.15,
color=color,
)
ax1.set_xlabel("Steps per Env")
ax1.set_ylabel("Mean Return")
ax1.set_title("Learning Curves: PQN on CartPole")
ax1.legend(fontsize=8)
ax1.grid(alpha=0.3)
# Plot 2: Throughput vs num_envs (log scale — values span 3 orders of magnitude)
steps_per_env = PUREJAXQL_PARAMS["steps_per_env"]
throughputs = [
(n * steps_per_env) / sweep_times[n]
for n in sweep_num_envs
]
bars = ax2.bar(
np.arange(len(sweep_num_envs)),
throughputs,
color=colors_sweep,
edgecolor="white",
width=0.6,
)
ax2.set_xticks(np.arange(len(sweep_num_envs)))
ax2.set_xticklabels([str(n) for n in sweep_num_envs], rotation=30, ha="right")
for i, tp in enumerate(throughputs):
ax2.text(i, tp * 1.08, f"{tp/1e6:.2f}M" if tp >= 1e6 else f"{tp/1e3:.0f}K",
ha="center", va="bottom", fontsize=8)
ax2.set_yscale("log")
ax2.set_xlabel("Number of Parallel Environments")
ax2.set_ylabel("Throughput (steps/sec, log scale)")
ax2.set_title("GPU Throughput vs Parallelism")
ax2.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.show()
The Kernel crashed while executing code in the current cell or a previous cell.
Please review the code in the cell(s) to identify a possible cause of the failure.
Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info.
View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
Key Takeaways¶
Myriad works with high-parallelism algorithms out of the box (Sections A–B):
Using PureJaxQL’s published hyperparameters verbatim,
Myriad’s PQN reaches ~500 mean return on CartPole — matching their reported results with no adaptation needed. This confirms that algorithms designed around large parallel batches integrate
directly into Myriad’s API.
Parallelism improves sample efficiency (Section C):
On a shared x-axis (steps per environment), PQN with 32 parallel environments converges
much faster than single-environment DQN. Larger rollout batches provide better gradient
signal per update — a direct benefit of Myriad’s parallel infrastructure.
Myriad scales: more environments, same wall-clock time (Section D):
Scaling from 1 to 2048 parallel environments, throughput grows by roughly 3 orders of
magnitude while wall-clock time barely changes. JAX vectorises environment steps and
network updates across the GPU — adding environments is nearly free.
CartPole is a proof of concept, not the point:
CartPole is trivially easy, and any configuration solves it quickly. The real value of
this scalability appears on harder problems — complex physics, stochastic dynamics,
system identification with uncertain parameters — where running thousands of parallel
experiments is the only way to gather sufficient statistical signal.
References:
PQN paper: Simplifying Deep Temporal Difference Learning (Gallici et al., ICLR 2025 Spotlight)
PureJaxQL: https://github.com/mttga/purejaxql
Cleanup¶
for results in [results_pqn, results_dqn, *sweep_results.values()]:
shutil.rmtree(results.run_dir, ignore_errors=True)