From a52af6779bd15fc5cb866a449006b21b54a6e287 Mon Sep 17 00:00:00 2001 From: Setareh Ariafar Date: Fri, 22 Sep 2023 08:19:25 -0700 Subject: [PATCH] extra paretopt fixes PiperOrigin-RevId: 567627034 --- .../algorithms/designers/gp/acquisitions.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/vizier/_src/algorithms/designers/gp/acquisitions.py b/vizier/_src/algorithms/designers/gp/acquisitions.py index 96fab621d..f184c0a98 100644 --- a/vizier/_src/algorithms/designers/gp/acquisitions.py +++ b/vizier/_src/algorithms/designers/gp/acquisitions.py @@ -170,6 +170,50 @@ def score_with_aux( # Vizier library acquisition functions use `flax.struct`, instead of `attrs` and # a hash function, so that acquisition functions can be passed as args to JIT-ed # functions without triggering retracing when attribute values change. +@struct.dataclass +class MEANPENALIZED(AcquisitionFunction): + """Mean with relative region AcquisitionFunction.""" + + coefficient: float = 1.8 + max_lcb_thresh: float = -jnp.inf + + def __call__( + self, + dist: tfd.Distribution, + seed: Optional[jax.random.KeyArray] = None, + ) -> jax.Array: + del seed + ucb = dist.mean() + self.coefficient * dist.stddev() + acquisition = jnp.where( + (ucb >= self.max_lcb_thresh), + dist.mean(), + -1e12 - ucb, + ) + return acquisition + + +@struct.dataclass +class STDPENALIZED(AcquisitionFunction): + """Standard Deviation with relative region AcquisitionFunction.""" + + coefficient: float = 1.8 + max_lcb_thresh: float = -jnp.inf + + def __call__( + self, + dist: tfd.Distribution, + seed: Optional[jax.random.KeyArray] = None, + ) -> jax.Array: + del seed + ucb = dist.mean() + self.coefficient * dist.stddev() + acquisition = jnp.where( + (ucb >= self.max_lcb_thresh), + dist.stddev(), + -1e12 - ucb, + ) + return acquisition + + @struct.dataclass class UCB(AcquisitionFunction): """UCB AcquisitionFunction."""