Source code for myriad.agents.rl.dqn

"""Deep Q-Network (DQN) agent implementation using JAX and Flax.

A classic value-based RL algorithm that learns to estimate Q-values for state-action
pairs and uses epsilon-greedy exploration.

Features:

- Experience replay with uniform sampling
- Target network for stable learning
- Epsilon-greedy exploration with linear decay
- Supports soft (Polyak) or hard target network updates

Reference: `CleanRL DQN <https://docs.cleanrl.dev/rl-algorithms/dqn/>`_
"""

from typing import Any, Tuple

import jax
import jax.numpy as jnp
import optax
from flax import linen as nn, struct
from flax.training.train_state import TrainState
from jax import Array

from myriad.core.spaces import Discrete, Space
from myriad.core.types import Observation, PRNGKey, Transition
from myriad.utils.observations import to_array

from ..agent import Agent


[docs] class QNetwork(nn.Module): """Simple MLP Q-network for discrete action spaces.""" action_dim: int hidden_dims: tuple[int, ...] = (64, 64)
[docs] @nn.compact def __call__(self, x: Array) -> Array: """Forward pass to compute Q-values for all actions. Args: x: Observation array of shape (obs_dim,) or (batch_size, obs_dim) Returns: Q-values of shape (action_dim,) or (batch_size, action_dim) """ for hidden_dim in self.hidden_dims: x = nn.Dense(hidden_dim)(x) x = nn.relu(x) x = nn.Dense(self.action_dim)(x) return x
[docs] @struct.dataclass class AgentParams: """Static parameters for the DQN agent. Attributes: action_space: Action space (must be Discrete). learning_rate: Learning rate for Adam optimizer. gamma: Discount factor for future rewards. epsilon_start: Initial exploration rate. epsilon_end: Final exploration rate after decay. epsilon_decay_steps: Number of steps to decay epsilon from start to end. target_network_frequency: Steps between target network updates. tau: Soft update coefficient (1.0 = hard update, <1.0 = exponential moving average). """ action_space: Space learning_rate: float gamma: float epsilon_start: float epsilon_end: float epsilon_decay_steps: int target_network_frequency: int tau: float
[docs] @struct.dataclass class AgentState: """State of the DQN agent. Attributes: train_state: Flax TrainState containing network params and optimizer state. target_params: Parameters of the target network (lagged copy of online network). global_step: Number of update steps taken (used for epsilon decay and target updates). """ train_state: TrainState target_params: Any global_step: Array
def _init( key: PRNGKey, sample_obs: Observation, params: AgentParams, ) -> AgentState: """Initialize the DQN agent. Args: key: Random key for initialization sample_obs: Sample observation to infer network architecture (can be array or NamedTuple) params: Agent hyperparameters Returns: Initial agent state containing networks and optimizer """ # Convert observation to array if needed sample_obs_array = to_array(sample_obs) if not isinstance(params.action_space, Discrete): raise ValueError("DQN only supports Discrete action spaces") action_dim = params.action_space.n # Initialize Q-network q_network = QNetwork(action_dim=action_dim) q_params = q_network.init(key, sample_obs_array) # Create training state with optimizer optimizer = optax.adam(params.learning_rate) train_state = TrainState.create( apply_fn=q_network.apply, params=q_params, tx=optimizer, ) # Initialize target network with same parameters target_params = jax.tree_util.tree_map(lambda x: x.copy(), q_params) return AgentState( train_state=train_state, target_params=target_params, global_step=jnp.array(0, dtype=jnp.int32), ) def _select_action( key: PRNGKey, obs: Observation, state: AgentState, params: AgentParams, deterministic: bool, ) -> Tuple[Array, AgentState]: """Select action using epsilon-greedy policy. Args: key: Random key for exploration obs: Current observation (can be array or NamedTuple with .to_array() method) state: Current agent state params: Agent hyperparameters deterministic: If True, use greedy policy (epsilon=0). Default False. Returns: Tuple of (action, unchanged agent_state) """ # Convert observation to array if needed obs_array = to_array(obs) # Calculate current epsilon with linear decay (or use 0 if deterministic) epsilon_decayed = jnp.maximum( params.epsilon_end, params.epsilon_start - (params.epsilon_start - params.epsilon_end) * state.global_step / params.epsilon_decay_steps, ) epsilon = jax.lax.select(deterministic, jnp.array(0.0), epsilon_decayed) # Get Q-values q_values = state.train_state.apply_fn(state.train_state.params, obs_array) # Epsilon-greedy action selection key_explore, key_action = jax.random.split(key) explore = jax.random.uniform(key_explore) < epsilon # Greedy action (argmax Q) greedy_action = jnp.argmax(q_values) # Random action random_action = params.action_space.sample(key_action) # Select based on epsilon action = jax.lax.select(explore, random_action, greedy_action) return action, state def _update( key: PRNGKey, state: AgentState, batch: Transition, params: AgentParams, ) -> Tuple[AgentState, dict]: """Update the agent using a batch of transitions. Args: key: Random key (unused in DQN) state: Current agent state batch: Batch of transitions from replay buffer params: Agent hyperparameters Returns: Tuple of (new agent_state, metrics dict) """ def loss_fn(q_params): """Compute TD loss for Q-network.""" # Current Q-values: Q(s, a) q_values = state.train_state.apply_fn(q_params, batch.obs) # Select Q-values for actions taken actions_expanded = jnp.asarray(batch.action)[:, None] q_values_selected = jnp.take_along_axis(q_values, actions_expanded, axis=1).squeeze(1) # Target Q-values: r + gamma * max_a' Q_target(s', a') next_q_values = state.train_state.apply_fn(state.target_params, batch.next_obs) next_q_max = jnp.max(next_q_values, axis=1) # TD target (no gradient through target) rewards = jnp.asarray(batch.reward) dones = jnp.asarray(batch.done, dtype=jnp.float32) td_target = rewards + params.gamma * next_q_max * (1.0 - dones) td_target = jax.lax.stop_gradient(td_target) # MSE loss td_error = q_values_selected - td_target loss = jnp.mean(td_error**2) return loss, {"td_error_mean": jnp.mean(jnp.abs(td_error)), "q_value_mean": jnp.mean(q_values_selected)} # Compute gradients and update (loss, aux_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.train_state.params) new_train_state = state.train_state.apply_gradients(grads=grads) # Update target network periodically should_update_target = (state.global_step.astype(jnp.int32) % params.target_network_frequency) == 0 if params.tau == 1.0: # Hard update new_target_params = jax.lax.cond( should_update_target, lambda: jax.tree_util.tree_map(lambda x: x.copy(), new_train_state.params), lambda: state.target_params, ) else: # Soft update: target = tau * online + (1 - tau) * target def soft_update(): return jax.tree_util.tree_map( lambda online, target: params.tau * online + (1.0 - params.tau) * target, new_train_state.params, state.target_params, ) new_target_params = jax.lax.cond( should_update_target, soft_update, lambda: state.target_params, ) new_agent_state = AgentState( train_state=new_train_state, target_params=new_target_params, global_step=state.global_step + 1, ) metrics = { "loss": loss, "td_error": aux_metrics["td_error_mean"], "q_value": aux_metrics["q_value_mean"], } return new_agent_state, metrics
[docs] def make_agent( action_space: Space, learning_rate: float = 1e-3, gamma: float = 0.99, epsilon_start: float = 1.0, epsilon_end: float = 0.05, epsilon_decay_steps: int = 10000, target_network_frequency: int = 500, tau: float = 1.0, ) -> Agent: """Factory function to create a DQN agent. Args: action_space: Action space (must be Discrete) learning_rate: Learning rate for Adam optimizer gamma: Discount factor epsilon_start: Initial exploration rate epsilon_end: Final exploration rate epsilon_decay_steps: Steps to decay epsilon from start to end target_network_frequency: Steps between target network updates tau: Soft update coefficient (1.0 = hard update) Returns: Agent instance with DQN implementation """ params = AgentParams( action_space=action_space, learning_rate=learning_rate, gamma=gamma, epsilon_start=epsilon_start, epsilon_end=epsilon_end, epsilon_decay_steps=epsilon_decay_steps, target_network_frequency=target_network_frequency, tau=tau, ) return Agent( params=params, init=_init, select_action=_select_action, update=_update, )