import logging
from pathlib import Path
import jax
import jax.numpy as jnp
import numpy as np
from myriad.configs.default import Config
from myriad.core.replay_buffer import ReplayBuffer
from myriad.utils import to_array
from .display import format_train_config
from .initialization import initialize_environment_and_agent
from .logging import SessionLogger
from .metadata import RunMetadata
from .output_dir import format_artifacts_tree, get_or_create_output_dir
from .runners import (
make_chunk_runner,
make_chunked_collector,
make_on_policy_chunk_runner,
)
from .steps import (
make_collection_step_fn,
make_eval_rollout_fn,
make_sample_transition,
make_train_step_fn,
)
from .types import TrainingEnvState, TrainingResults
logger = logging.getLogger(__name__)
def _run_training_loop(config: Config, session_logger: SessionLogger, run_dir: Path) -> TrainingResults:
"""Executes the training loop and returns metrics + trained agent.
Returns:
TrainingResults containing trained agent state, training metrics history,
evaluation metrics history, and configuration.
"""
# Initialize everything
key = jax.random.PRNGKey(config.run.seed)
key, env_key, agent_key, buffer_key = jax.random.split(key, 4)
# Create environment and agent using shared initialization
env, agent, action_space = initialize_environment_and_agent(config)
# Initialize parallel environments
env_keys = jax.random.split(env_key, config.run.num_envs)
obs, env_states = jax.vmap(env.reset, in_axes=(0, None, None))(env_keys, env.params, env.config)
# Convert observations to arrays
obs_array = jax.vmap(to_array)(obs)
training_env_states = TrainingEnvState(env_state=env_states, obs=obs_array)
# Initialize agent using the initial observation from one environment
# Use original NamedTuple observation (not converted array) to allow field introspection
# Extract first element from batched NamedTuple using tree.map
sample_obs = jax.tree.map(lambda x: x[0], obs)
agent_state = agent.init(agent_key, sample_obs, agent.params)
# Build shared jitted primitives first
eval_rollout_fn = make_eval_rollout_fn(agent, env, config.run.eval_rollouts, config.run.eval_max_steps)
chunk_size = max(1, config.run.scan_chunk_size or 1)
# Determine training mode and initialize accordingly
use_rollout_training = config.run.rollout_steps is not None
if use_rollout_training:
# On-policy training (e.g., PPO, A2C, PQN): no replay buffer needed
replay_buffer = None
buffer_state = None
assert config.run.rollout_steps is not None # type narrowing for mypy
# Create chunked collector for efficient rollout collection
collection_step_fn = make_collection_step_fn(agent, env, config.run.num_envs)
rollout_fn = make_chunked_collector(collection_step_fn=collection_step_fn, total_steps=config.run.rollout_steps)
# Create chunk runner that batches multiple rollout-update cycles
run_chunk_fn = make_on_policy_chunk_runner(
rollout_fn=rollout_fn,
agent=agent,
)
else:
# Off-policy training (e.g., DQN): use replay buffer
if config.run.buffer_size is None:
raise ValueError("buffer_size must be set in config for off-policy training (when rollout_steps is None)")
replay_buffer = ReplayBuffer(buffer_size=config.run.buffer_size)
# Convert sample_obs to array for buffer initialization (matches training transition structure)
sample_obs_array = to_array(sample_obs)
sample_transition = make_sample_transition(buffer_key, sample_obs_array, action_space)
buffer_state = replay_buffer.init(sample_transition)
# Create chunk runner that batches multiple step-update cycles
train_step_fn = make_train_step_fn(agent, env, replay_buffer, config.run.num_envs)
run_chunk_fn = make_chunk_runner(train_step_fn, config.run.batch_size)
# Training runs for steps_per_env steps in each environment
steps_per_env = config.run.steps_per_env
eval_frequency = config.run.eval_frequency
steps_completed = 0
# Track latest metrics for periodic log lines
latest_metrics: dict[str, str] = {}
while steps_completed < steps_per_env:
remaining_steps = steps_per_env - steps_completed
# Helper function for boundary alignment
def _steps_until_boundary(current_step: int, frequency: int) -> int:
"""Calculate steps until next logging/eval boundary.
This helper ensures chunks align with logging and evaluation frequencies,
preventing partial metrics from being logged.
Args:
current_step: The current training step counter
frequency: The logging or eval frequency (0 means disabled)
Returns:
Number of steps until the next boundary
"""
if frequency <= 0:
return chunk_size
remainder = current_step % frequency
result = frequency if remainder == 0 else frequency - remainder
return result
# Unified chunked training for both on-policy and off-policy
# Boundary alignment:
# Determine how many steps/cycles to run before the next logging or eval boundary.
# For on-policy: steps = rollout-update cycles
# For off-policy: steps = individual training steps
steps_to_eval = _steps_until_boundary(steps_completed, eval_frequency)
if use_rollout_training:
# On-policy: Calculate number of rollout-update cycles to run.
# chunk_size is in steps; convert to cycles for the scan.
assert config.run.rollout_steps is not None
cycles_per_chunk = max(1, chunk_size // config.run.rollout_steps)
cycles_to_eval = steps_to_eval // config.run.rollout_steps if steps_to_eval > 0 else cycles_per_chunk
cycles_remaining = (remaining_steps + config.run.rollout_steps - 1) // config.run.rollout_steps
num_cycles = min(cycles_per_chunk, cycles_remaining, cycles_to_eval)
# Ensure we run at least one cycle if there are remaining steps
num_cycles = max(1, num_cycles) if remaining_steps > 0 else 0
# Create active mask for cycles
active_mask = (jnp.arange(cycles_per_chunk) < num_cycles).astype(jnp.bool_)
# Run chunked on-policy training
(key, agent_state, training_env_states), metrics_history = run_chunk_fn(
(key, agent_state, training_env_states),
active_mask,
)
# Calculate actual steps completed (num_cycles * rollout_steps, capped at remaining)
steps_this_chunk = min(num_cycles * config.run.rollout_steps, remaining_steps)
else:
# Off-policy: Calculate number of individual training steps to run
steps_this_chunk = min(chunk_size, remaining_steps, steps_to_eval)
# Create a boolean mask for the scan:
# - active_mask always has length chunk_size (for consistent JIT compilation)
# - Only the first steps_this_chunk elements are True
# - Inactive iterations (False elements) execute but don't update state
active_mask = (jnp.arange(chunk_size) < steps_this_chunk).astype(jnp.bool_)
# Run chunked off-policy training
(key, agent_state, training_env_states, buffer_state), metrics_history = run_chunk_fn(
(key, agent_state, training_env_states, buffer_state),
active_mask,
)
steps_completed += steps_this_chunk
global_step = steps_completed * config.run.num_envs
# Extract latest metrics for log line display
if "loss" in metrics_history:
latest_metrics["loss"] = f"{float(jax.device_get(metrics_history['loss'][-1])):.3f}"
if "reward" in metrics_history:
latest_metrics["reward"] = f"{float(jax.device_get(metrics_history['reward'][-1])):.2f}"
# Log and evaluate at each eval_frequency boundary
should_eval = eval_frequency > 0 and steps_completed > 0 and steps_completed % eval_frequency == 0
# Log training metrics (handles both local capture and W&B)
if should_eval:
session_logger.log_training_step(
global_step=global_step,
steps_per_env=steps_completed,
metrics_history=metrics_history,
steps_this_chunk=steps_this_chunk,
)
# Periodically run evaluation rollouts without touching the training buffer
if should_eval:
# Determine if we should save episodes this cycle
should_save_episodes = (
config.run.eval_episode_save_frequency > 0
and steps_completed % config.run.eval_episode_save_frequency == 0
)
# Run eval once with appropriate flag (efficient: no double evaluation)
key, eval_key = jax.random.split(key)
eval_key, eval_results_jax = eval_rollout_fn(eval_key, agent_state, return_episodes=should_save_episodes)
key = eval_key
# Convert to host (handle nested episodes dict if present)
eval_results_host = {}
for name, value in eval_results_jax.items():
if name == "episodes":
eval_results_host[name] = {k: jax.device_get(v) for k, v in value.items()}
else:
eval_results_host[name] = jax.device_get(value)
# Log evaluation (handles metrics + episode saving + W&B in one call)
save_count = config.run.eval_episode_save_count or config.run.eval_rollouts
session_logger.log_evaluation(
global_step=global_step,
steps_per_env=steps_completed,
eval_results=eval_results_host,
save_episodes=should_save_episodes,
episode_save_count=save_count,
)
if "episode_return" in eval_results_host:
mean_return = float(np.mean(eval_results_host["episode_return"])) # type: ignore[call-overload]
latest_metrics["eval_return"] = f"{mean_return:.2f}"
# Print a clean progress line at each eval boundary
if should_eval:
pct = 100 * steps_completed / steps_per_env
metrics_str = " | ".join(f"{k}={v}" for k, v in latest_metrics.items())
suffix = f" | {metrics_str}" if metrics_str else ""
logger.info(f"Step {steps_completed:>{len(str(steps_per_env))}}/{steps_per_env} ({pct:3.0f}%){suffix}")
# Always log the final step if it wasn't just logged
# This ensures training_metrics.global_steps[-1] reflects actual completion
total_env_steps = steps_completed * config.run.num_envs
if eval_frequency > 0 and steps_completed % eval_frequency != 0:
session_logger.log_training_step(
global_step=total_env_steps,
steps_per_env=steps_completed,
metrics_history=metrics_history,
steps_this_chunk=steps_this_chunk,
)
# Get captured metrics and return complete results
training_metrics, eval_metrics = session_logger.get_results()
return TrainingResults(
agent_state=agent_state,
training_metrics=training_metrics,
eval_metrics=eval_metrics,
config=config,
run_dir=run_dir,
final_env_state=training_env_states,
)
[docs]
def train_and_evaluate(config: Config) -> TrainingResults:
"""
Main entry point for a training run.
Initializes everything and runs the outer training loop.
Output directory is automatically managed:
- Under Hydra: uses current directory (Hydra-managed)
- Otherwise: creates timestamped directory in outputs/
Args:
config: Training configuration specifying environment, agent, and run parameters.
Returns:
TrainingResults containing:
- agent_state: Trained agent (ready for inference)
- training_metrics: Training history (loss, reward, etc.)
- eval_metrics: Evaluation history (episode returns, lengths)
- config: Configuration used (for reproducibility)
- final_env_state: Final environment states (can be used to resume training)
"""
logger.info(format_train_config(config))
# Get or create output directory
run_dir = get_or_create_output_dir(None)
# Config will be saved by results.save() to avoid duplicate I/O
session_logger = SessionLogger.for_training(config, run_dir=run_dir)
exit_code = 0
try:
with RunMetadata(run_dir, run_type="training"):
results = _run_training_loop(config, session_logger, run_dir)
# Save artifacts directly
results.save(run_dir, save_checkpoint=config.run.save_agent_checkpoint)
logger.info(format_artifacts_tree(run_dir))
except (KeyboardInterrupt, SystemExit):
raise # intentional stop — exit_code stays 0
except BaseException:
exit_code = 1
raise
finally:
session_logger.finalize(exit_code=exit_code)
return results