Source code for myriad.envs.classic.cartpole.physics

"""Pure stateless physics for the CartPole system.

This module contains the ground truth dynamics for the cart-pole system,
completely decoupled from any task-specific logic (rewards, terminations, observations).

The physics can be reused by different tasks (control, system ID, etc.) and can be
directly accessed by model-based methods like MPC planners or Neural ODEs.
"""

from typing import NamedTuple

import chex
import jax.numpy as jnp
from flax import struct
from jax import Array


[docs] class PhysicsState(NamedTuple): """Pure physical state of the cart-pole system. For CartPole, this is a fully observable system, so PhysicsState serves as both the internal state and the observation. This eliminates duplication and makes observability explicit. Attributes: x: Cart position (m) x_dot: Cart velocity (m/s) theta: Pole angle from vertical (rad, 0 = upright) theta_dot: Pole angular velocity (rad/s) """ x: Array x_dot: Array theta: Array theta_dot: Array
[docs] def to_array(self) -> Array: """Convert to flat array for NN-based agents. Returns: Array of shape (4,) with [x, x_dot, theta, theta_dot] """ return jnp.stack([self.x, self.x_dot, self.theta, self.theta_dot])
[docs] @classmethod def from_array(cls, arr: Array) -> "PhysicsState": """Create state from flat array. Args: arr: Array of shape (4,) with [x, x_dot, theta, theta_dot] Returns: PhysicsState instance """ chex.assert_shape(arr, (4,)) return cls( x=arr[0], # type: ignore x_dot=arr[1], # type: ignore theta=arr[2], # type: ignore theta_dot=arr[3], # type: ignore )
[docs] @struct.dataclass class PhysicsConfig: """Static physics constants for the cart-pole system. These are compile-time constants passed as static_argnames to jit. Changing these values requires recompilation but enables better optimization. """ gravity: float = 9.8 # m/s^2 cart_mass: float = 1.0 # kg pole_mass: float = 0.1 # kg pole_length: float = 0.5 # m (half-length of the pole) force_magnitude: float = 10.0 # N dt: float = 0.02 # s (integration timestep)
[docs] @struct.dataclass class PhysicsParams: """Dynamic physics parameters for domain randomization. Currently empty for CartPole, but maintained for protocol consistency and future domain randomization support (e.g., varying masses/lengths per episode). """ ...
[docs] def step_physics( state: PhysicsState, action: Array, params: PhysicsParams, config: PhysicsConfig, ) -> PhysicsState: """Pure physics step using Euler integration. The cart-pole dynamics are based on the equations from Barto, Sutton, Anderson (1983). Args: state: Current physical state (x, x_dot, theta, theta_dot) action: Discrete action {0, 1} representing force direction params: Dynamic parameters (unused, reserved for future randomization) config: Static physics constants Returns: Next physical state after one dt timestep """ x, x_dot, theta, theta_dot = state # Convert discrete action {0, 1} to continuous force {-1, +1} * force_magnitude force = (2 * action - 1) * config.force_magnitude # Cart-pole dynamics cos_theta = jnp.cos(theta) sin_theta = jnp.sin(theta) total_mass = config.cart_mass + config.pole_mass pole_mass_length = config.pole_mass * config.pole_length # Intermediate calculation for angular acceleration temp = (force + pole_mass_length * theta_dot**2 * sin_theta) / total_mass # Angular acceleration theta_acc = (config.gravity * sin_theta - cos_theta * temp) / ( config.pole_length * (4.0 / 3.0 - config.pole_mass * cos_theta**2 / total_mass) ) # Linear acceleration x_acc = temp - pole_mass_length * theta_acc * cos_theta / total_mass # Euler integration x_next = x + config.dt * x_dot x_dot_next = x_dot + config.dt * x_acc theta_next = theta + config.dt * theta_dot theta_dot_next = theta_dot + config.dt * theta_acc return PhysicsState(x_next, x_dot_next, theta_next, theta_dot_next)
[docs] def create_physics_params(**kwargs) -> PhysicsParams: """Factory function to create PhysicsParams. Args: **kwargs: Reserved for future domain randomization parameters Returns: PhysicsParams instance """ return PhysicsParams()