Source code for myriad.envs.classic.pendulum.tasks.base
"""Shared utilities for Pendulum task wrappers."""fromtypingimportNamedTupleimportjaximportjax.numpyasjnpfromflaximportstructfromjaximportArrayfrommyriad.core.spacesimportBoxfrommyriad.core.typesimportPRNGKeyfrom..physicsimportPhysicsConfig,PhysicsState
[docs]@struct.dataclassclassTaskConfig:"""Base configuration shared by all Pendulum tasks. These define the task-specific episode limits. """max_steps:int=200
[docs]classPendulumObservation(NamedTuple):"""Observation for the pendulum system. Uses cos/sin representation to avoid angle discontinuity at +/-pi. This provides bounded values suitable for neural network inputs. Attributes: cos_theta: Cosine of angle from vertical down sin_theta: Sine of angle from vertical down theta_dot: Angular velocity (rad/s) """cos_theta:Arraysin_theta:Arraytheta_dot:Array
[docs]defto_array(self)->Array:"""Convert to flat array for NN-based agents. Returns: Array of shape (3,) with [cos_theta, sin_theta, theta_dot] """returnjnp.stack([self.cos_theta,self.sin_theta,self.theta_dot])
[docs]defget_pendulum_obs(physics_state:PhysicsState)->PendulumObservation:"""Extract standard Pendulum observation from physics state. Converts angle to cos/sin representation to avoid discontinuity. Args: physics_state: PhysicsState with theta, theta_dot fields Returns: PendulumObservation with cos_theta, sin_theta, theta_dot """returnPendulumObservation(cos_theta=jnp.cos(physics_state.theta),sin_theta=jnp.sin(physics_state.theta),theta_dot=physics_state.theta_dot,)
[docs]defget_pendulum_obs_shape()->tuple[int,...]:"""Get the shape of the Pendulum observation space. Returns: Observation shape tuple (3,) for [cos_theta, sin_theta, theta_dot] """return(3,)
[docs]defget_pendulum_action_space(config:PhysicsConfig)->Box:"""Get the continuous action space for Pendulum. Args: config: Physics configuration with max_torque Returns: Box space for torque in [-max_torque, max_torque] """returnBox(low=-config.max_torque,high=config.max_torque,shape=(1,))
[docs]defsample_initial_physics(key:PRNGKey)->PhysicsState:"""Sample initial physics state with random angle and velocity. Initializes the pendulum with random angle in [-pi, pi] and small random velocity. Args: key: RNG key for random initialization Returns: PhysicsState with random initial conditions """key1,key2=jax.random.split(key)# Random angle in [-pi, pi]theta=jax.random.uniform(key1,shape=(),minval=-jnp.pi,maxval=jnp.pi)# Small random velocity in [-1, 1]theta_dot=jax.random.uniform(key2,shape=(),minval=-1.0,maxval=1.0)returnPhysicsState(theta=theta,theta_dot=theta_dot)