"""Loading API for run artifacts.Provides utilities to load configs, results, checkpoints, and metadata fromcompleted training/evaluation runs."""from__future__importannotationsimportpicklefromdataclassesimportdataclassfrompathlibimportPathfromtypingimportAny,Generic,TypeVarimportyamlfromomegaconfimportOmegaConffrommyriad.configs.defaultimportConfig,EvalConfigfrommyriad.platform.typesimportEvaluationResults,TrainingResultsfrom.constantsimport(CHECKPOINT_EXTENSION,CHECKPOINTS_DIR,FINAL_CHECKPOINT_NAME,METADATA_FILENAME,RESULTS_FILENAME,)from.serializationimportload_agent_state# Generic type variables for RunArtifactsConfigT=TypeVar("ConfigT",bound=Config|EvalConfig)ResultsT=TypeVar("ResultsT",bound=TrainingResults|EvaluationResults)
[docs]defload_run_config(run_path:str|Path)->Config|EvalConfig:"""Load config from run directory. Loads from .hydra/config.yaml and validates with Pydantic. Requires run_metadata.yaml to determine config type. Args: run_path: Path to run directory Returns: Config or EvalConfig depending on run type Raises: FileNotFoundError: If config.yaml or run_metadata.yaml not found RuntimeError: If run_type field missing from metadata Example: >>> config = load_run_config("outputs/2026-02-12/14-30-52") >>> print(config.run.seed) """run_path=Path(run_path)config_path=run_path/".hydra"/"config.yaml"ifnotconfig_path.exists():raiseFileNotFoundError(f"No config.yaml found in {run_path}/.hydra/")# Load config using OmegaConfcfg=OmegaConf.load(config_path)config_dict=OmegaConf.to_object(cfg)# Load metadata to determine run type (mandatory)metadata=load_run_metadata(run_path)if"run_type"notinmetadata:raiseRuntimeError(f"Missing 'run_type' field in {run_path}/{METADATA_FILENAME}. "f"Cannot determine whether to load Config or EvalConfig.")run_type=metadata["run_type"]# Validate with appropriate Pydantic modelifrun_type=="training":returnConfig.model_validate(config_dict)else:returnEvalConfig.model_validate(config_dict)
[docs]defload_run_results(run_path:str|Path)->TrainingResults|EvaluationResults:"""Load results from run directory. Args: run_path: Path to run directory Returns: TrainingResults or EvaluationResults Example: >>> results = load_run_results("outputs/2026-02-12/14-30-52") >>> print(results.summary()) """run_path=Path(run_path)results_path=run_path/RESULTS_FILENAMEifnotresults_path.exists():raiseFileNotFoundError(f"No {RESULTS_FILENAME} found in {run_path}")withopen(results_path,"rb")asf:returnpickle.load(f)
[docs]defload_run_checkpoint(run_path:str|Path,checkpoint:str=FINAL_CHECKPOINT_NAME,)->Any:"""Load agent checkpoint from run directory. Args: run_path: Path to run directory checkpoint: Checkpoint name (default: "final") Returns: Agent state from checkpoint Raises: FileNotFoundError: If checkpoint file not found RuntimeError: If deserialization fails Example: >>> agent_state = load_run_checkpoint("outputs/2026-02-12/14-30-52") >>> # Use with evaluate() >>> results = evaluate(config, agent_state=agent_state) """run_path=Path(run_path)checkpoint_path=run_path/CHECKPOINTS_DIR/f"{checkpoint}{CHECKPOINT_EXTENSION}"ifnotcheckpoint_path.exists():raiseFileNotFoundError(f"No checkpoint '{checkpoint}' in {run_path}/{CHECKPOINTS_DIR}/")returnload_agent_state(checkpoint_path)
[docs]defload_run_metadata(run_path:str|Path)->dict:"""Load run metadata from run directory. Args: run_path: Path to run directory Returns: Dictionary with metadata (run_type, timestamp, git_hash, versions) Raises: FileNotFoundError: If metadata file not found Example: >>> metadata = load_run_metadata("outputs/2026-02-12/14-30-52") >>> print(metadata["git_hash"]) """run_path=Path(run_path)metadata_path=run_path/METADATA_FILENAMEifnotmetadata_path.exists():raiseFileNotFoundError(f"No {METADATA_FILENAME} in {run_path}. Run metadata is required "f"to determine run type and configuration.")withopen(metadata_path)asf:returnyaml.safe_load(f)
[docs]@dataclassclassRunArtifacts(Generic[ConfigT,ResultsT]):"""Container for all artifacts from a run. Provides a unified interface to access configs, results, metadata, and optionally load checkpoints. Type parameters: ConfigT: Config or EvalConfig ResultsT: TrainingResults or EvaluationResults """config:ConfigT"""Configuration used for this run."""results:ResultsT"""Results from the run."""metadata:dict"""Run metadata (timestamp, git hash, versions)."""run_path:Path"""Path to the run directory."""
[docs]defload_checkpoint(self,checkpoint:str=FINAL_CHECKPOINT_NAME)->Any:"""Load agent checkpoint from disk. Always loads fresh from disk (no caching). Args: checkpoint: Checkpoint name (default: "final") Returns: Agent state from checkpoint Raises: FileNotFoundError: If checkpoint file not found RuntimeError: If deserialization fails """returnload_run_checkpoint(self.run_path,checkpoint)
[docs]defload_run(run_path:str|Path)->RunArtifacts:"""Load all artifacts from a run directory. This is the main entry point for loading runs. It loads config, results, and metadata in one call. Agent checkpoints can be loaded on demand. Args: run_path: Path to run directory Returns: RunArtifacts container with all run data Example: >>> run = load_run("outputs/2026-02-12/14-30-52") >>> print(f"Final return: {run.results.summary()['mean_return']}") >>> agent = run.load_checkpoint() # Lazy load if needed """run_path=Path(run_path)returnRunArtifacts(config=load_run_config(run_path),results=load_run_results(run_path),metadata=load_run_metadata(run_path),run_path=run_path,)