Agent¶
Overview¶
from myriad.agents import make_agent
agent = make_agent("dqn", action_space=env.action_space)
state = agent.init(key, sample_obs, agent.params)
action, state = agent.select_action(key, obs, state, agent.params, deterministic=False)
Available Agents¶
ID |
Category |
Agent |
Description |
|---|---|---|---|
|
Uniform random action selection |
||
|
Threshold-based bang-bang controller |
||
|
Proportional-Integral-Derivative controller |
||
|
Deep Q-Network (discrete actions) |
||
|
Parallelized Q-Network (on-policy) |
Factory Function¶
- myriad.agents.make_agent(name, **kwargs)[source]¶
Create an agent instance by name.
- Parameters:
- Returns:
An instance of the requested Agent.
- Raises:
ValueError – If the agent name is not found in the registry.
- Return type:
Base Protocols¶
Base agent definitions for JAX-based agents.
This module provides small, focused Protocols for the two agent components (params, state) and a typed container Agent which holds the agent’s pure functions. The Protocols are intentionally small and permissive so concrete environments remain free to use dataclasses, Flax structs, NamedTuples, etc., while still providing helpful static typing and documentation.
- class myriad.agents.agent.AgentParams(*args, **kwargs)[source]¶
Bases:
ProtocolProtocol for agent parameter objects.
Concrete agents can use dataclasses, Flax structs, or simple NamedTuples.
- __init__(*args, **kwargs)¶
- class myriad.agents.agent.AgentState(*args, **kwargs)[source]¶
Bases:
ProtocolProtocol for agent state objects.
As with AgentParams, this is a marker Protocol. A state should be something JAX can transform (e.g., a NamedTuple or a pytree-compatible dataclass), but the Protocol leaves that choice to the implementation.
- __init__(*args, **kwargs)¶
- class myriad.agents.agent.InitFn(*args, **kwargs)[source]¶
Bases:
Protocol[S_co,P_contra,Obs_contra]Initialize the agent’s state.
- Parameters:
key – JAX PRNG key for stochastic initialization (e.g., network weights)
sample_obs – Sample observation to infer network architecture and field names
params – Agent hyperparameters (learning rate, network architecture, etc.)
- Returns:
Initialized agent state (e.g., network parameters, optimizer state)
- Return type:
S
- __init__(*args, **kwargs)¶
- class myriad.agents.agent.SelectActionFn(*args, **kwargs)[source]¶
Bases:
Protocol[S_inv,P_contra,Obs_contra]Select an action given the current observation.
- Parameters:
key – JAX PRNG key for stochastic action selection (e.g., epsilon-greedy)
obs – Current observation from the environment
state – Current agent state (e.g., network parameters)
params – Agent hyperparameters
deterministic – If True, select the greedy/deterministic action (e.g., for evaluation). If False, sample from the policy distribution (e.g., for exploration).
- Returns:
Selected action and (possibly updated) agent state
- Return type:
tuple[Array,S]
- __init__(*args, **kwargs)¶
- class myriad.agents.agent.UpdateFn(*args, **kwargs)[source]¶
Bases:
Protocol[S_inv,P_contra]Update the agent’s state from a batch of experience.
- Parameters:
key – JAX PRNG key for stochastic updates (e.g., dropout, minibatch sampling)
state – Current agent state to update
batch – Batch of experience data (structure depends on the agent/algorithm)
params – Agent hyperparameters
- Returns:
Updated agent state and a metrics dictionary (e.g., loss values)
- Return type:
tuple[S,dict[str,Any]]
- __init__(*args, **kwargs)¶
- class myriad.agents.agent.Agent(params, init, select_action, update)[source]¶
Bases:
NamedTuple,Generic[S,P,Obs]Typed container for a JAX-friendly agent’s pure functions.
- Variables:
params (
myriad.agents.agent.P) – Agent hyperparameters (learning rate, network config, action space, etc.).init (
myriad.agents.agent.InitFn[myriad.agents.agent.S,myriad.agents.agent.P,myriad.agents.agent.Obs]) – Pure function to initialize the agent’s state.select_action (
myriad.agents.agent.SelectActionFn[myriad.agents.agent.S,myriad.agents.agent.P,myriad.agents.agent.Obs]) – Pure function to select an action from the agent’s policy.update (
myriad.agents.agent.UpdateFn[myriad.agents.agent.S,myriad.agents.agent.P]) – Pure function to update the agent’s state from experience.
- params: P¶
Alias for field number 0
- select_action: SelectActionFn[S, P, Obs]¶
Alias for field number 2
- classmethod __class_getitem__(params)¶
Parameterizes a generic class.
At least, parameterizing a generic class is the main thing this method does. For example, for some generic class Foo, this is called when we do Foo[int] - there, with cls=Foo and params=int.
However, note that this method is also called when defining generic classes in the first place with class Foo(Generic[T]): ….