PQN¶
Parallelized Q-Network (PQN) with LayerNorm agent implementation using JAX and Flax.
PQN is an on-policy value-based RL algorithm designed for massively parallel training.
Features:
Lambda-returns (GAE-style) instead of 1-step TD targets
LayerNorm for stability (no target network needed)
Multi-epoch training on collected rollouts
Epsilon-greedy exploration with linear decay
Gradient clipping for stable training
Reference: PureJaxQL
- class myriad.agents.rl.pqn.QNetwork(action_dim, hidden_size=128, num_layers=2, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
ModuleMLP Q-network with LayerNorm for discrete action spaces.
- __call__(x)[source]¶
Forward pass to compute Q-values for all actions.
- Parameters:
x (Array) – Observation array of shape (obs_dim,) or (batch_size, obs_dim)
- Returns:
Q-values of shape (action_dim,) or (batch_size, action_dim)
- Return type:
Array
- __init__(action_dim, hidden_size=128, num_layers=2, parent=<flax.linen.module._Sentinel object>, name=None)¶
- class myriad.agents.rl.pqn.AgentParams(action_space, learning_rate, gamma, lambda_, reward_scale, epsilon_start, epsilon_end, epsilon_decay_steps, max_grad_norm, num_epochs, num_minibatches, hidden_size, num_layers, lr_end, lr_decay_steps)[source]¶
Bases:
objectStatic parameters for the PQN agent.
- Variables:
action_space (myriad.core.spaces.Space) – Action space (must be Discrete).
learning_rate (float) – Learning rate for Adam optimizer.
gamma (float) – Discount factor for future rewards.
lambda (float) – Lambda parameter for lambda-returns (0.0 = 1-step TD, 1.0 = Monte Carlo).
reward_scale (float) – Internal scaling factor applied to rewards before computing returns.
epsilon_start (float) – Initial exploration rate.
epsilon_end (float) – Final exploration rate after decay.
epsilon_decay_steps (int) – Number of environment steps (per env) to decay epsilon from start to end.
max_grad_norm (float) – Maximum gradient norm for clipping.
num_epochs (int) – Number of training epochs per rollout batch.
num_minibatches (int) – Number of minibatches per epoch.
hidden_size (int) – Hidden layer size for Q-network.
num_layers (int) – Number of hidden layers in Q-network.
lr_end (float) – Final learning rate after linear decay (only used when lr_decay_steps > 0).
lr_decay_steps (int) – Total optimizer gradient steps to decay LR over (0 = disabled).
- __init__(action_space, learning_rate, gamma, lambda_, reward_scale, epsilon_start, epsilon_end, epsilon_decay_steps, max_grad_norm, num_epochs, num_minibatches, hidden_size, num_layers, lr_end, lr_decay_steps)¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- class myriad.agents.rl.pqn.AgentState(train_state, global_step)[source]¶
Bases:
objectState of the PQN agent.
- Variables:
train_state (flax.training.train_state.TrainState) – Flax TrainState containing network params and optimizer state.
global_step (jax.jaxlib._jax.Array) – Number of environment steps taken per env (used for epsilon decay).
- train_state: TrainState¶
- global_step: Array¶
- __init__(train_state, global_step)¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- myriad.agents.rl.pqn.make_agent(action_space, learning_rate=0.00025, reward_scale=1.0, gamma=0.99, lambda_=0.65, epsilon_start=1.0, epsilon_end=0.05, epsilon_decay_steps=50000, max_grad_norm=0.5, num_epochs=4, num_minibatches=4, hidden_size=128, num_layers=2, lr_end=1e-20, lr_decay_steps=0)[source]¶
Factory function to create a PQN agent.
- Parameters:
action_space (Space) – Action space (must be Discrete)
learning_rate (float) – Learning rate for Adam optimizer
reward_scale (float) – Internal scaling factor for the rewards
gamma (float) – Discount factor
lambda – Lambda parameter for lambda-returns (0.0 = 1-step TD, 1.0 = Monte Carlo)
epsilon_start (float) – Initial exploration rate
epsilon_end (float) – Final exploration rate
epsilon_decay_steps (int) – Number of environment steps (per env) to decay epsilon from start to end. When using
create_config(), passepsilon_decay_fractioninstead and the absolute step count is resolved automatically.max_grad_norm (float) – Maximum gradient norm for clipping
num_epochs (int) – Number of training epochs per rollout
num_minibatches (int) – Number of minibatches per epoch
hidden_size (int) – Hidden layer size for Q-network
num_layers (int) – Number of hidden layers in Q-network
lr_end (float) – Final learning rate after linear decay. Only used when
lr_decay_steps > 0.lr_decay_steps (int) – Total optimizer gradient steps to decay LR from
learning_ratetolr_end(0 = disabled, fixed LR). When usingcreate_config(), passlr_decay_fractioninstead and the absolute step count is resolved automatically. Total gradient steps =num_updates × num_minibatches × num_epochs.
- Returns:
Agent instance with PQN implementation
- Return type: