"""W&B helper utilities for fetching and inspecting runs and sweeps.
Useful both internally (seed-eval pipeline) and interactively in notebooks.
"""
import warnings
from typing import Any
import polars as pl
import wandb # type: ignore[import]
from myriad.configs.default import Config, WandbConfig
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _unwrap_wandb_value(obj: Any) -> Any:
"""Recursively unwrap W&B ``{"value": x}`` wrappers and drop ``_``-prefixed keys."""
if isinstance(obj, dict):
if list(obj.keys()) == ["value"]:
return _unwrap_wandb_value(obj["value"])
return {k: _unwrap_wandb_value(v) for k, v in obj.items() if not k.startswith("_")}
return obj
def _unflatten_dotted_keys(d: dict[str, Any]) -> dict[str, Any]:
"""Convert flat dot-separated keys into a nested dict.
W&B sweep agents store hyperparameters with dotted keys (e.g. ``agent.lr``)
rather than nested dicts. This undoes that flattening so Pydantic can
validate the result as a ``Config``.
Already-nested values are merged in place, so a mix of flat and nested keys
is handled correctly.
"""
out: dict[str, Any] = {}
for key, value in d.items():
parts = key.split(".")
node = out
for part in parts[:-1]:
if part not in node or not isinstance(node[part], dict):
node[part] = {}
node = node[part]
leaf = parts[-1]
# If value is itself a nested dict, recurse so inner dotted keys are also handled.
unflattened = _unflatten_dotted_keys(value) if isinstance(value, dict) else value
# Deep-merge if the existing leaf is also a dict, to avoid overwriting siblings.
if isinstance(unflattened, dict) and isinstance(node.get(leaf), dict):
_deep_merge(node[leaf], unflattened)
else:
node[leaf] = unflattened
return out
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> None:
"""Recursively merge *override* into *base* in-place."""
for k, v in override.items():
if isinstance(v, dict) and isinstance(base.get(k), dict):
_deep_merge(base[k], v)
else:
base[k] = v
# ---------------------------------------------------------------------------
# ID resolution
# ---------------------------------------------------------------------------
def _resolve_sweep_id(sweep_id: str) -> str:
"""Ensure a sweep ID is fully qualified as ``entity/project/sweep_id``.
W&B's API requires the full three-part path. If ``sweep_id`` is already
fully qualified (contains exactly two ``/``), it is returned unchanged.
If it is a bare ID (no ``/``), the current entity's projects are searched
until a matching sweep is found.
Args:
sweep_id: A bare sweep ID (e.g. ``"abc123"``) or a fully-qualified
path (``"entity/project/abc123"``).
Returns:
Fully-qualified sweep ID as ``entity/project/sweep_id``.
Raises:
ValueError: If the ID cannot be resolved to any known sweep.
"""
if sweep_id.count("/") == 2:
return sweep_id
if sweep_id.count("/") != 0:
raise ValueError(f"Ambiguous sweep ID '{sweep_id}'. " "Use the fully-qualified form: entity/project/sweep_id.")
api = wandb.Api()
entity = api.default_entity
for project in api.projects(entity):
candidate = f"{entity}/{project.name}/{sweep_id}"
try:
api.sweep(candidate)
return candidate
except Exception:
continue
raise ValueError(
f"Could not find sweep '{sweep_id}' in any project for entity '{entity}'. "
"Pass the fully-qualified form: entity/project/sweep_id."
)
# ---------------------------------------------------------------------------
# Single-run helpers
# ---------------------------------------------------------------------------
[docs]
def fetch_run(run_id: str) -> Any:
"""Fetch a single W&B run by its fully-qualified ID.
Args:
run_id: Fully-qualified run ID (``entity/project/run_id``).
Returns:
A ``wandb.Run`` object.
"""
return wandb.Api().run(run_id)
[docs]
def config_from_wandb_run(run: Any) -> Config:
"""Reconstruct a Config from a W&B run object.
W&B stores the full ``model_dump()`` nested dict in ``run.config``.
Filters W&B-internal metadata and unwraps sweep param wrappers before
passing to ``Config.model_validate``.
Args:
run: A ``wandb.Run`` object (from e.g. ``wandb.Api().run(...)``).
Returns:
A validated ``Config`` instance.
"""
raw: dict[str, Any] = dict(run.config)
normalised = _unwrap_wandb_value(raw)
assert isinstance(normalised, dict)
nested = _unflatten_dotted_keys(normalised)
config = Config.model_validate(nested)
# The wandb section is intentionally stripped from run.config by _to_flat_config
# (it's run metadata, not experiment config). Restore project and entity from the
# run object so that seed-eval writes back to the same project.
if config.wandb is None:
config = config.model_copy(update={"wandb": WandbConfig(project=run.project, entity=run.entity)})
else:
updates: dict[str, Any] = {}
if config.wandb.project is None:
updates["project"] = run.project
if config.wandb.entity is None:
updates["entity"] = run.entity
if updates:
config = config.model_copy(update={"wandb": config.wandb.model_copy(update=updates)})
return config
# ---------------------------------------------------------------------------
# Sweep helpers
# ---------------------------------------------------------------------------
[docs]
def fetch_sweep_runs(sweep_id: str, *, state: str | None = None) -> list[Any]:
"""Fetch runs from a W&B sweep, optionally filtered by state.
Args:
sweep_id: Fully-qualified sweep ID (``entity/project/sweep_id``).
state: If provided, only return runs with this state (e.g. ``"finished"``,
``"running"``, ``"crashed"``). If ``None``, return all runs.
Returns:
List of ``wandb.Run`` objects.
"""
sweep = wandb.Api().sweep(_resolve_sweep_id(sweep_id))
if state is None:
return list(sweep.runs)
return [r for r in sweep.runs if r.state == state]
[docs]
def fetch_top_k_runs(
sweep_id: str,
metric: str,
top_k: int,
*,
maximize: bool,
) -> list[Any]:
"""Return the top-K finished runs from a W&B sweep, sorted by metric.
Args:
sweep_id: Fully-qualified sweep ID (``entity/project/sweep_id``).
metric: W&B summary metric name to rank by (e.g. ``eval/return/best``).
top_k: Number of top runs to return.
maximize: If True, sort descending (higher is better). If False, ascending.
Returns:
List of ``wandb.Run`` objects, length ≤ ``top_k``.
"""
finished = fetch_sweep_runs(sweep_id, state="finished")
if len(finished) < top_k:
warnings.warn(
f"Requested top-{top_k} runs but only {len(finished)} finished runs exist "
f"in sweep '{sweep_id}'. Returning all {len(finished)}.",
UserWarning,
stacklevel=2,
)
def _sort_key(run: Any) -> tuple[bool, float]:
val = run.summary.get(metric)
if val is None:
return (True, 0.0) # missing values sort last
return (False, -float(val) if maximize else float(val))
finished.sort(key=_sort_key)
return finished[:top_k]
# ---------------------------------------------------------------------------
# DataFrame helper (notebook-friendly)
# ---------------------------------------------------------------------------
[docs]
def runs_to_dataframe(runs: list[Any], metrics: list[str] | None = None) -> pl.DataFrame:
"""Convert a list of W&B runs to a Polars DataFrame.
Each row corresponds to one run. Config fields are flattened with dot-separated
keys (e.g. ``agent.lr``). Summary metrics are included as-is.
Args:
runs: List of ``wandb.Run`` objects.
metrics: If provided, include only these summary metric keys. If ``None``,
include all summary keys that don't start with ``_``.
Returns:
A ``polars.DataFrame`` with one row per run.
"""
rows = []
for run in runs:
row: dict[str, Any] = {"run_id": run.id, "run_name": run.name, "state": run.state}
config = _unwrap_wandb_value(dict(run.config))
row.update(_flatten_dict(config, sep="."))
summary = {k: v for k, v in run.summary.items() if not k.startswith("_") and (metrics is None or k in metrics)}
row.update(summary)
rows.append(row)
return pl.DataFrame(rows)
def _flatten_dict(d: dict[str, Any], *, sep: str = ".", prefix: str = "") -> dict[str, Any]:
"""Recursively flatten a nested dict with ``sep``-joined keys."""
out: dict[str, Any] = {}
for k, v in d.items():
full_key = f"{prefix}{sep}{k}" if prefix else k
if isinstance(v, dict):
out.update(_flatten_dict(v, sep=sep, prefix=full_key))
else:
out[full_key] = v
return out