Source code for myriad.envs.wrappers

"""Environment wrappers for compatibility with different frameworks.

This module provides wrappers to adapt Myriad environments to different interfaces.
"""

import math
import warnings
from typing import Any, Callable, NamedTuple, TypeVar

import jax.numpy as jnp
from jax import Array

from myriad.envs.environment import Environment, EnvironmentConfig, EnvironmentParams, EnvironmentState
from myriad.utils.observations import to_array

S = TypeVar("S", bound=EnvironmentState)
C = TypeVar("C", bound=EnvironmentConfig)
P = TypeVar("P", bound=EnvironmentParams)


[docs] def make_array_obs_env( env: Environment[S, C, P, Any], obs_to_array: Callable[[Any], Array] | None = None, ) -> Environment[S, C, P, Array]: # type: ignore[type-var] """Wrap an environment to convert NamedTuple observations to arrays. This wrapper is useful for compatibility with standard RL frameworks (Gym, Gymnasium) and neural network agents that expect flat array observations. Args: env: Environment with potentially structured observations (NamedTuple, etc.) obs_to_array: Optional conversion function. If None, assumes observations have a `.to_array()` method (the standard pattern for Myriad observations). Returns: Environment with array observations Example: >>> from myriad.envs.cartpole.tasks.control import make_env >>> env = make_env() # Returns CartPoleObs observations >>> array_env = make_array_obs_env(env) # Returns array observations >>> >>> key = jax.random.PRNGKey(0) >>> obs, state = array_env.reset(key, array_env.params, array_env.config) >>> print(obs.shape) # (4,) >>> >>> # If you need a custom conversion: >>> def custom_converter(obs): ... return jnp.array([obs.x, obs.theta]) # Only position and angle >>> partial_env = make_array_obs_env(env, obs_to_array=custom_converter) """ # Default: use the `.to_array()` method converter = obs_to_array if obs_to_array is not None else lambda obs: obs.to_array() def wrapped_step( key: Array, state: S, action: Array, params: P, config: C, ): obs, next_state, reward, done, info = env.step(key, state, action, params, config) return converter(obs), next_state, reward, done, info def wrapped_reset( key: Array, params: P, config: C, ): obs, state = env.reset(key, params, config) return converter(obs), state # Return a new Environment with wrapped functions return env._replace( step=wrapped_step, reset=wrapped_reset, )
# --------------------------------------------------------------------------- # Frame stacking # ---------------------------------------------------------------------------
[docs] class FrameStackState(NamedTuple): """State for a frame-stacking wrapper. Attributes: env_state: The wrapped environment's state (any pytree-compatible type). obs_buffer: Ring buffer of the last ``n_frames`` observations. Shape: ``(n_frames, obs_dim)``, newest frame in the **last** slot. """ env_state: Any obs_buffer: Array
[docs] def make_frame_stack_env(env: Environment, n_frames: int) -> Environment: """Wrap an environment to stack the last ``n_frames`` observations. Returns flat ``Array`` observations of shape ``(n_frames * obs_dim,)``. The wrapped env's state becomes :class:`FrameStackState`, which bundles the inner env state with a ring buffer of recent observations. On :func:`reset`, the buffer is zero-filled with the initial observation placed in the last slot (newest position). On :func:`step`, the buffer is rolled by one and the new observation is inserted at the end. Both functions are pure JAX — compatible with :func:`jax.jit`, :func:`jax.vmap`, and :func:`jax.lax.scan`. Args: env: Environment to wrap. Observations must be flat :class:`~jax.Array` or have a ``.to_array()`` method (the standard Myriad pattern). n_frames: Number of consecutive observations to stack. Returns: A new :class:`~myriad.envs.environment.Environment` whose ``reset`` and ``step`` return stacked observations and :class:`FrameStackState`. Example: >>> import jax >>> from myriad.envs import make_env >>> from myriad.envs.wrappers import make_frame_stack_env >>> inner = make_env("cartpole-control") >>> env = make_frame_stack_env(inner, n_frames=4) >>> obs, state = env.reset(jax.random.PRNGKey(0), env.params, env.config) >>> obs.shape # (4 * 4,) (16,) """ if n_frames < 1: raise ValueError(f"n_frames must be >= 1 (got {n_frames}).") if n_frames == 1: warnings.warn( "make_frame_stack_env called with n_frames=1. This adds wrapper overhead with no benefit " "— the stacked obs is identical to the raw obs. Use the env directly instead.", UserWarning, stacklevel=2, ) inner_obs_shape = env.get_obs_shape(env.config) obs_dim = math.prod(inner_obs_shape) def _get_obs_shape(config) -> tuple[int, ...]: return (n_frames * math.prod(env.get_obs_shape(config)),) def _reset(key, params, config): obs, inner_state = env.reset(key, params, config) obs_flat = to_array(obs).reshape(obs_dim) buffer = jnp.concatenate([jnp.zeros((n_frames - 1, obs_dim), dtype=obs_flat.dtype), obs_flat[None]], axis=0) return buffer.reshape(-1), FrameStackState(env_state=inner_state, obs_buffer=buffer) def _step(key, state, action, params, config): obs, inner_state, reward, done, info = env.step(key, state.env_state, action, params, config) obs_flat = to_array(obs).reshape(obs_dim) new_buffer = jnp.concatenate([state.obs_buffer[1:], obs_flat[None]], axis=0) return new_buffer.reshape(-1), FrameStackState(env_state=inner_state, obs_buffer=new_buffer), reward, done, info return env._replace( reset=_reset, step=_step, get_obs_shape=_get_obs_shape, )