DQN

Deep Q-Network (DQN) agent implementation using JAX and Flax.

A classic value-based RL algorithm that learns to estimate Q-values for state-action pairs and uses epsilon-greedy exploration.

Features:

  • Experience replay with uniform sampling

  • Target network for stable learning

  • Epsilon-greedy exploration with linear decay

  • Supports soft (Polyak) or hard target network updates

Reference: CleanRL DQN

class myriad.agents.rl.dqn.QNetwork(action_dim, hidden_dims=(64, 64), parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

Simple MLP Q-network for discrete action spaces.

action_dim: int
hidden_dims: tuple[int, ...] = (64, 64)
__call__(x)[source]

Forward pass to compute Q-values for all actions.

Parameters:

x (Array) – Observation array of shape (obs_dim,) or (batch_size, obs_dim)

Returns:

Q-values of shape (action_dim,) or (batch_size, action_dim)

Return type:

Array

__init__(action_dim, hidden_dims=(64, 64), parent=<flax.linen.module._Sentinel object>, name=None)
name: str | None = None
parent: Module | Scope | _Sentinel | None = None
scope: Scope | None = None
class myriad.agents.rl.dqn.AgentParams(action_space, learning_rate, gamma, epsilon_start, epsilon_end, epsilon_decay_steps, target_network_frequency, tau)[source]

Bases: object

Static parameters for the DQN agent.

Variables:
  • action_space (myriad.core.spaces.Space) – Action space (must be Discrete).

  • learning_rate (float) – Learning rate for Adam optimizer.

  • gamma (float) – Discount factor for future rewards.

  • epsilon_start (float) – Initial exploration rate.

  • epsilon_end (float) – Final exploration rate after decay.

  • epsilon_decay_steps (int) – Number of steps to decay epsilon from start to end.

  • target_network_frequency (int) – Steps between target network updates.

  • tau (float) – Soft update coefficient (1.0 = hard update, <1.0 = exponential moving average).

action_space: Space
learning_rate: float
gamma: float
epsilon_start: float
epsilon_end: float
epsilon_decay_steps: int
target_network_frequency: int
tau: float
__init__(action_space, learning_rate, gamma, epsilon_start, epsilon_end, epsilon_decay_steps, target_network_frequency, tau)
replace(**updates)

Returns a new object replacing the specified fields with new values.

class myriad.agents.rl.dqn.AgentState(train_state, target_params, global_step)[source]

Bases: object

State of the DQN agent.

Variables:
  • train_state (flax.training.train_state.TrainState) – Flax TrainState containing network params and optimizer state.

  • target_params (Any) – Parameters of the target network (lagged copy of online network).

  • global_step (jax.jaxlib._jax.Array) – Number of update steps taken (used for epsilon decay and target updates).

train_state: TrainState
target_params: Any
global_step: Array
__init__(train_state, target_params, global_step)
replace(**updates)

Returns a new object replacing the specified fields with new values.

myriad.agents.rl.dqn.make_agent(action_space, learning_rate=0.001, gamma=0.99, epsilon_start=1.0, epsilon_end=0.05, epsilon_decay_steps=10000, target_network_frequency=500, tau=1.0)[source]

Factory function to create a DQN agent.

Parameters:
  • action_space (Space) – Action space (must be Discrete)

  • learning_rate (float) – Learning rate for Adam optimizer

  • gamma (float) – Discount factor

  • epsilon_start (float) – Initial exploration rate

  • epsilon_end (float) – Final exploration rate

  • epsilon_decay_steps (int) – Steps to decay epsilon from start to end

  • target_network_frequency (int) – Steps between target network updates

  • tau (float) – Soft update coefficient (1.0 = hard update)

Returns:

Agent instance with DQN implementation

Return type:

Agent