"""Config builder utilities for programmatic use.
This module provides high-level functions to create training and evaluation
configs without requiring detailed knowledge of Pydantic models.
"""
from typing import Any
from myriad.agents import get_agent_info
from myriad.core.types import BaseModel
from myriad.envs import EnvInfo, get_env_info
from .default import AgentConfig, Config, EnvConfig, EvalConfig, EvalRunConfig, RunConfig, WandbConfig
# Default rollout steps for on-policy agents if not specified
DEFAULT_ON_POLICY_ROLLOUT_STEPS = 2
def _distribute_kwargs(
kwargs: dict[str, Any],
run_cls: type[BaseModel],
) -> tuple[dict, dict, dict, dict]:
"""Distributes flattened kwargs into nested sections based on Pydantic models.
Args:
kwargs: The dictionary of keyword arguments to distribute.
run_cls: The specific RunConfig class to use for inference (:class:`RunConfig` or :class:`EvalRunConfig`).
Returns:
A tuple of (env_kwargs, agent_kwargs, run_kwargs, wandb_kwargs).
"""
sections: dict[str, dict[str, Any]] = {
"env": {},
"agent": {},
"run": {},
"wandb": {},
}
# Model classes for automated parameter inference
inference_models: dict[str, type[BaseModel]] = {
"run": run_cls,
"wandb": WandbConfig,
"agent": AgentConfig,
}
for key, value in kwargs.items():
# 1. Explicit dot notation (e.g., "agent.learning_rate")
if "." in key:
prefix, attr = key.split(".", 1)
if prefix in sections:
sections[prefix][attr] = value
continue
# 2. Nested dicts (e.g., wandb={"project": "myriad"})
if isinstance(value, dict) and key in sections:
sections[key].update(value)
continue
# 3. Inference based on model fields
# We check Run and Wandb first since they have fixed schemas.
# AgentConfig has extra="allow", so we only check its explicit fields.
found = False
for name, cls in inference_models.items():
if key in cls.model_fields:
sections[name][key] = value
found = True
break
if not found:
# Default to agent config
sections["agent"][key] = value
return sections["env"], sections["agent"], sections["run"], sections["wandb"]
def _resolve_eval_max_steps(eval_max_steps: int | None, env_info: EnvInfo | None) -> int | None:
"""Resolves eval_max_steps from explicit > registry config_cls > model defaults."""
if eval_max_steps is not None:
return eval_max_steps
if env_info:
# Instantiate environment config with defaults to get its max_steps property
try:
default_env_config = env_info.config_cls()
return getattr(default_env_config, "max_steps", None)
except (TypeError, AttributeError):
pass
return None
[docs]
def create_config(
env: str,
agent: str,
num_envs: int = 1,
steps_per_env: int = 1000,
rollout_steps: int | None = None,
eval_max_steps: int | None = None,
eval_frequency: int = 100,
eval_rollouts: int = 10,
seed: int = 42,
wandb_enabled: bool = False,
**kwargs: Any,
) -> Config:
"""Create a training config with sensible defaults.
This is the recommended way to create configs programmatically.
It provides a simpler interface than constructing nested Pydantic models.
Args:
env: Environment name (e.g., "cartpole-control", "ccas-ccar-control")
agent: Agent name (e.g., "dqn", "pqn", "random")
num_envs: Number of parallel environments to run
steps_per_env: Number of steps to run per environment
rollout_steps: Number of steps to collect per environment before updating
(for on-policy agents only). If None, defaults to 2 for on-policy agents.
eval_max_steps: Maximum steps per evaluation episode.
If None, uses environment-specific default from registry or Config models.
eval_frequency: Log and evaluate every N steps-per-env (0 to disable)
eval_rollouts: Number of episodes to run during evaluation
seed: Random seed for reproducibility
wandb_enabled: Enable Weights & Biases logging
**kwargs: Additional config overrides. Can specify nested parameters using
dot notation (e.g., ``agent.learning_rate=1e-3``) or pass dicts for
nested configs (e.g., ``wandb={"project": "my-project"}``).
Returns:
Fully configured Config object ready for :func:`~myriad.platform.train_and_evaluate`
"""
# Look up agent and environment info
agent_info = get_agent_info(agent)
env_info = get_env_info(env)
# Distribute nested overrides
env_kwargs, agent_kwargs, run_kwargs, wandb_kwargs = _distribute_kwargs(kwargs, RunConfig)
# Auto-configure training mode based on agent type
if rollout_steps is None:
# Check if specified in run_kwargs via dot notation or dict
rollout_steps = run_kwargs.get("rollout_steps")
if rollout_steps is None and agent_info and agent_info.is_on_policy:
rollout_steps = DEFAULT_ON_POLICY_ROLLOUT_STEPS
# Build run config with merged params: explicit > model defaults
run_params: dict[str, Any] = {
"seed": seed,
"num_envs": num_envs,
"steps_per_env": steps_per_env,
"rollout_steps": rollout_steps,
"eval_frequency": eval_frequency,
"eval_rollouts": eval_rollouts,
"eval_max_steps": _resolve_eval_max_steps(eval_max_steps, env_info),
**run_kwargs,
}
# Clean up Nones so Pydantic uses its own field defaults where applicable
run_params = {k: v for k, v in run_params.items() if v is not None}
run_config = RunConfig(**run_params)
# Build other configs
wandb_params: dict[str, Any] = {"enabled": wandb_enabled, **wandb_kwargs}
return Config(
env=EnvConfig(name=env, **env_kwargs),
agent=AgentConfig(name=agent, **agent_kwargs),
run=run_config,
wandb=WandbConfig(**wandb_params),
)
[docs]
def create_eval_config(
env: str,
agent: str,
eval_rollouts: int = 10,
eval_max_steps: int | None = None,
seed: int = 42,
wandb_enabled: bool = False,
**kwargs: Any,
) -> EvalConfig:
"""Create an evaluation-only config with sensible defaults.
Use this for evaluating non-learning controllers (random, PID, bang-bang)
or pre-trained models without any training.
Args:
env: Environment name (e.g., "cartpole-control")
agent: Agent name (e.g., "random", "dqn")
eval_rollouts: Number of episodes to evaluate
eval_max_steps: Maximum steps per episode.
If None, uses environment-specific default from registry or Config models.
seed: Random seed for reproducibility
wandb_enabled: Enable Weights & Biases logging
**kwargs: Additional config overrides (same as create_config)
Returns:
Fully configured EvalConfig object ready for :func:`~myriad.platform.evaluate`
"""
# Look up environment info
env_info = get_env_info(env)
# Distribute nested overrides
env_kwargs, agent_kwargs, run_kwargs, wandb_kwargs = _distribute_kwargs(kwargs, EvalRunConfig)
# Build run config with merged params: explicit > model defaults
run_params: dict[str, Any] = {
"seed": seed,
"eval_rollouts": eval_rollouts,
"eval_max_steps": _resolve_eval_max_steps(eval_max_steps, env_info),
**run_kwargs,
}
# Clean up Nones so Pydantic uses its own field defaults where applicable
run_params = {k: v for k, v in run_params.items() if v is not None}
eval_run_config = EvalRunConfig(**run_params)
# Build other configs
wandb_params: dict[str, Any] = {"enabled": wandb_enabled, **wandb_kwargs}
return EvalConfig(
env=EnvConfig(name=env, **env_kwargs),
agent=AgentConfig(name=agent, **agent_kwargs),
run=eval_run_config,
wandb=WandbConfig(**wandb_params),
)