"""Plotting utilities for visualizing training and evaluation results."""from__future__importannotationsfromtypingimportTYPE_CHECKINGimportmatplotlib.pyplotaspltimportnumpyasnpfrommatplotlib.axesimportAxesfrommatplotlib.figureimportFigureifTYPE_CHECKING:frommyriad.platform.typesimportTrainingResults
[docs]defplot_training_curve(results:TrainingResults|list[TrainingResults],labels:str|list[str]|None=None,xlabel:str="Steps per Env",# cspell:disable-lineylabel:str="Mean Return",# cspell:disable-linetitle:str|None=None,figsize:tuple[float,float]=(8,4),show_std:bool=True,ax:Axes|None=None,)->tuple[Figure,Axes]:"""Plot training curve(s) showing mean return with optional standard deviation. Args: results: Single TrainingResults or list of results to plot labels: Legend label(s) for the curve(s). If None, uses agent name from config xlabel: Label for x-axis # cspell:disable-line ylabel: Label for y-axis # cspell:disable-line title: Plot title. If None, auto-generates from environment name figsize: Figure size (width, height) in inches show_std: Whether to show standard deviation as shaded region ax: Existing axes to plot on. If None, creates new figure Returns: Tuple of (figure, axes) objects Example: >>> results = train_and_evaluate(config) >>> fig, ax = plot_training_curve(results) >>> plt.show() >>> # Compare multiple runs >>> results_list = [results_dqn, results_ppo] >>> fig, ax = plot_training_curve(results_list, labels=["DQN", "PPO"]) >>> plt.show() """# Normalize inputs to listsresults_list=[results]ifnotisinstance(results,list)elseresults# Handle labelsiflabelsisNone:labels_list=[r.config.agent.name.upper()forrinresults_list]elifisinstance(labels,str):labels_list=[labels]else:labels_list=labelsiflen(labels_list)!=len(results_list):raiseValueError(f"Number of labels ({len(labels_list)}) must match number of results ({len(results_list)})")# Create figure if neededifaxisNone:fig,ax=plt.subplots(figsize=figsize)else:fig_tmp=ax.get_figure()iffig_tmpisNone:raiseValueError("Provided axes must be attached to a figure")# get_figure() can return SubFigure, but we treat it as Figure for our purposesfig=fig_tmp# type: ignore[assignment]# Plot each resultforresult,labelinzip(results_list,labels_list):steps=result.eval_metrics.steps_per_envmean=result.eval_metrics.mean_returnstd=result.eval_metrics.std_return# Plot mean lineax.plot(steps,mean,"o-",label=label)# Add standard deviation bandifshow_std:mean_arr=np.array(mean)std_arr=np.array(std)ax.fill_between(steps,mean_arr-std_arr,mean_arr+std_arr,alpha=0.2)# Formattingax.set_xlabel(xlabel)ax.set_ylabel(ylabel)# Auto-generate title if not providediftitleisNoneandlen(results_list)==1:env_name=results_list[0].config.env.nameagent_name=results_list[0].config.agent.name.upper()title=f"{agent_name} Training on {env_name}"iftitleisnotNone:ax.set_title(title)ax.legend()ax.grid(True,alpha=0.3)# Apply tight_layout if available (SubFigure doesn't have this method)ifhasattr(fig,"tight_layout"):fig.tight_layout()# type: ignore[union-attr]returnfig,ax# type: ignore[return-value]