Source code for myriad.envs.bio.ccasr_gfp.tasks.base

"""Shared utilities for CcaS-CcaR + GFP gene circuit task wrappers."""

from typing import NamedTuple

import chex
import jax.numpy as jnp
from flax import struct
from jax import Array

from myriad.core.spaces import Discrete
from myriad.core.types import PRNGKey

from ..physics import PhysicsConfig, PhysicsParams, PhysicsState, step_physics


[docs] class CcasrGfpControlObs(NamedTuple): """CcaS-CcaR control task observation with named fields. Note: This is a partially observable system. The agent does not directly observe the light input (U) or the CcaSR concentration (H). Attributes: F_normalized: GFP fluorescence normalized by F_obs_normalizer F_target: Target trajectory [current, t+1, ..., t+n_horizon] """ F_normalized: Array F_target: Array
[docs] def to_array(self) -> Array: """Convert to flat array for NN-based agents. Returns: Array of shape (1 + n_horizon + 1,) with [F, F_target...] """ return jnp.concatenate([jnp.array([self.F_normalized]), self.F_target])
[docs] @classmethod def from_array(cls, arr: Array) -> "CcasrGfpControlObs": """Create observation from flat array. Args: arr: Array of shape (1 + n_horizon + 1,) with [F, F_target...] Returns: CcasCcarControlObs instance """ chex.assert_rank(arr, 1) return cls( F_normalized=arr[0], # type: ignore F_target=arr[1:], # type: ignore )
[docs] @struct.dataclass class TaskConfig: """Base configuration shared by all CcaS-CcaR tasks. These define the task-specific episode limits and observation normalization. """ max_steps: int = 288 # 288 steps * 5 min/step = 24 hours F_obs_normalizer: float = 80.0 # Normalization constant for F observations
[docs] @struct.dataclass class BaseCcasrGfpTaskConfig: """Shared config fields for all CcaS-CcaR task wrappers. Provides physics config, task config, and the max_steps protocol property in one place so individual task configs don't repeat them. """ physics: PhysicsConfig = struct.field(default_factory=PhysicsConfig) task: TaskConfig = struct.field(default_factory=TaskConfig) @property def max_steps(self) -> int: """Required by EnvironmentConfig protocol.""" return self.task.max_steps
[docs] def step_physics_interval( key: PRNGKey, physics: PhysicsState, t: Array, prev_action: Array, action: Array, params: PhysicsParams, config: PhysicsConfig, ) -> tuple[PhysicsState, Array]: """Step physics forward one interval and return (next_physics, t + 1). Centralises the interval_start calculation and step_physics call that every task _step function would otherwise duplicate. """ interval_start = t * config.timestep_minutes next_physics = step_physics( key, physics, action, params, config, previous_action=prev_action, interval_start=interval_start, ) return next_physics, t + 1
[docs] def check_termination(t: Array, task_config: TaskConfig) -> Array: """Common termination check for CcaS-CcaR tasks. The episode terminates when maximum timesteps is reached. Args: t: Current timestep counter task_config: TaskConfig with max_steps Returns: done: 1.0 if terminated, 0.0 otherwise (as float for JAX compatibility) """ return (t >= task_config.max_steps).astype(jnp.float32)
[docs] def get_action_space() -> Discrete: """Get the discrete action space for CcaS-CcaR. Returns: Discrete space with 2 actions: 0 (red light) and 1 (green light) """ return Discrete(n=2)
[docs] def sample_initial_physics(key: PRNGKey) -> PhysicsState: """Sample initial physics state. We start from zero proteins at time 0. This represents the initial state before any light input. Args: key: RNG key for random initialization (unused) Returns: PhysicsState initialized to zero concentrations """ return PhysicsState.create( time=jnp.array(0.0), H=jnp.array(0.0), F=jnp.array(0.0), )
[docs] def generate_constant_target( n_horizon: int, F_target_value: float, ) -> Array: """Generate a constant target trajectory. Args: n_horizon: Number of future timesteps to predict F_target_value: Constant target value for F Returns: Array of shape (n_horizon + 1,) with constant target values """ return jnp.full(n_horizon + 1, F_target_value, dtype=jnp.float32)
[docs] def generate_sinewave_target( key: PRNGKey, t: Array, n_horizon: int, timestep_minutes: float, period_minutes: float = 600.0, amplitude: float = 20.0, vshift: float = 30.0, ) -> Array: """Generate a sinusoidal target trajectory. Creates a time-varying target that follows a sine wave pattern. Used for testing tracking performance on dynamic targets. Args: key: RNG key for random phase initialization t: Current timestep counter n_horizon: Number of future timesteps to predict timestep_minutes: Duration of each RL timestep in minutes period_minutes: Period of the sine wave in minutes (default: 600 = 10 hours) amplitude: Amplitude of the sine wave (default: 20) vshift: Vertical shift / DC offset (default: 30) Returns: Array of shape (n_horizon + 1,) with sinusoidal target values """ # Convert timestep to actual time in minutes current_time_minutes = t * timestep_minutes # Generate future time points future_steps = jnp.arange(n_horizon + 1) future_times = current_time_minutes + future_steps * timestep_minutes # Compute sine wave: vshift + amplitude * sin(2π * time / period) omega = 2.0 * jnp.pi / period_minutes targets = vshift + amplitude * jnp.sin(omega * future_times) return targets