Environment¶
Overview¶
from myriad.envs import make_env
env = make_env("cartpole-control")
obs, state = env.reset(key, env.params, env.config)
obs, state, reward, done, info = env.step(key, state, action, env.params, env.config)
Available Environments¶
ID |
Category |
Environment |
Description |
|---|---|---|---|
|
Inverted pendulum balancing |
||
|
Swing-up control |
||
|
1D optogenetic circuit, continuous light input |
||
|
CcaS/CcaR + GFP optogenetic gene circuit |
Factory Function¶
- myriad.envs.make_env(name, **kwargs)[source]¶
Create an environment instance by name.
- Parameters:
- Returns:
An instance of the requested Environment.
- Raises:
ValueError – If the environment name is not found in the registry.
- Return type:
Base Protocols¶
Base environment definitions for JAX-based RL environments.
This module provides small, focused Protocols for the three environment
components (config, params, state) and a typed container Environment
which holds the environment’s pure functions. The Protocols are intentionally
small and permissive so concrete environments remain free to use dataclasses,
Flax structs, NamedTuples, etc., while still providing helpful static typing
and documentation.
Design Rationale: Config vs Params
Environments separate static configuration (EnvironmentConfig) from
dynamic parameters (EnvironmentParams) to optimize JAX compilation:
EnvironmentConfig: Static, compile-time configuration passed as
static_argnamestojax.jit(). Changes trigger recompilation but enable better optimization. Use for: physics constants, termination thresholds,max_steps, environment structure.EnvironmentParams: Dynamic, runtime parameters that can vary between episodes without recompilation. Use for: randomized dynamics, curriculum learning parameters, domain randomization values, etc.
If your environment doesn’t need dynamic parameters, EnvironmentParams
can be empty, but keep the structure for protocol consistency.
- class myriad.envs.environment.EnvironmentConfig(*args, **kwargs)[source]¶
Bases:
ProtocolProtocol for environment configuration objects.
- Variables:
max_steps (
int) – Maximum steps per episode (required for training loops).
Notes
Typical fields include:
Physics constants (gravity, mass, friction coefficients)
Environment structure (grid size, number of agents)
Termination thresholds
Integration timestep (
dt)Any parameter that defines “what kind of environment this is”
Implementation: Use
@flax.struct.dataclassfor JAX compatibility.- __init__(*args, **kwargs)¶
- class myriad.envs.environment.EnvironmentParams(*args, **kwargs)[source]¶
Bases:
ProtocolProtocol for environment parameter objects.
Notes
Use cases include:
Domain randomization (randomized dynamics, varying targets)
Curriculum learning (difficulty parameters that change over training)
Multi-task learning (task-specific parameters)
Any parameter you want to sweep/randomize frequently
If your environment doesn’t need runtime variation, this can be empty, but maintain the structure for protocol consistency.
Implementation: Use
@flax.struct.dataclassfor JAX compatibility.This is intentionally an empty Protocol — it’s a structural marker for type consistency in the
Environmentcontainer.- __init__(*args, **kwargs)¶
- class myriad.envs.environment.EnvironmentState(*args, **kwargs)[source]¶
Bases:
ProtocolProtocol for environment state objects.
As with
EnvironmentParams, this is a marker Protocol. A state should be something JAX can transform (e.g., aNamedTupleor a pytree-compatible dataclass), but the Protocol leaves that choice to the implementation.- __init__(*args, **kwargs)¶
- class myriad.envs.environment.GetActionSpaceFn(*args, **kwargs)[source]¶
Bases:
Protocol[C_contra]Return the environment’s action space.
- Parameters:
config (
EnvironmentConfig) – Environment configuration (structural info like action dimensions).- Returns:
The action space specification.
- Return type:
Space
- __init__(*args, **kwargs)¶
- class myriad.envs.environment.GetObsShapeFn(*args, **kwargs)[source]¶
Bases:
Protocol[C_contra]Return the shape of observations produced by the environment.
- Parameters:
config (
EnvironmentConfig) – Environment configuration (structural info like state dimensions).- Returns:
shape – Shape tuple for observations (e.g.,
(4,)for CartPole).- Return type:
tuple[int,]
- __init__(*args, **kwargs)¶
- class myriad.envs.environment.ResetFn(*args, **kwargs)[source]¶
Bases:
Protocol[S_co,P_contra,C_contra,Obs_co]Reset the environment to an initial state.
- Parameters:
key (
jax.Array) – JAX PRNG key for stochastic initialization.params (
EnvironmentParams) – Dynamic environment parameters.config (
EnvironmentConfig) – Static environment configuration.
- Returns:
Initial observation and initial environment state.
- Return type:
tuple[:class:`~myriad.core.types.Observation`,EnvironmentState]
- __init__(*args, **kwargs)¶
- class myriad.envs.environment.StepFn(*args, **kwargs)[source]¶
Bases:
Protocol[S_inv,P_contra,C_contra,Obs_co]Advance the environment by one timestep.
- Parameters:
key (
jax.Array) – JAX PRNG key for stochastic transitions.state (
EnvironmentState) – Current environment state.action (
jax.Array) – Action to execute.params (
EnvironmentParams) – Dynamic environment parameters.config (
EnvironmentConfig) – Static environment configuration.
- Returns:
A 5-tuple containing:
next_obs –
Observationafter the transition.next_state – Updated environment state.
reward – Scalar reward signal.
done – Boolean termination flag.
info – Auxiliary information dictionary.
- Return type:
tuple[Observation,EnvironmentState,jax.Array,jax.Array,dict[str,Any]]
- __init__(*args, **kwargs)¶
- class myriad.envs.environment.Environment(config, params, get_action_space, get_obs_shape, reset, step, sample_params_fn=None)[source]¶
Bases:
NamedTuple,Generic[S,C,P,Obs]Typed container for a JAX-friendly environment’s pure functions.
This is a generic class parameterized by:
S– TheEnvironmentStatetype.C– TheEnvironmentConfigtype.P– TheEnvironmentParamstype.Obs– TheObservationtype.
- Variables:
config (
C) – Static configuration used asstatic_argnameswhen jitting functions. Changes to config require recompilation.params (
P) – Dynamic parameters that can vary between runs without recompilation. Passed as a regular (non-static) argument tostep/resetfunctions.get_action_space (
GetActionSpaceFn[C]) – Pure function returning the action space specification.get_obs_shape (
GetObsShapeFn[C]) – Pure function returning the observation shape tuple.reset (
ResetFn[S,P,C,Obs]) – Pure function to reset the environment to an initial state.step (
StepFn[S,P,C,Obs]) – Pure function to advance the environment by one timestep.
Examples
When jitting environment functions:
step = jax.jit(_step, static_argnames=["config"]) reset = jax.jit(_reset, static_argnames=["config"])
This allows
paramsto vary without recompilation while keepingconfigstatic.- config: C¶
Alias for field number 0
- params: P¶
Alias for field number 1
- get_action_space: GetActionSpaceFn[C]¶
Alias for field number 2
- get_obs_shape: GetObsShapeFn[C]¶
Alias for field number 3
- classmethod __class_getitem__(params)¶
Parameterizes a generic class.
At least, parameterizing a generic class is the main thing this method does. For example, for some generic class Foo, this is called when we do Foo[int] - there, with cls=Foo and params=int.
However, note that this method is also called when defining generic classes in the first place with class Foo(Generic[T]): ….