PQN

Parallelized Q-Network (PQN) with LayerNorm agent implementation using JAX and Flax.

PQN is an on-policy value-based RL algorithm designed for massively parallel training.

Features:

  • Lambda-returns (GAE-style) instead of 1-step TD targets

  • LayerNorm for stability (no target network needed)

  • Multi-epoch training on collected rollouts

  • Epsilon-greedy exploration with linear decay

  • Gradient clipping for stable training

Reference: PureJaxQL

class myriad.agents.rl.pqn.QNetwork(action_dim, hidden_size=128, num_layers=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

MLP Q-network with LayerNorm for discrete action spaces.

action_dim: int
hidden_size: int = 128
num_layers: int = 2
__call__(x)[source]

Forward pass to compute Q-values for all actions.

Parameters:

x (Array) – Observation array of shape (obs_dim,) or (batch_size, obs_dim)

Returns:

Q-values of shape (action_dim,) or (batch_size, action_dim)

Return type:

Array

__init__(action_dim, hidden_size=128, num_layers=2, parent=<flax.linen.module._Sentinel object>, name=None)
name: str | None = None
parent: Module | Scope | _Sentinel | None = None
scope: Scope | None = None
class myriad.agents.rl.pqn.AgentParams(action_space, learning_rate, gamma, lambda_, reward_scale, epsilon_start, epsilon_end, epsilon_decay_steps, max_grad_norm, num_epochs, num_minibatches, hidden_size, num_layers, lr_end, lr_decay_steps)[source]

Bases: object

Static parameters for the PQN agent.

Variables:
  • action_space (myriad.core.spaces.Space) – Action space (must be Discrete).

  • learning_rate (float) – Learning rate for Adam optimizer.

  • gamma (float) – Discount factor for future rewards.

  • lambda (float) – Lambda parameter for lambda-returns (0.0 = 1-step TD, 1.0 = Monte Carlo).

  • reward_scale (float) – Internal scaling factor applied to rewards before computing returns.

  • epsilon_start (float) – Initial exploration rate.

  • epsilon_end (float) – Final exploration rate after decay.

  • epsilon_decay_steps (int) – Number of environment steps (per env) to decay epsilon from start to end.

  • max_grad_norm (float) – Maximum gradient norm for clipping.

  • num_epochs (int) – Number of training epochs per rollout batch.

  • num_minibatches (int) – Number of minibatches per epoch.

  • hidden_size (int) – Hidden layer size for Q-network.

  • num_layers (int) – Number of hidden layers in Q-network.

  • lr_end (float) – Final learning rate after linear decay (only used when lr_decay_steps > 0).

  • lr_decay_steps (int) – Total optimizer gradient steps to decay LR over (0 = disabled).

action_space: Space
learning_rate: float
gamma: float
lambda_: float
reward_scale: float
epsilon_start: float
epsilon_end: float
epsilon_decay_steps: int
max_grad_norm: float
num_epochs: int
num_minibatches: int
hidden_size: int
num_layers: int
lr_end: float
lr_decay_steps: int
__init__(action_space, learning_rate, gamma, lambda_, reward_scale, epsilon_start, epsilon_end, epsilon_decay_steps, max_grad_norm, num_epochs, num_minibatches, hidden_size, num_layers, lr_end, lr_decay_steps)
replace(**updates)

Returns a new object replacing the specified fields with new values.

class myriad.agents.rl.pqn.AgentState(train_state, global_step)[source]

Bases: object

State of the PQN agent.

Variables:
  • train_state (flax.training.train_state.TrainState) – Flax TrainState containing network params and optimizer state.

  • global_step (jax.jaxlib._jax.Array) – Number of environment steps taken per env (used for epsilon decay).

train_state: TrainState
global_step: Array
__init__(train_state, global_step)
replace(**updates)

Returns a new object replacing the specified fields with new values.

myriad.agents.rl.pqn.make_agent(action_space, learning_rate=0.00025, reward_scale=1.0, gamma=0.99, lambda_=0.65, epsilon_start=1.0, epsilon_end=0.05, epsilon_decay_steps=50000, max_grad_norm=0.5, num_epochs=4, num_minibatches=4, hidden_size=128, num_layers=2, lr_end=1e-20, lr_decay_steps=0)[source]

Factory function to create a PQN agent.

Parameters:
  • action_space (Space) – Action space (must be Discrete)

  • learning_rate (float) – Learning rate for Adam optimizer

  • reward_scale (float) – Internal scaling factor for the rewards

  • gamma (float) – Discount factor

  • lambda – Lambda parameter for lambda-returns (0.0 = 1-step TD, 1.0 = Monte Carlo)

  • epsilon_start (float) – Initial exploration rate

  • epsilon_end (float) – Final exploration rate

  • epsilon_decay_steps (int) – Number of environment steps (per env) to decay epsilon from start to end. When using create_config(), pass epsilon_decay_fraction instead and the absolute step count is resolved automatically.

  • max_grad_norm (float) – Maximum gradient norm for clipping

  • num_epochs (int) – Number of training epochs per rollout

  • num_minibatches (int) – Number of minibatches per epoch

  • hidden_size (int) – Hidden layer size for Q-network

  • num_layers (int) – Number of hidden layers in Q-network

  • lr_end (float) – Final learning rate after linear decay. Only used when lr_decay_steps > 0.

  • lr_decay_steps (int) – Total optimizer gradient steps to decay LR from learning_rate to lr_end (0 = disabled, fixed LR). When using create_config(), pass lr_decay_fraction instead and the absolute step count is resolved automatically. Total gradient steps = num_updates × num_minibatches × num_epochs.

Returns:

Agent instance with PQN implementation

Return type:

Agent