DQN¶
Deep Q-Network (DQN) agent implementation using JAX and Flax.
A classic value-based RL algorithm that learns to estimate Q-values for state-action pairs and uses epsilon-greedy exploration.
Features:
Experience replay with uniform sampling
Target network for stable learning
Epsilon-greedy exploration with linear decay
Supports soft (Polyak) or hard target network updates
Reference: CleanRL DQN
- class myriad.agents.rl.dqn.QNetwork(action_dim, hidden_dims=(64, 64), parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
ModuleSimple MLP Q-network for discrete action spaces.
- __call__(x)[source]¶
Forward pass to compute Q-values for all actions.
- Parameters:
x (Array) – Observation array of shape (obs_dim,) or (batch_size, obs_dim)
- Returns:
Q-values of shape (action_dim,) or (batch_size, action_dim)
- Return type:
Array
- __init__(action_dim, hidden_dims=(64, 64), parent=<flax.linen.module._Sentinel object>, name=None)¶
- class myriad.agents.rl.dqn.AgentParams(action_space, learning_rate, gamma, epsilon_start, epsilon_end, epsilon_decay_steps, target_network_frequency, tau)[source]¶
Bases:
objectStatic parameters for the DQN agent.
- Variables:
action_space (myriad.core.spaces.Space) – Action space (must be Discrete).
learning_rate (float) – Learning rate for Adam optimizer.
gamma (float) – Discount factor for future rewards.
epsilon_start (float) – Initial exploration rate.
epsilon_end (float) – Final exploration rate after decay.
epsilon_decay_steps (int) – Number of steps to decay epsilon from start to end.
target_network_frequency (int) – Steps between target network updates.
tau (float) – Soft update coefficient (1.0 = hard update, <1.0 = exponential moving average).
- __init__(action_space, learning_rate, gamma, epsilon_start, epsilon_end, epsilon_decay_steps, target_network_frequency, tau)¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- class myriad.agents.rl.dqn.AgentState(train_state, target_params, global_step)[source]¶
Bases:
objectState of the DQN agent.
- Variables:
train_state (flax.training.train_state.TrainState) – Flax TrainState containing network params and optimizer state.
target_params (Any) – Parameters of the target network (lagged copy of online network).
global_step (jax.jaxlib._jax.Array) – Number of update steps taken (used for epsilon decay and target updates).
- train_state: TrainState¶
- global_step: Array¶
- __init__(train_state, target_params, global_step)¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- myriad.agents.rl.dqn.make_agent(action_space, learning_rate=0.001, gamma=0.99, epsilon_start=1.0, epsilon_end=0.05, epsilon_decay_steps=10000, target_network_frequency=500, tau=1.0)[source]¶
Factory function to create a DQN agent.
- Parameters:
action_space (Space) – Action space (must be Discrete)
learning_rate (float) – Learning rate for Adam optimizer
gamma (float) – Discount factor
epsilon_start (float) – Initial exploration rate
epsilon_end (float) – Final exploration rate
epsilon_decay_steps (int) – Steps to decay epsilon from start to end
target_network_frequency (int) – Steps between target network updates
tau (float) – Soft update coefficient (1.0 = hard update)
- Returns:
Agent instance with DQN implementation
- Return type: