"""Utilities for handling different observation types.This module provides utilities for working with observations in different formats:- to_array(): Convert observations to JAX arrays- get_field_index(): Map field names to array indices (for NamedTuple observations)See src/myriad/agents/bangbang.py for an example of how to use them in practice."""fromtypingimportAnyimportjax.numpyasjnpfromjaximportArrayfrommyriad.core.typesimportObservation
[docs]defto_array(obs:Observation|Array)->Array:"""Convert observation to array format. Handles different observation types: - Arrays (JAX/numpy): Returned as-is - Observations with .to_array(): Converted via that method - Other types: Attempted conversion via jnp.asarray Args: obs: Observation (satisfying the Observation Protocol) or array Returns: Array representation of the observation Raises: ValueError: If observation cannot be converted to array """# Already an array - return as-isifisinstance(obs,(jnp.ndarray,type(jnp.array(0)))):returnobs# type: ignore[return-value]# Has .to_array() method (satisfies Observation Protocol)ifhasattr(obs,"to_array")andcallable(obs.to_array):# type: ignore[union-attr]returnobs.to_array()# type: ignore[union-attr]# Try generic conversiontry:returnjnp.asarray(obs)exceptExceptionase:raiseValueError(f"Cannot convert observation of type {type(obs)} to array. "f"Observation should be an array or have a .to_array() method.")frome
defget_field_index(sample_obs:Any,field_name:str)->int:"""Get the array index for a named field in a NamedTuple observation. This function introspects a NamedTuple observation to find the array index corresponding to a named field. It should be called once during agent initialization and the result cached for use during action selection. Args: sample_obs: Sample observation (must be a NamedTuple or similar with ._fields) field_name: Name of the field to look up (e.g., "theta", "x") Returns: Index of the field in the flattened observation array Raises: ValueError: If sample_obs is not a NamedTuple with ._fields attribute ValueError: If field_name is not found in the observation Example: >>> from myriad.envs.cartpole.physics import PhysicsState >>> sample_obs = PhysicsState(x=0.0, x_dot=0.0, theta=0.0, theta_dot=0.0) >>> theta_idx = get_field_index(sample_obs, "theta") >>> theta_idx 2 """ifnothasattr(sample_obs,"_fields"):raiseValueError(f"get_field_index requires NamedTuple observations with ._fields attribute, "f"but got {type(sample_obs)}. "f"Ensure your environment returns NamedTuple observations.")field_to_index={field:idxforidx,fieldinenumerate(sample_obs._fields)}iffield_namenotinfield_to_index:available_fields=list(field_to_index.keys())raiseValueError(f"Observation field '{field_name}' not found in observation. "f"Available fields: {available_fields}")returnfield_to_index[field_name]