Wrappers

Environment wrappers for compatibility with different frameworks.

This module provides wrappers to adapt Myriad environments to different interfaces.

myriad.envs.wrappers.make_array_obs_env(env, obs_to_array=None)[source]

Wrap an environment to convert NamedTuple observations to arrays.

This wrapper is useful for compatibility with standard RL frameworks (Gym, Gymnasium) and neural network agents that expect flat array observations.

Parameters:
  • env (Environment[S, C, P, Any]) – Environment with potentially structured observations (NamedTuple, etc.)

  • obs_to_array (Callable[[Any], Array] | None) – Optional conversion function. If None, assumes observations have a .to_array() method (the standard pattern for Myriad observations).

Returns:

Environment with array observations

Return type:

Environment[S, C, P, Array]

Example

>>> from myriad.envs.cartpole.tasks.control import make_env
>>> env = make_env()  # Returns CartPoleObs observations
>>> array_env = make_array_obs_env(env)  # Returns array observations
>>>
>>> key = jax.random.PRNGKey(0)
>>> obs, state = array_env.reset(key, array_env.params, array_env.config)
>>> print(obs.shape)  # (4,)
>>>
>>> # If you need a custom conversion:
>>> def custom_converter(obs):
...     return jnp.array([obs.x, obs.theta])  # Only position and angle
>>> partial_env = make_array_obs_env(env, obs_to_array=custom_converter)
class myriad.envs.wrappers.FrameStackState(env_state, obs_buffer)[source]

Bases: NamedTuple

State for a frame-stacking wrapper.

Variables:
  • env_state (Any) – The wrapped environment’s state (any pytree-compatible type).

  • obs_buffer (jax.jaxlib._jax.Array) – Ring buffer of the last n_frames observations. Shape: (n_frames, obs_dim), newest frame in the last slot.

env_state: Any

Alias for field number 0

obs_buffer: Array

Alias for field number 1

myriad.envs.wrappers.make_frame_stack_env(env, n_frames)[source]

Wrap an environment to stack the last n_frames observations.

Returns flat Array observations of shape (n_frames * obs_dim,). The wrapped env’s state becomes FrameStackState, which bundles the inner env state with a ring buffer of recent observations.

On reset(), the buffer is zero-filled with the initial observation placed in the last slot (newest position). On step(), the buffer is rolled by one and the new observation is inserted at the end.

Both functions are pure JAX — compatible with jax.jit(), jax.vmap(), and jax.lax.scan().

Parameters:
  • env (Environment) – Environment to wrap. Observations must be flat Array or have a .to_array() method (the standard Myriad pattern).

  • n_frames (int) – Number of consecutive observations to stack.

Returns:

A new Environment whose reset and step return stacked observations and FrameStackState.

Return type:

Environment

Example

>>> import jax
>>> from myriad.envs import make_env
>>> from myriad.envs.wrappers import make_frame_stack_env
>>> inner = make_env("cartpole-control")
>>> env = make_frame_stack_env(inner, n_frames=4)
>>> obs, state = env.reset(jax.random.PRNGKey(0), env.params, env.config)
>>> obs.shape  # (4 * 4,)
(16,)