"""Base environment definitions for JAX-based RL environments.
This module provides small, focused Protocols for the three environment
components (config, params, state) and a typed container :class:`Environment`
which holds the environment's pure functions. The Protocols are intentionally
small and permissive so concrete environments remain free to use dataclasses,
Flax structs, NamedTuples, etc., while still providing helpful static typing
and documentation.
**Design Rationale: Config vs Params**
Environments separate static configuration (:class:`EnvironmentConfig`) from
dynamic parameters (:class:`EnvironmentParams`) to optimize JAX compilation:
- **EnvironmentConfig**: Static, compile-time configuration passed as
``static_argnames`` to :func:`jax.jit`. Changes trigger recompilation but
enable better optimization. Use for: physics constants, termination
thresholds, ``max_steps``, environment structure.
- **EnvironmentParams**: Dynamic, runtime parameters that can vary between
episodes without recompilation. Use for: randomized dynamics, curriculum
learning parameters, domain randomization values, etc.
If your environment doesn't need dynamic parameters, :class:`EnvironmentParams`
can be empty, but keep the structure for protocol consistency.
"""
from typing import Any, Generic, NamedTuple, Protocol, TypeVar
from jax import Array
from myriad.core.spaces import Space
from myriad.core.types import Observation, PRNGKey
[docs]
class EnvironmentConfig(Protocol):
"""Protocol for environment configuration objects.
Attributes
----------
max_steps : int
Maximum steps per episode (required for training loops).
Notes
-----
Typical fields include:
- Physics constants (gravity, mass, friction coefficients)
- Environment structure (grid size, number of agents)
- Termination thresholds
- Integration timestep (``dt``)
- Any parameter that defines "what kind of environment this is"
Implementation: Use ``@flax.struct.dataclass`` for JAX compatibility.
"""
@property
def max_steps(self) -> int: ...
[docs]
class EnvironmentParams(Protocol):
"""Protocol for environment parameter objects.
Notes
-----
Use cases include:
- Domain randomization (randomized dynamics, varying targets)
- Curriculum learning (difficulty parameters that change over training)
- Multi-task learning (task-specific parameters)
- Any parameter you want to sweep/randomize frequently
If your environment doesn't need runtime variation, this can be empty,
but maintain the structure for protocol consistency.
Implementation: Use ``@flax.struct.dataclass`` for JAX compatibility.
This is intentionally an empty Protocol — it's a structural marker
for type consistency in the :class:`Environment` container.
"""
...
[docs]
class EnvironmentState(Protocol):
"""Protocol for environment state objects.
As with :class:`EnvironmentParams`, this is a marker Protocol. A state
should be something JAX can transform (e.g., a ``NamedTuple`` or a
pytree-compatible dataclass), but the Protocol leaves that choice to the
implementation.
"""
...
# Type variables bound to the small Protocols above
S = TypeVar("S", bound=EnvironmentState)
P = TypeVar("P", bound=EnvironmentParams)
C = TypeVar("C", bound=EnvironmentConfig)
Obs = TypeVar("Obs", bound=Observation)
# Variance-specific type variables for Protocol definitions
S_co = TypeVar("S_co", bound=EnvironmentState, covariant=True)
S_inv = TypeVar("S_inv", bound=EnvironmentState)
P_contra = TypeVar("P_contra", bound=EnvironmentParams, contravariant=True)
C_contra = TypeVar("C_contra", bound=EnvironmentConfig, contravariant=True)
Obs_co = TypeVar("Obs_co", bound=Observation, covariant=True)
# ---------------------------------------------------------------------------
# Callback Protocols for Environment functions
# ---------------------------------------------------------------------------
[docs]
class GetActionSpaceFn(Protocol[C_contra]):
"""Return the environment's action space.
Parameters
----------
config : EnvironmentConfig
Environment configuration (structural info like action dimensions).
Returns
-------
Space
The action space specification.
"""
def __call__(self, config: C_contra) -> Space: ...
[docs]
class GetObsShapeFn(Protocol[C_contra]):
"""Return the shape of observations produced by the environment.
Parameters
----------
config : EnvironmentConfig
Environment configuration (structural info like state dimensions).
Returns
-------
shape: tuple[int, ...]
Shape tuple for observations (e.g., ``(4,)`` for CartPole).
"""
def __call__(self, config: C_contra) -> tuple[int, ...]: ...
[docs]
class ResetFn(Protocol[S_co, P_contra, C_contra, Obs_co]):
"""Reset the environment to an initial state.
Parameters
----------
key : PRNGKey
JAX PRNG key for stochastic initialization.
params : EnvironmentParams
Dynamic environment parameters.
config : EnvironmentConfig
Static environment configuration.
Returns
-------
tuple[:class:`~myriad.core.types.Observation`, EnvironmentState]
Initial observation and initial environment state.
"""
def __call__(self, key: PRNGKey, params: P_contra, config: C_contra) -> tuple[Obs_co, S_co]: ...
[docs]
class StepFn(Protocol[S_inv, P_contra, C_contra, Obs_co]):
"""Advance the environment by one timestep.
Parameters
----------
key : PRNGKey
JAX PRNG key for stochastic transitions.
state : EnvironmentState
Current environment state.
action : Array
Action to execute.
params : EnvironmentParams
Dynamic environment parameters.
config : EnvironmentConfig
Static environment configuration.
Returns
-------
tuple[Observation, EnvironmentState, Array, Array, dict[str, Any]]
A 5-tuple containing:
- **next_obs** -- :class:`~myriad.core.types.Observation` after the transition.
- **next_state** -- Updated environment state.
- **reward** -- Scalar reward signal.
- **done** -- Boolean termination flag.
- **info** -- Auxiliary information dictionary.
"""
def __call__(
self,
key: PRNGKey,
state: S_inv,
action: Array,
params: P_contra,
config: C_contra,
) -> tuple[Obs_co, S_inv, Array, Array, dict[str, Any]]: ...
# ---------------------------------------------------------------------------
# Environment container
# ---------------------------------------------------------------------------
[docs]
class Environment(NamedTuple, Generic[S, C, P, Obs]):
"""Typed container for a JAX-friendly environment's pure functions.
This is a generic class parameterized by:
- ``S`` -- The :class:`EnvironmentState` type.
- ``C`` -- The :class:`EnvironmentConfig` type.
- ``P`` -- The :class:`EnvironmentParams` type.
- ``Obs`` -- The :class:`~myriad.core.types.Observation` type.
Attributes
----------
config : C
Static configuration used as ``static_argnames`` when jitting functions.
Changes to config require recompilation.
params : P
Dynamic parameters that can vary between runs without recompilation.
Passed as a regular (non-static) argument to ``step``/``reset``
functions.
get_action_space : GetActionSpaceFn[C]
Pure function returning the action space specification.
get_obs_shape : GetObsShapeFn[C]
Pure function returning the observation shape tuple.
reset : ResetFn[S, P, C, Obs]
Pure function to reset the environment to an initial state.
step : StepFn[S, P, C, Obs]
Pure function to advance the environment by one timestep.
Examples
--------
When jitting environment functions:
.. code-block:: python
step = jax.jit(_step, static_argnames=["config"])
reset = jax.jit(_reset, static_argnames=["config"])
This allows ``params`` to vary without recompilation while keeping
``config`` static.
"""
config: C
params: P
# Action / observation helpers
get_action_space: GetActionSpaceFn[C]
get_obs_shape: GetObsShapeFn[C]
# Pure, jitted environment functions
reset: ResetFn[S, P, C, Obs]
step: StepFn[S, P, C, Obs]