Agent

Overview

from myriad.agents import make_agent

agent = make_agent("dqn", action_space=env.action_space)
state = agent.init(key, sample_obs, agent.params)
action, state = agent.select_action(key, obs, state, agent.params, deterministic=False)

Available Agents

ID

Category

Agent

Description

random

Classical

Random

Uniform random action selection

bangbang

Classical

Bang-Bang

Threshold-based bang-bang controller

pid

Classical

PID

Proportional-Integral-Derivative controller

dqn

RL

DQN

Deep Q-Network (discrete actions)

pqn

RL

PQN

Parallelized Q-Network (on-policy)

Factory Function

myriad.agents.make_agent(name, **kwargs)[source]

Create an agent instance by name.

Parameters:
  • name (str) – Unique identifier for the agent.

  • **kwargs (Any) – Keyword arguments passed to the agent’s factory function.

Returns:

An instance of the requested Agent.

Raises:

ValueError – If the agent name is not found in the registry.

Return type:

Any

Base Protocols

Base agent definitions for JAX-based agents.

This module provides small, focused Protocols for the two agent components (params, state) and a typed container Agent which holds the agent’s pure functions. The Protocols are intentionally small and permissive so concrete environments remain free to use dataclasses, Flax structs, NamedTuples, etc., while still providing helpful static typing and documentation.

class myriad.agents.agent.AgentParams(*args, **kwargs)[source]

Bases: Protocol

Protocol for agent parameter objects.

Concrete agents can use dataclasses, Flax structs, or simple NamedTuples.

action_space: Space
__init__(*args, **kwargs)
class myriad.agents.agent.AgentState(*args, **kwargs)[source]

Bases: Protocol

Protocol for agent state objects.

As with AgentParams, this is a marker Protocol. A state should be something JAX can transform (e.g., a NamedTuple or a pytree-compatible dataclass), but the Protocol leaves that choice to the implementation.

__init__(*args, **kwargs)
class myriad.agents.agent.InitFn(*args, **kwargs)[source]

Bases: Protocol[S_co, P_contra, Obs_contra]

Initialize the agent’s state.

Parameters:
  • key – JAX PRNG key for stochastic initialization (e.g., network weights)

  • sample_obs – Sample observation to infer network architecture and field names

  • params – Agent hyperparameters (learning rate, network architecture, etc.)

Returns:

Initialized agent state (e.g., network parameters, optimizer state)

Return type:

S

__init__(*args, **kwargs)
class myriad.agents.agent.SelectActionFn(*args, **kwargs)[source]

Bases: Protocol[S_inv, P_contra, Obs_contra]

Select an action given the current observation.

Parameters:
  • key – JAX PRNG key for stochastic action selection (e.g., epsilon-greedy)

  • obs – Current observation from the environment

  • state – Current agent state (e.g., network parameters)

  • params – Agent hyperparameters

  • deterministic – If True, select the greedy/deterministic action (e.g., for evaluation). If False, sample from the policy distribution (e.g., for exploration).

Returns:

Selected action and (possibly updated) agent state

Return type:

tuple[Array, S]

__init__(*args, **kwargs)
class myriad.agents.agent.UpdateFn(*args, **kwargs)[source]

Bases: Protocol[S_inv, P_contra]

Update the agent’s state from a batch of experience.

Parameters:
  • key – JAX PRNG key for stochastic updates (e.g., dropout, minibatch sampling)

  • state – Current agent state to update

  • batch – Batch of experience data (structure depends on the agent/algorithm)

  • params – Agent hyperparameters

Returns:

Updated agent state and a metrics dictionary (e.g., loss values)

Return type:

tuple[S, dict[str, Any]]

__init__(*args, **kwargs)
class myriad.agents.agent.Agent(params, init, select_action, update)[source]

Bases: NamedTuple, Generic[S, P, Obs]

Typed container for a JAX-friendly agent’s pure functions.

Variables:
  • params (myriad.agents.agent.P) – Agent hyperparameters (learning rate, network config, action space, etc.).

  • init (myriad.agents.agent.InitFn[myriad.agents.agent.S, myriad.agents.agent.P, myriad.agents.agent.Obs]) – Pure function to initialize the agent’s state.

  • select_action (myriad.agents.agent.SelectActionFn[myriad.agents.agent.S, myriad.agents.agent.P, myriad.agents.agent.Obs]) – Pure function to select an action from the agent’s policy.

  • update (myriad.agents.agent.UpdateFn[myriad.agents.agent.S, myriad.agents.agent.P]) – Pure function to update the agent’s state from experience.

params: P

Alias for field number 0

init: InitFn[S, P, Obs]

Alias for field number 1

select_action: SelectActionFn[S, P, Obs]

Alias for field number 2

update: UpdateFn[S, P]

Alias for field number 3

classmethod __class_getitem__(params)

Parameterizes a generic class.

At least, parameterizing a generic class is the main thing this method does. For example, for some generic class Foo, this is called when we do Foo[int] - there, with cls=Foo and params=int.

However, note that this method is also called when defining generic classes in the first place with class Foo(Generic[T]): ….