"""Control task wrapper for CartPole.
Standard balancing task: Keep the pole upright for as long as possible.
Reward: +1 per timestep the pole remains balanced.
"""
from typing import Any, Dict, NamedTuple
import jax
import jax.numpy as jnp
from flax import struct
from jax import Array
from myriad.core.spaces import Discrete
from myriad.core.types import PRNGKey
from myriad.envs.environment import Environment
from ..physics import PhysicsConfig, PhysicsParams, PhysicsState, create_physics_params, step_physics
from .base import (
TaskConfig,
check_termination,
get_cartpole_action_space,
get_cartpole_obs,
get_cartpole_obs_shape,
sample_initial_physics,
)
[docs]
class ControlTaskState(NamedTuple):
"""State for the control task.
Attributes:
physics: The underlying physics state (x, x_dot, theta, theta_dot)
t: Current timestep counter
"""
physics: PhysicsState
t: Array
[docs]
@struct.dataclass
class ControlTaskConfig:
"""Configuration for the CartPole control task.
Composed of physics config and task config for clean separation.
"""
physics: PhysicsConfig = struct.field(default_factory=PhysicsConfig)
task: TaskConfig = struct.field(default_factory=TaskConfig)
@property
def dt(self) -> float:
"""Timestep duration in seconds."""
return self.physics.dt
@property
def max_steps(self) -> int:
"""Required by EnvironmentConfig protocol."""
return self.task.max_steps
[docs]
@struct.dataclass
class ControlTaskParams:
"""Parameters for the control task."""
physics: PhysicsParams = struct.field(default_factory=PhysicsParams)
def _step(
key: PRNGKey,
state: ControlTaskState,
action: Array,
params: ControlTaskParams,
config: ControlTaskConfig,
) -> tuple[PhysicsState, ControlTaskState, Array, Array, Dict[str, Any]]:
"""Step the control task forward one timestep.
Args:
key: RNG key (unused for deterministic control task, but part of protocol)
state: Current task state
action: Discrete action {0, 1}
params: Task parameters
config: Task configuration (static)
Returns:
obs_next: Next observation (PhysicsState = fully observable)
next_state: Next environment state
reward: Reward (+1.0 per step)
done: Termination flag (1.0 if done, 0.0 otherwise)
info: Empty dict (no auxiliary information)
"""
# Step the pure physics
next_physics = step_physics(state.physics, action, params.physics, config.physics)
# Increment timestep
t_next = state.t + 1
# Check termination
done = check_termination(next_physics, t_next, config.task)
# Compute reward (standard CartPole: +1 per step)
reward = jnp.float32(1.0)
# Create next state
next_state = ControlTaskState(physics=next_physics, t=t_next)
# Extract observation
obs_next = get_obs(next_state, params, config)
return obs_next, next_state, reward, done, {}
def _reset(
key: PRNGKey,
params: ControlTaskParams,
config: ControlTaskConfig,
) -> tuple[PhysicsState, ControlTaskState]:
"""Reset the control task to initial state.
Initializes the pole with small random perturbations around the upright position.
Args:
key: RNG key for random initialization
params: Task parameters
config: Task configuration (static)
Returns:
obs: Initial observation (PhysicsState with named fields)
state: Initial task state
"""
# Sample initial physics state with small random perturbations
physics = sample_initial_physics(key)
state = ControlTaskState(physics=physics, t=jnp.array(0))
obs = get_obs(state, params, config)
return obs, state
# JIT the step and reset functions with config as static argument
step = jax.jit(_step, static_argnames=["config"])
reset = jax.jit(_reset, static_argnames=["config"])
[docs]
def get_obs(
state: ControlTaskState,
params: ControlTaskParams,
config: ControlTaskConfig,
) -> PhysicsState:
"""Extract observation from state.
For control task, observation is the physical state as a NamedTuple with named fields.
Neural network agents can call `.to_array()` for flat array representation.
Args:
state: Current task state
params: Task parameters (unused)
config: Task configuration (unused)
Returns:
PhysicsState with named fields (x, x_dot, theta, theta_dot)
"""
return get_cartpole_obs(state.physics)
[docs]
def get_obs_shape(config: ControlTaskConfig) -> tuple[int, ...]:
"""Get the shape of the observation space.
Args:
config: Task configuration (unused)
Returns:
Observation shape tuple
"""
return get_cartpole_obs_shape()
[docs]
def get_action_space(config: ControlTaskConfig) -> Discrete:
"""Get the discrete action space for the environment.
Args:
config: Task configuration (unused)
Returns:
Discrete space with 2 actions: 0 (push left) and 1 (push right)
"""
return get_cartpole_action_space()
[docs]
def make_env(
config: ControlTaskConfig | None = None,
params: ControlTaskParams | None = None,
**kwargs,
) -> Environment[ControlTaskState, ControlTaskConfig, ControlTaskParams, PhysicsState]:
"""Create a CartPole control task environment.
Args:
config: Custom ControlTaskConfig. If None, uses defaults.
params: Custom ControlTaskParams. If None, creates from kwargs.
**kwargs: Keyword arguments for creating config/params if not provided.
Returns:
Environment instance for the control task
"""
if config is None:
# Parse kwargs into nested config structure
physics_fields = {"gravity", "cart_mass", "pole_mass", "pole_length", "force_magnitude", "dt"}
task_fields = {"max_steps", "theta_threshold", "x_threshold"}
physics_kwargs = {k: v for k, v in kwargs.items() if k in physics_fields}
task_kwargs = {k: v for k, v in kwargs.items() if k in task_fields}
config = ControlTaskConfig(
physics=PhysicsConfig(**physics_kwargs) if physics_kwargs else PhysicsConfig(),
task=TaskConfig(**task_kwargs) if task_kwargs else TaskConfig(),
)
if params is None:
params = ControlTaskParams(physics=create_physics_params())
return Environment(
step=step,
reset=reset,
get_action_space=get_action_space,
get_obs_shape=get_obs_shape,
params=params,
config=config,
)