"""Control task wrapper for CcaS-CcaR + GFP gene circuit.
Standard tracking task: Control GFP expression (F) to match a target trajectory.
Reward: Negative absolute error between F and F_target.
Task variants:
- Constant target: Fixed GFP level (default: F_target = 25)
- Sinewave target: Time-varying sinusoidal trajectory
"""
from typing import Any, NamedTuple
import jax
import jax.numpy as jnp
from flax import struct
from jax import Array
from myriad.core.spaces import Discrete
from myriad.core.types import PRNGKey
from myriad.envs.environment import Environment
from myriad.utils import filter_kwargs
from ..physics import PhysicsConfig, PhysicsParams, PhysicsParamsPrior, PhysicsState
from .base import (
BaseCcasrGfpTaskConfig,
CcasrGfpControlObs,
TaskConfig,
check_termination,
generate_constant_target,
generate_sinewave_target,
get_action_space as _get_action_space,
sample_initial_physics,
step_physics_interval,
)
[docs]
class ControlTaskState(NamedTuple):
"""State for the control task.
Attributes:
physics: The underlying physics state (time, H, F)
t: Current timestep counter (RL timesteps, not Gillespie time)
U: Previous action (light input from last timestep, for action-toggle detection)
F_target: Target trajectory for GFP expression [current, t+1, ..., t+n_horizon]
"""
physics: PhysicsState
t: Array
U: Array
F_target: Array
[docs]
@struct.dataclass
class ControlTaskConfig(BaseCcasrGfpTaskConfig):
"""Configuration for the CcaS-CcaR control task."""
# Target generation
target_type: str = "constant" # "constant" or "sinewave"
n_horizon: int = 1 # Number of future timesteps to include in observation
# Constant target parameters
F_target_constant: float = 25.0
# Sinewave target parameters
sinewave_period_minutes: float = 600.0 # 10 hours
sinewave_amplitude: float = 20.0
sinewave_vshift: float = 30.0
@property
def dt(self) -> float:
"""Timestep duration in minutes."""
return self.physics.timestep_minutes
[docs]
@struct.dataclass
class ControlTaskParams:
"""Parameters for the control task."""
physics: PhysicsParams = struct.field(default_factory=PhysicsParams)
[docs]
@struct.dataclass
class ControlTaskParamsPrior:
"""Prior distribution over control task parameters."""
physics: PhysicsParamsPrior = struct.field(default_factory=PhysicsParamsPrior)
[docs]
def sample(self, key: "PRNGKey") -> ControlTaskParams:
return ControlTaskParams(physics=self.physics.sample(key))
def _step(
key: PRNGKey,
state: ControlTaskState,
action: Array,
params: ControlTaskParams,
config: ControlTaskConfig,
) -> tuple[CcasrGfpControlObs, ControlTaskState, Array, Array, dict[str, Any]]:
"""Step the control task forward one timestep.
Args:
key: RNG key for stochastic physics simulation
state: Current task state
action: Discrete action {0, 1} for light input
params: Task parameters
config: Task configuration (static)
Returns:
obs_next: Next observation
next_state: Next task state
reward: Reward (negative absolute error)
done: Termination flag (1.0 if done, 0.0 otherwise)
info: Dict with current protein levels for logging
"""
key_physics, key_target = jax.random.split(key)
next_physics, t_next = step_physics_interval(
key_physics, state.physics, state.t, state.U, action, params.physics, config.physics
)
# config.target_type is static — resolve at trace time with plain if/else
if config.target_type == "sinewave":
F_target_next = generate_sinewave_target(
key_target,
t_next,
config.n_horizon,
config.physics.timestep_minutes,
config.sinewave_period_minutes,
config.sinewave_amplitude,
config.sinewave_vshift,
)
else:
F_target_next = generate_constant_target(config.n_horizon, config.F_target_constant)
# Check termination
done = check_termination(t_next, config.task)
# Compute reward (negative absolute error)
# Use the current target (first element of F_target array)
reward = -jnp.abs(next_physics.F - state.F_target[0])
# Create next state (store current action as U for next step's toggle detection)
next_state = ControlTaskState(physics=next_physics, t=t_next, U=action, F_target=F_target_next)
# Extract observation
obs_next = get_obs(next_state, params, config)
# Info dict for logging
info = {
"F": next_physics.F,
"H": next_physics.H,
"F_target": state.F_target[0],
}
return obs_next, next_state, reward, done, info
def _reset(
key: PRNGKey,
params: ControlTaskParams,
config: ControlTaskConfig,
) -> tuple[CcasrGfpControlObs, ControlTaskState]:
"""Reset the control task to initial state.
Initializes the system at zero protein concentrations and generates initial target.
Args:
key: RNG key for random initialization
params: Task parameters
config: Task configuration (static)
Returns:
obs: Initial observation (CcasCcarControlObs with named fields)
state: Initial task state
"""
key_physics, key_target = jax.random.split(key)
# Sample initial physics state (zero concentrations)
physics = sample_initial_physics(key_physics)
# config.target_type is static — resolve at trace time with plain if/else
if config.target_type == "sinewave":
F_target = generate_sinewave_target(
key_target,
jnp.array(0),
config.n_horizon,
config.physics.timestep_minutes,
config.sinewave_period_minutes,
config.sinewave_amplitude,
config.sinewave_vshift,
)
else:
F_target = generate_constant_target(config.n_horizon, config.F_target_constant)
# Initialize U=0 (no light). First action toggle (if any) will reset time to 0.
state = ControlTaskState(physics=physics, t=jnp.array(0), U=jnp.array(0), F_target=F_target)
obs = get_obs(state, params, config)
return obs, state
# JIT the step and reset functions with config as static argument
step = jax.jit(_step, static_argnames=["config"])
reset = jax.jit(_reset, static_argnames=["config"])
[docs]
def get_obs(
state: ControlTaskState,
params: ControlTaskParams,
config: ControlTaskConfig,
) -> CcasrGfpControlObs:
"""Extract observation from state.
Returns a structured observation with named fields for semantic access
by classical controllers. Neural network agents can call `.to_array()`.
Args:
state: Current task state
params: Task parameters (unused)
config: Task configuration
Returns:
CcasCcarControlObs with named fields (F_normalized, U_obs, F_target)
"""
# Normalize F by observation normalizer
F_normalized = state.physics.F / config.task.F_obs_normalizer
# Normalize F_target
F_target_normalized = state.F_target / config.task.F_obs_normalizer
return CcasrGfpControlObs(
F_normalized=F_normalized,
F_target=F_target_normalized,
)
[docs]
def get_obs_shape(config: ControlTaskConfig) -> tuple[int, ...]:
"""Get the shape of the observation space.
Observation: [F, F_target[0:n_horizon+1]]
Shape: (1 + n_horizon + 1,)
Args:
config: Task configuration
Returns:
Observation shape tuple
"""
return (1 + config.n_horizon + 1,)
[docs]
def get_action_space(config: ControlTaskConfig) -> Discrete:
"""Get the discrete action space for the environment.
Args:
config: Task configuration (unused)
Returns:
Discrete space with 2 actions: 0 (light off) and 1 (light on)
"""
return _get_action_space()
[docs]
def make_env(
config: ControlTaskConfig | None = None,
params: ControlTaskParams | None = None,
params_prior: ControlTaskParamsPrior | None = None,
**kwargs,
) -> Environment[ControlTaskState, ControlTaskConfig, ControlTaskParams, CcasrGfpControlObs]:
"""Create a Ccasr-gfp control task environment.
Args:
config: Custom ControlTaskConfig. If None, uses defaults.
params: Custom ControlTaskParams. If None, creates from kwargs.
params_prior: Optional prior for domain randomization. If set,
``env.sample_params_fn`` will sample distinct θ* per parallel env.
Can also be triggered via flat kwargs (e.g. ``nu_scale=0.3``).
**kwargs: Keyword arguments for creating config/params if not provided.
Returns:
Environment instance for the control task
Example:
>>> # Constant target at F=25
>>> env = make_env(F_target_constant=25.0)
>>> # Domain randomization
>>> env = make_env(nu_scale=0.3, Kh_scale=0.2)
"""
if config is None:
# Distribute flat kwargs to the appropriate nested config dataclass.
# filter_kwargs introspects dataclass fields, so routing stays in sync
# automatically when fields are added or removed.
control_kwargs = {
k: v
for k, v in filter_kwargs(kwargs, ControlTaskConfig).items()
if k not in {"physics", "task"} # nested — handled separately below
}
config = ControlTaskConfig(
physics=PhysicsConfig(**filter_kwargs(kwargs, PhysicsConfig)),
task=TaskConfig(**filter_kwargs(kwargs, TaskConfig)),
**control_kwargs,
)
if params is None:
params = ControlTaskParams(physics=PhysicsParams(**filter_kwargs(kwargs, PhysicsParams)))
if params_prior is None:
prior_kwargs = filter_kwargs(kwargs, PhysicsParamsPrior)
if prior_kwargs:
params_prior = ControlTaskParamsPrior(physics=PhysicsParamsPrior(**prior_kwargs))
sample_fn = params_prior.sample if params_prior is not None else None
return Environment(
step=step,
reset=reset,
get_action_space=get_action_space,
get_obs_shape=get_obs_shape,
params=params,
config=config,
sample_params_fn=sample_fn,
)