Wrappers¶
Environment wrappers for compatibility with different frameworks.
This module provides wrappers to adapt Myriad environments to different interfaces.
- myriad.envs.wrappers.make_array_obs_env(env, obs_to_array=None)[source]¶
Wrap an environment to convert NamedTuple observations to arrays.
This wrapper is useful for compatibility with standard RL frameworks (Gym, Gymnasium) and neural network agents that expect flat array observations.
- Parameters:
env (Environment[S, C, P, Any]) – Environment with potentially structured observations (NamedTuple, etc.)
obs_to_array (Callable[[Any], Array] | None) – Optional conversion function. If None, assumes observations have a .to_array() method (the standard pattern for Myriad observations).
- Returns:
Environment with array observations
- Return type:
Environment[S, C, P, Array]
Example
>>> from myriad.envs.cartpole.tasks.control import make_env >>> env = make_env() # Returns CartPoleObs observations >>> array_env = make_array_obs_env(env) # Returns array observations >>> >>> key = jax.random.PRNGKey(0) >>> obs, state = array_env.reset(key, array_env.params, array_env.config) >>> print(obs.shape) # (4,) >>> >>> # If you need a custom conversion: >>> def custom_converter(obs): ... return jnp.array([obs.x, obs.theta]) # Only position and angle >>> partial_env = make_array_obs_env(env, obs_to_array=custom_converter)
- class myriad.envs.wrappers.FrameStackState(env_state, obs_buffer)[source]¶
Bases:
NamedTupleState for a frame-stacking wrapper.
- Variables:
env_state (Any) – The wrapped environment’s state (any pytree-compatible type).
obs_buffer (jax.jaxlib._jax.Array) – Ring buffer of the last
n_framesobservations. Shape:(n_frames, obs_dim), newest frame in the last slot.
- obs_buffer: Array¶
Alias for field number 1
- myriad.envs.wrappers.make_frame_stack_env(env, n_frames)[source]¶
Wrap an environment to stack the last
n_framesobservations.Returns flat
Arrayobservations of shape(n_frames * obs_dim,). The wrapped env’s state becomesFrameStackState, which bundles the inner env state with a ring buffer of recent observations.On
reset(), the buffer is zero-filled with the initial observation placed in the last slot (newest position). Onstep(), the buffer is rolled by one and the new observation is inserted at the end.Both functions are pure JAX — compatible with
jax.jit(),jax.vmap(), andjax.lax.scan().- Parameters:
env (Environment) – Environment to wrap. Observations must be flat
Arrayor have a.to_array()method (the standard Myriad pattern).n_frames (int) – Number of consecutive observations to stack.
- Returns:
A new
Environmentwhoseresetandstepreturn stacked observations andFrameStackState.- Return type:
Example
>>> import jax >>> from myriad.envs import make_env >>> from myriad.envs.wrappers import make_frame_stack_env >>> inner = make_env("cartpole-control") >>> env = make_frame_stack_env(inner, n_frames=4) >>> obs, state = env.reset(jax.random.PRNGKey(0), env.params, env.config) >>> obs.shape # (4 * 4,) (16,)