Source code for myriad.envs.bio.ccasr_gfp.physics

"""Pure stateless physics for the CcaS-CcaR + GFP with self-activation gene circuit system.

This module contains the ground truth stochastic dynamics for the bi-stable genetic
circuit, using the Gillespie algorithm for exact stochastic simulation.

The physics is completely decoupled from any task-specific logic (rewards, terminations, observations).
It can be reused by different tasks (control, SysID, etc.).

System Description:
    Light Input (U) → CcaSR (H) → GFP (F) with autoactivation feedback

    Five Chemical Reactions:
    1. CcaSR activation: ∅ → H  (rate: eta * U)
    2. CcaSR deactivation: H → ∅  (rate: nu * H)
    3. F creation from H: ∅ → F  (rate: 0.5 * a * H^nh / (Kh^nh + H^nh))
    4. F self-activation: ∅ → F  (rate: 0.5 * a * F^nf / (Kf^nf + F^nf))
    5. F dilution: F → ∅  (rate: nu * F)

Reference:
    Based on the bi-stable genetic circuit model from:
    "Control of a Bi-Stable Genetic System via Parallelized Reinforcement Learning"
    CDC 2025, https://gitlab.com/lugagnelab/pqn-control-cdc2025
"""

import math
from typing import NamedTuple

import chex
import jax
import jax.numpy as jnp
from crn_jax.gillespie import simulate_interval
from crn_jax.kinetics import hill_function, sample_lognormal
from flax import struct
from jax import Array

from myriad.core.types import PRNGKey


[docs] class PhysicsState(NamedTuple): """Pure physical state of the gene circuit system. Attributes: time: Current simulation time (minutes) H: CcaS-CcaR protein concentration (molecules) F: GFP reporter protein concentration (molecules) next_reaction_time: Scheduled time of next reaction (minutes). Set to inf when no reaction is pending (sample fresh). Preserved across RL step boundaries for physical accuracy. """ time: Array H: Array F: Array next_reaction_time: Array
[docs] def to_array(self) -> Array: """Convert to flat array for NN-based agents. Note: next_reaction_time is excluded as it's internal bookkeeping. Returns: Array of shape (3,) with [time, H, F] """ return jnp.stack([self.time, self.H, self.F])
[docs] @classmethod def from_array(cls, arr: Array) -> "PhysicsState": """Create state from flat array. Args: arr: Array of shape (3,) with [time, H, F] Returns: PhysicsState instance (next_reaction_time defaults to inf) """ chex.assert_shape(arr, (3,)) return cls( time=arr[0], # type: ignore H=arr[1], # type: ignore F=arr[2], # type: ignore next_reaction_time=jnp.array(jnp.inf), )
[docs] @classmethod def create( cls, time: Array, H: Array, F: Array, next_reaction_time: Array | None = None, ) -> "PhysicsState": """Factory method to create PhysicsState with default next_reaction_time. Args: time: Current simulation time H: CcaS-CcaR protein concentration F: GFP reporter protein concentration next_reaction_time: Optional pending reaction time (defaults to inf) Returns: PhysicsState instance """ if next_reaction_time is None: next_reaction_time = jnp.array(jnp.inf) return cls(time=time, H=H, F=F, next_reaction_time=next_reaction_time)
[docs] @struct.dataclass class PhysicsConfig: """Static structural constants for the gene circuit system. These are compile-time constants passed as static_argnames to jit. They define the experimental platform and circuit architecture — values that never change between cells or experiments. Kinetic parameters (nu, Kh, nh, Kf, nf) belong in PhysicsParams because they vary between cells and are the targets of system identification. """ # Time discretization (set by measurement interval) timestep_minutes: float = 5.0 # Physical timestep (minutes) # Gillespie algorithm safety cap max_gillespie_steps: int = 10000 # Maximum reactions per timestep
[docs] @struct.dataclass class PhysicsParams: """Dynamic physics parameters — vmappable over parameter space. These are the kinetic parameters that vary between cells (domain randomization) or are unknown and must be inferred (system identification). """ nu: float | Array = 0.01 # Protein dilution rate (1/min) Kh: float | Array = 90.0 # CcaR Hill half-max concentration nh: float | Array = 3.6 # CcaR Hill cooperativity coefficient Kf: float | Array = 30.0 # GFP self-activation half-max concentration nf: float | Array = 3.6 # GFP Hill coefficient
[docs] @struct.dataclass class PhysicsParamsPrior: """Log-normal prior over kinetic parameters. Each parameter p is sampled as: p ~ exp(Normal(loc, scale)). With scale=0 the distribution collapses to a point mass at exp(loc), so the default (all scales zero) is fully deterministic and backward compatible. """ nu_loc: float | Array = math.log(0.01) nu_scale: float | Array = 0.0 Kh_loc: float | Array = math.log(90.0) Kh_scale: float | Array = 0.0 nh_loc: float | Array = math.log(3.6) nh_scale: float | Array = 0.0 Kf_loc: float | Array = math.log(30.0) Kf_scale: float | Array = 0.0 nf_loc: float | Array = math.log(3.6) nf_scale: float | Array = 0.0
[docs] def sample(self, key: PRNGKey) -> PhysicsParams: k_nu, k_Kh, k_nh, k_Kf, k_nf = jax.random.split(key, 5) return PhysicsParams( nu=sample_lognormal(k_nu, self.nu_loc, self.nu_scale), Kh=sample_lognormal(k_Kh, self.Kh_loc, self.Kh_scale), nh=sample_lognormal(k_nh, self.nh_loc, self.nh_scale), Kf=sample_lognormal(k_Kf, self.Kf_loc, self.Kf_scale), nf=sample_lognormal(k_nf, self.nf_loc, self.nf_scale), )
[docs] def compute_propensities( state: PhysicsState, action: Array, params: PhysicsParams, ) -> Array: """Compute reaction propensities (rates) for all five reactions. Args: state: Current physical state (time, H, F) action: Discrete action {0, 1} representing light input U params: Kinetic parameters (nu, Kh, nh, Kf, nf) — vmappable Returns: Array of 5 propensities for reactions [R1, R2, R3, R4, R5] """ H = state.H F = state.F U = action # Reaction 1: CcaSR activation (∅ → H) r1 = U # Reaction 2: CcaSR deactivation (H → ∅) r2 = params.nu * H # Reaction 3: F creation from H (∅ → F) r3 = 0.5 * hill_function(H, params.Kh, params.nh) # Reaction 4: F self-activation (∅ → F) r4 = 0.5 * hill_function(F, params.Kf, params.nf) # Reaction 5: F dilution (F → ∅) r5 = params.nu * F return jnp.array([r1, r2, r3, r4, r5])
[docs] def apply_reaction(state: PhysicsState, reaction_idx: Array) -> PhysicsState: """Apply a single reaction to update the state. Uses jax.lax.switch for JAX-compatible control flow. Args: state: Current physical state reaction_idx: Index of reaction to apply (0-4) Returns: Updated physical state after reaction """ def reaction_0(s): """Reaction 1: ∅ → H (CcaSR activation)""" return s._replace(H=s.H + 1) def reaction_1(s): """Reaction 2: H → ∅ (CcaSR deactivation)""" return s._replace(H=jnp.maximum(s.H - 1, 0)) def reaction_2(s): """Reaction 3: ∅ → F (F creation from H)""" return s._replace(F=s.F + 1) def reaction_3(s): """Reaction 4: ∅ → F (F self-activation)""" return s._replace(F=s.F + 1) def reaction_4(s): """Reaction 5: F → ∅ (F dilution)""" return s._replace(F=jnp.maximum(s.F - 1, 0)) branches = [reaction_0, reaction_1, reaction_2, reaction_3, reaction_4] return jax.lax.switch(reaction_idx, branches, state)
[docs] def step_physics( key: PRNGKey, state: PhysicsState, action: Array, params: PhysicsParams, config: PhysicsConfig, previous_action: Array, interval_start: Array, ) -> PhysicsState: """Pure physics step using a discrete Gillespie algorithm. Runs Gillespie simulation from current time until the end of the current interval (``interval_start + timestep_minutes``). Intervals are at fixed absolute times (0, 5, 10, 15...), matching the physical setup where observations and actions occur at regular intervals. Args: key: RNG key for stochastic simulation state: Current physical state (``time``, ``H``, ``F``, ``next_reaction_time``) action: Discrete action ``{0, 1}`` representing light input params: Dynamic parameters config: Static physics constants previous_action: Action from previous timestep. If different from action, the pending reaction time is invalidated (propensities changed). interval_start: Start time of current interval (``t * timestep_minutes``). Returns: Next physical state after simulating until interval end """ def propensities(s: PhysicsState, u: Array) -> Array: return compute_propensities(s, u, params) return simulate_interval( key=key, state=state, input=action, timestep=config.timestep_minutes, max_steps=config.max_gillespie_steps, compute_propensities_fn=propensities, apply_reaction_fn=apply_reaction, previous_input=previous_action, interval_start=interval_start, )