Environment

Overview

from myriad.envs import make_env

env = make_env("cartpole-control")
obs, state = env.reset(key, env.params, env.config)
obs, state, reward, done, info = env.step(key, state, action, env.params, env.config)

Available Environments

ID

Category

Environment

Description

cartpole-control

Classic

CartPole

Inverted pendulum balancing

pendulum-control

Classic

Pendulum

Swing-up control

opto-hill-1d-sysid

Biology

Opto Hill 1D

1D optogenetic circuit, continuous light input

ccasr-gfp-control

Biology

CcaSR-GFP

CcaS/CcaR + GFP optogenetic gene circuit

Factory Function

myriad.envs.make_env(name, **kwargs)[source]

Create an environment instance by name.

Parameters:
  • name (str) – Unique identifier for the environment.

  • **kwargs (Any) – Keyword arguments passed to the environment’s factory function.

Returns:

An instance of the requested Environment.

Raises:

ValueError – If the environment name is not found in the registry.

Return type:

Any

Base Protocols

Base environment definitions for JAX-based RL environments.

This module provides small, focused Protocols for the three environment components (config, params, state) and a typed container Environment which holds the environment’s pure functions. The Protocols are intentionally small and permissive so concrete environments remain free to use dataclasses, Flax structs, NamedTuples, etc., while still providing helpful static typing and documentation.

Design Rationale: Config vs Params

Environments separate static configuration (EnvironmentConfig) from dynamic parameters (EnvironmentParams) to optimize JAX compilation:

  • EnvironmentConfig: Static, compile-time configuration passed as static_argnames to jax.jit(). Changes trigger recompilation but enable better optimization. Use for: physics constants, termination thresholds, max_steps, environment structure.

  • EnvironmentParams: Dynamic, runtime parameters that can vary between episodes without recompilation. Use for: randomized dynamics, curriculum learning parameters, domain randomization values, etc.

If your environment doesn’t need dynamic parameters, EnvironmentParams can be empty, but keep the structure for protocol consistency.

class myriad.envs.environment.EnvironmentConfig(*args, **kwargs)[source]

Bases: Protocol

Protocol for environment configuration objects.

Variables:

max_steps (int) – Maximum steps per episode (required for training loops).

Notes

Typical fields include:

  • Physics constants (gravity, mass, friction coefficients)

  • Environment structure (grid size, number of agents)

  • Termination thresholds

  • Integration timestep (dt)

  • Any parameter that defines “what kind of environment this is”

Implementation: Use @flax.struct.dataclass for JAX compatibility.

property max_steps: int
__init__(*args, **kwargs)
class myriad.envs.environment.EnvironmentParams(*args, **kwargs)[source]

Bases: Protocol

Protocol for environment parameter objects.

Notes

Use cases include:

  • Domain randomization (randomized dynamics, varying targets)

  • Curriculum learning (difficulty parameters that change over training)

  • Multi-task learning (task-specific parameters)

  • Any parameter you want to sweep/randomize frequently

If your environment doesn’t need runtime variation, this can be empty, but maintain the structure for protocol consistency.

Implementation: Use @flax.struct.dataclass for JAX compatibility.

This is intentionally an empty Protocol — it’s a structural marker for type consistency in the Environment container.

__init__(*args, **kwargs)
class myriad.envs.environment.EnvironmentState(*args, **kwargs)[source]

Bases: Protocol

Protocol for environment state objects.

As with EnvironmentParams, this is a marker Protocol. A state should be something JAX can transform (e.g., a NamedTuple or a pytree-compatible dataclass), but the Protocol leaves that choice to the implementation.

__init__(*args, **kwargs)
class myriad.envs.environment.GetActionSpaceFn(*args, **kwargs)[source]

Bases: Protocol[C_contra]

Return the environment’s action space.

Parameters:

config (EnvironmentConfig) – Environment configuration (structural info like action dimensions).

Returns:

The action space specification.

Return type:

Space

__init__(*args, **kwargs)
class myriad.envs.environment.GetObsShapeFn(*args, **kwargs)[source]

Bases: Protocol[C_contra]

Return the shape of observations produced by the environment.

Parameters:

config (EnvironmentConfig) – Environment configuration (structural info like state dimensions).

Returns:

shape – Shape tuple for observations (e.g., (4,) for CartPole).

Return type:

tuple[int, ]

__init__(*args, **kwargs)
class myriad.envs.environment.ResetFn(*args, **kwargs)[source]

Bases: Protocol[S_co, P_contra, C_contra, Obs_co]

Reset the environment to an initial state.

Parameters:
Returns:

Initial observation and initial environment state.

Return type:

tuple[:class:`~myriad.core.types.Observation`, EnvironmentState]

__init__(*args, **kwargs)
class myriad.envs.environment.StepFn(*args, **kwargs)[source]

Bases: Protocol[S_inv, P_contra, C_contra, Obs_co]

Advance the environment by one timestep.

Parameters:
Returns:

A 5-tuple containing:

  • next_obsObservation after the transition.

  • next_state – Updated environment state.

  • reward – Scalar reward signal.

  • done – Boolean termination flag.

  • info – Auxiliary information dictionary.

Return type:

tuple[Observation, EnvironmentState, jax.Array, jax.Array, dict[str, Any]]

__init__(*args, **kwargs)
class myriad.envs.environment.Environment(config, params, get_action_space, get_obs_shape, reset, step, sample_params_fn=None)[source]

Bases: NamedTuple, Generic[S, C, P, Obs]

Typed container for a JAX-friendly environment’s pure functions.

This is a generic class parameterized by:

Variables:
  • config (C) – Static configuration used as static_argnames when jitting functions. Changes to config require recompilation.

  • params (P) – Dynamic parameters that can vary between runs without recompilation. Passed as a regular (non-static) argument to step/reset functions.

  • get_action_space (GetActionSpaceFn[C]) – Pure function returning the action space specification.

  • get_obs_shape (GetObsShapeFn[C]) – Pure function returning the observation shape tuple.

  • reset (ResetFn[S, P, C, Obs]) – Pure function to reset the environment to an initial state.

  • step (StepFn[S, P, C, Obs]) – Pure function to advance the environment by one timestep.

Examples

When jitting environment functions:

step = jax.jit(_step, static_argnames=["config"])
reset = jax.jit(_reset, static_argnames=["config"])

This allows params to vary without recompilation while keeping config static.

config: C

Alias for field number 0

params: P

Alias for field number 1

get_action_space: GetActionSpaceFn[C]

Alias for field number 2

get_obs_shape: GetObsShapeFn[C]

Alias for field number 3

reset: ResetFn[S, P, C, Obs]

Alias for field number 4

step: StepFn[S, P, C, Obs]

Alias for field number 5

sample_params_fn: Callable[[Array], P] | None

Alias for field number 6

classmethod __class_getitem__(params)

Parameterizes a generic class.

At least, parameterizing a generic class is the main thing this method does. For example, for some generic class Foo, this is called when we do Foo[int] - there, with cls=Foo and params=int.

However, note that this method is also called when defining generic classes in the first place with class Foo(Generic[T]): ….