"""JAX-friendly space definitions for RL environments."""fromtypingimportAny,Tupleimportjaximportjax.numpyasjnpfromjaximportArrayfrom.typesimportPRNGKey
[docs]defsample(self,key:PRNGKey)->Array:"""Sample uniformly from the box."""returnjnp.asarray(jax.random.uniform(key,shape=self.shape,minval=self.low,maxval=self.high,dtype=self.dtype))
[docs]defcontains(self,x:Array)->bool:"""Check if x is within bounds."""returnbool(jnp.all(x>=self.low)andjnp.all(x<=self.high))
[docs]classDiscrete(Space):"""A finite set of integer actions {0, 1, ..., n-1}."""
[docs]def__init__(self,n:int,dtype:Any=jnp.int32):ifn<=0:raiseValueError("Discrete space size must be positive")self.n=int(n)self.dtype=dtypeself.shape:Tuple[int,...]=()
[docs]defsample(self,key:PRNGKey)->Array:"""Sample uniformly from the discrete set."""returnjax.random.randint(key,shape=self.shape,minval=0,maxval=self.n,dtype=self.dtype)
[docs]defcontains(self,x:Array)->bool:"""Check if x is a valid discrete value."""value=jnp.asarray(x)ifvalue.ndim!=0:returnFalsereturnbool((value>=0)&(value<self.n))