diff --git a/vizier/_src/algorithms/designers/gp/acquisitions.py b/vizier/_src/algorithms/designers/gp/acquisitions.py index 96fab621d..f184c0a98 100644 --- a/vizier/_src/algorithms/designers/gp/acquisitions.py +++ b/vizier/_src/algorithms/designers/gp/acquisitions.py @@ -170,6 +170,50 @@ def score_with_aux( # Vizier library acquisition functions use `flax.struct`, instead of `attrs` and # a hash function, so that acquisition functions can be passed as args to JIT-ed # functions without triggering retracing when attribute values change. +@struct.dataclass +class MEANPENALIZED(AcquisitionFunction): + """Mean with relative region AcquisitionFunction.""" + + coefficient: float = 1.8 + max_lcb_thresh: float = -jnp.inf + + def __call__( + self, + dist: tfd.Distribution, + seed: Optional[jax.random.KeyArray] = None, + ) -> jax.Array: + del seed + ucb = dist.mean() + self.coefficient * dist.stddev() + acquisition = jnp.where( + (ucb >= self.max_lcb_thresh), + dist.mean(), + -1e12 - ucb, + ) + return acquisition + + +@struct.dataclass +class STDPENALIZED(AcquisitionFunction): + """Standard Deviation with relative region AcquisitionFunction.""" + + coefficient: float = 1.8 + max_lcb_thresh: float = -jnp.inf + + def __call__( + self, + dist: tfd.Distribution, + seed: Optional[jax.random.KeyArray] = None, + ) -> jax.Array: + del seed + ucb = dist.mean() + self.coefficient * dist.stddev() + acquisition = jnp.where( + (ucb >= self.max_lcb_thresh), + dist.stddev(), + -1e12 - ucb, + ) + return acquisition + + @struct.dataclass class UCB(AcquisitionFunction): """UCB AcquisitionFunction."""