diff --git a/vizier/_src/algorithms/designers/gp/acquisitions.py b/vizier/_src/algorithms/designers/gp/acquisitions.py index 71add9d6d..56f0cc0bd 100644 --- a/vizier/_src/algorithms/designers/gp/acquisitions.py +++ b/vizier/_src/algorithms/designers/gp/acquisitions.py @@ -196,6 +196,26 @@ def __call__( return tfp_bo.acquisition.GaussianProcessExpectedImprovement(dist, labels)() +@struct.dataclass +class ExpectedValue(AcquisitionFunction): + """Expected Improvement acquisition function.""" + + def __call__( + self, + dist: tfd.Distribution, + features: Optional[types.ModelInput] = None, + labels: Optional[types.PaddedArray] = None, + seed: Optional[jax.random.KeyArray] = None, + ) -> jax.Array: + del features, seed + if labels is not None: + labels = labels.replace_fill_value(-np.inf).padded_array + return ( + tfp_bo.acquisition.GaussianProcessExpectedImprovement(dist, labels)() + + dist.mean() + ) + + @struct.dataclass class PI(AcquisitionFunction): """Probability of Improvement acquisition function.""" @@ -245,6 +265,16 @@ def default_ucb_pi(cls) -> 'AcquisitionTrustRegion': UCB(1.8), PI(), bad_acq_value=-1e12, threshold=0.3, apply_tr_after=0 ) + @classmethod + def default_ucb_expected_value(cls) -> 'AcquisitionTrustRegion': + return cls( + UCB(1.8), + ExpectedValue(), + bad_acq_value=-1e12, + threshold=None, + apply_tr_after=0, + ) + @classmethod def default_ucb_lcb(cls) -> 'AcquisitionTrustRegion': return cls( @@ -289,9 +319,10 @@ def __call__( apply_tr = False if labels is not None: labels_padded = labels.replace_fill_value(np.nan).padded_array - threshold = jnp.minimum( - jnp.nanmean(labels_padded), jnp.nanmedian(labels_padded) - ) + # threshold = jnp.minimum( + # jnp.nanmean(labels_padded), jnp.nanmedian(labels_padded) + # ) + threshold = jnp.nanmedian(labels_padded) apply_tr = labels._original_shape[0] <= self.apply_tr_after if self.threshold is not None: threshold = self.threshold