Source code for myriad.platform.evaluation

"""Evaluation-only functionality for trained and non-learning agents.

This module provides evaluation capabilities for:
- Non-learning controllers (random, bang-bang, PID)
- Pre-trained models
- Baseline comparisons
- Benchmarking and validation
"""

from __future__ import annotations

import logging

import jax
import numpy as np

from myriad.agents.agent import AgentState
from myriad.configs.default import EvalConfig

from .display import format_eval_config
from .initialization import initialize_environment_and_agent
from .logging import SessionLogger
from .metadata import RunMetadata
from .output_dir import format_artifacts_tree, get_or_create_output_dir
from .steps import make_eval_rollout_fn
from .types import EvaluationResults

logger = logging.getLogger(__name__)


[docs] def evaluate( config: EvalConfig, agent_state: AgentState | None = None, return_episodes: bool = False, save_episodes_to_disk_flag: bool | None = None, ) -> EvaluationResults: """ Evaluation-only entry point (no training). Useful for: - Non-learning controllers (random, bang-bang, PID) - Pre-trained models - Baseline comparisons - Benchmarking and validation Output directory is automatically managed: - Under Hydra: uses current directory (Hydra-managed) - Otherwise: creates timestamped directory in outputs/ Args: config: EvalConfig specifying environment, agent, and evaluation parameters. Use config_to_eval_config() to convert a training Config if needed. agent_state: Optional pre-initialized agent state. If None, agent will be initialized with random weights using config.run.seed. return_episodes: If True, return full episode trajectories in EvaluationResults.episodes. This includes observations, actions, rewards, and dones for each step. save_episodes_to_disk_flag: If True, save episodes to disk (respects config settings). If None, infers from config.run.eval_episode_save_frequency. Episodes can be saved to disk without keeping them in memory (return_episodes=False). Returns: EvaluationResults containing: - Summary statistics (mean_return, std_return, min, max) - Raw episode data (episode_returns, episode_lengths) - Optional trajectory data (if return_episodes=True) - Metadata (num_episodes, seed) """ logger.info(format_eval_config(config)) # Get or create output directory run_dir = get_or_create_output_dir(None) # Config will be saved by results.save() to avoid duplicate I/O # Create unified logger (handles W&B init/close automatically) session_logger = SessionLogger.for_evaluation(config, run_dir=run_dir) exit_code = 0 try: with RunMetadata(run_dir, run_type="evaluation"): # Extract evaluation settings seed, eval_rollouts, eval_max_steps = config.run.seed, config.run.eval_rollouts, config.run.eval_max_steps # Determine if we should save episodes to disk if save_episodes_to_disk_flag is None: save_episodes_to_disk_flag = config.run.eval_episode_save_frequency > 0 # Collect episodes if we need them for memory return OR for disk saving collect_episodes = return_episodes or save_episodes_to_disk_flag # Initialize RNG key = jax.random.PRNGKey(seed) key, env_key, agent_key = jax.random.split(key, 3) # Create environment and agent using shared initialization env, agent, _ = initialize_environment_and_agent(config) # Initialize agent state if not provided if agent_state is None: obs, _ = env.reset(env_key, env.params, env.config) agent_state = agent.init(agent_key, obs, agent.params) # Create and run evaluation rollout eval_rollout_fn = make_eval_rollout_fn(agent, env, eval_rollouts, eval_max_steps) key, eval_key = jax.random.split(key) eval_key, eval_results_jax = eval_rollout_fn(eval_key, agent_state, return_episodes=collect_episodes) # Convert results from device to host episode_returns = jax.device_get(eval_results_jax["episode_return"]) episode_lengths = jax.device_get(eval_results_jax["episode_length"]) # Convert episodes if collected episodes_data = None if "episodes" in eval_results_jax: episodes_data = {k: jax.device_get(v) for k, v in eval_results_jax["episodes"].items()} # Compute summary statistics results = EvaluationResults( mean_return=float(np.mean(episode_returns)), std_return=float(np.std(episode_returns)), min_return=float(np.min(episode_returns)), max_return=float(np.max(episode_returns)), mean_length=float(np.mean(episode_lengths)), std_length=float(np.std(episode_lengths)), min_length=int(np.min(episode_lengths)), max_length=int(np.max(episode_lengths)), episode_returns=episode_returns, episode_lengths=episode_lengths, num_episodes=eval_rollouts, seed=seed, config=config, # Store config for reproducibility episodes=episodes_data if return_episodes else None, agent_state=agent_state, # Store agent state for potential checkpoint saving run_dir=run_dir, # Store output directory for tests and inspection ) # Log evaluation with single unified call # This handles: metrics capture, disk saving, W&B logging, artifact upload eval_results_dict = { "episode_return": episode_returns, "episode_length": episode_lengths, "dones": jax.device_get(eval_results_jax.get("dones", np.ones(eval_rollouts, dtype=bool))), } if episodes_data is not None: eval_results_dict["episodes"] = episodes_data save_count = config.run.eval_episode_save_count or eval_rollouts session_logger.log_evaluation( global_step=0, steps_per_env=0, eval_results=eval_results_dict, save_episodes=save_episodes_to_disk_flag, episode_save_count=save_count, ) # Save artifacts directly results.save(run_dir, save_checkpoint=config.run.save_agent_checkpoint) logger.info(format_artifacts_tree(run_dir)) except (KeyboardInterrupt, SystemExit): raise # intentional stop — exit_code stays 0 except BaseException: exit_code = 1 raise finally: session_logger.finalize(exit_code=exit_code) return results