Skip to content

Commit 0653bfc

Browse files
slishak-PXfacebook-github-bot
authored andcommitted
Add qPosteriorStandardDeviation acquisition function (#2634)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation This is a small collection of changes for improving support for optimisation with deterministic (posterior mean) and pure exploration (posterior std) acquisition functions: 1. Using `PosteriorMeanModel` with `optimize_acqf` is currently not supported as `PosteriorMeanModel` does not implement `num_outputs` or `batch_shape`. 2. The `PosteriorStandardDeviation` acquisition function has no MC equivalent. This PR addresses the points above, and consequentially adds support for the constrained PSTD acquisition function. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2634 Test Plan: TODO - just submitting draft for now for discussion. ## Related PRs (If this PR adds or changes functionality, please take some time to update the docs at https://github.com/pytorch/botorch, and link to your PR here.) Reviewed By: saitcakmak Differential Revision: D68713704 Pulled By: Balandat fbshipit-source-id: 3e791345a1abd0cf14247fc94178ff1e52d67988
1 parent 2144440 commit 0653bfc

File tree

6 files changed

+209
-7
lines changed

6 files changed

+209
-7
lines changed

botorch/acquisition/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@
5555
from botorch.acquisition.monte_carlo import (
5656
MCAcquisitionFunction,
5757
qExpectedImprovement,
58+
qLowerConfidenceBound,
5859
qNoisyExpectedImprovement,
60+
qPosteriorStandardDeviation,
5961
qProbabilityOfImprovement,
6062
qSimpleRegret,
6163
qUpperConfidenceBound,
@@ -120,6 +122,8 @@
120122
"qNegIntegratedPosteriorVariance",
121123
"qProbabilityOfImprovement",
122124
"qSimpleRegret",
125+
"qPosteriorStandardDeviation",
126+
"qLowerConfidenceBound",
123127
"qUpperConfidenceBound",
124128
"ConstrainedMCObjective",
125129
"GenericMCObjective",

botorch/acquisition/monte_carlo.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ class qSimpleRegret(SampleReducingMCAcquisitionFunction):
747747
non-negative. `qSimpleRegret` acquisition values can be negative, so we instead use
748748
a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before
749749
computing the acquisition function) and shifts negative objective values using
750-
by an infeasible cost to ensure non-negativity (before applying constraints and
750+
an infeasible cost to ensure non-negativity (before applying constraints and
751751
shifting them back).
752752
753753
Example:
@@ -813,11 +813,11 @@ class qUpperConfidenceBound(SampleReducingMCAcquisitionFunction):
813813
`SampleReducingMCAcquisitionFunction` computes the acquisition values on the sample
814814
level and then weights the sample-level acquisition values by a soft feasibility
815815
indicator. Hence, it expects non-log acquisition function values to be
816-
non-negative. `qSimpleRegret` acquisition values can be negative, so we instead use
817-
a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before
818-
computing the acquisition function) and shifts negative objective values using
819-
by an infeasible cost to ensure non-negativity (before applying constraints and
820-
shifting them back).
816+
non-negative. `qUpperConfidenceBound` acquisition values can be negative, so we
817+
instead use a `ConstrainedMCObjective` which applies constraints to the objectives
818+
(e.g. before computing the acquisition function) and shifts negative objective
819+
values using an infeasible cost to ensure non-negativity (before applying
820+
constraints and shifting them back).
821821
822822
Example:
823823
>>> model = SingleTaskGP(train_X, train_Y)
@@ -887,3 +887,70 @@ class qLowerConfidenceBound(qUpperConfidenceBound):
887887
def _get_beta_prime(self, beta: float) -> float:
888888
"""Multiply beta prime by -1 to get the lower confidence bound."""
889889
return -super()._get_beta_prime(beta=beta)
890+
891+
892+
class qPosteriorStandardDeviation(SampleReducingMCAcquisitionFunction):
893+
r"""MC-based batch Posterior Standard Deviation.
894+
895+
An acquisition function for pure exploration.
896+
897+
Example:
898+
>>> model = SingleTaskGP(train_X, train_Y)
899+
>>> sampler = SobolQMCNormalSampler(1024)
900+
>>> qPSTD = qPosteriorStandardDeviation(model, sampler)
901+
>>> std = qPSTD(test_X)
902+
"""
903+
904+
def __init__(
905+
self,
906+
model: Model,
907+
sampler: MCSampler | None = None,
908+
objective: MCAcquisitionObjective | None = None,
909+
posterior_transform: PosteriorTransform | None = None,
910+
X_pending: Tensor | None = None,
911+
constraints: list[Callable[[Tensor], Tensor]] | None = None,
912+
eta: Tensor | float = 1e-3,
913+
) -> None:
914+
r"""q-Posterior Standard Deviation.
915+
916+
Args:
917+
model: A fitted model.
918+
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
919+
more details.
920+
objective: The MCAcquisitionObjective under which the samples are
921+
evaluated. Defaults to `IdentityMCObjective()`.
922+
posterior_transform: A PosteriorTransform (optional).
923+
X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points that have
924+
points that have been submitted for function evaluation but have not yet
925+
been evaluated. Concatenated into X upon forward call. Copied and set to
926+
have no gradient.
927+
constraints: A list of constraint callables which map a Tensor of posterior
928+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
929+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
930+
are considered satisfied if the output is less than zero.
931+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
932+
approximation to the constraint indicators. For more details, on this
933+
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
934+
"""
935+
super().__init__(
936+
model=model,
937+
sampler=sampler,
938+
objective=objective,
939+
posterior_transform=posterior_transform,
940+
X_pending=X_pending,
941+
constraints=constraints,
942+
eta=eta,
943+
)
944+
self._scale = math.sqrt(math.pi / 2)
945+
946+
def _sample_forward(self, obj: Tensor) -> Tensor:
947+
r"""Evaluate qPosteriorStandardDeviation per sample on the candidate set `X`.
948+
949+
Args:
950+
obj: A `sample_shape x batch_shape x q`-dim Tensor of MC objective values.
951+
952+
Returns:
953+
A `sample_shape x batch_shape x q`-dim Tensor of acquisition values.
954+
"""
955+
mean = obj.mean(dim=0)
956+
return (obj - mean).abs() * self._scale

botorch/models/deterministic.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,16 @@ def __init__(self, model: Model) -> None:
162162
def forward(self, X: Tensor) -> Tensor:
163163
return self.model.posterior(X).mean
164164

165+
@property
166+
def num_outputs(self) -> int:
167+
r"""The number of outputs of the model."""
168+
return self.model.num_outputs
169+
170+
@property
171+
def batch_shape(self) -> torch.Size:
172+
r"""The batch shape of the model."""
173+
return self.model.batch_shape
174+
165175

166176
class FixedSingleSampleModel(DeterministicModel):
167177
r"""

botorch/utils/testing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ def rsample(
347347
do a shape check but return the same mock samples."""
348348
if sample_shape is None:
349349
sample_shape = torch.Size()
350-
return self._samples.expand(sample_shape + self._samples.shape)
350+
extended_shape = self._extended_shape(sample_shape)
351+
return self._samples.expand(extended_shape)
351352

352353
def rsample_from_base_samples(
353354
self,

test/acquisition/test_monte_carlo.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
qExpectedImprovement,
2121
qLowerConfidenceBound,
2222
qNoisyExpectedImprovement,
23+
qPosteriorStandardDeviation,
2324
qProbabilityOfImprovement,
2425
qSimpleRegret,
2526
qUpperConfidenceBound,
@@ -37,6 +38,7 @@
3738
from botorch.models import SingleTaskGP
3839
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
3940
from botorch.utils.low_rank import sample_cached_cholesky
41+
from botorch.utils.sampling import draw_sobol_normal_samples
4042
from botorch.utils.test_helpers import DummyNonScalarizingPosteriorTransform
4143
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
4244
from botorch.utils.transforms import standardize
@@ -1009,6 +1011,121 @@ def test_beta_prime(self):
10091011
super().test_beta_prime(negate=True)
10101012

10111013

1014+
class TestQPosteriorStandardDeviation(BotorchTestCase):
1015+
def test_q_pstd(self):
1016+
n_samples = 128
1017+
for dtype in (torch.float, torch.double):
1018+
# the event shape is `b x q x t` = 1 x 1 x 1
1019+
samples = draw_sobol_normal_samples(
1020+
1,
1021+
n_samples,
1022+
device=self.device,
1023+
dtype=dtype,
1024+
seed=0,
1025+
)[..., None, None]
1026+
# samples has shape (n_samples, 1, 1, 1)
1027+
std = samples.std(dim=0, correction=0).item()
1028+
mm = MockModel(
1029+
MockPosterior(samples=samples, base_shape=torch.Size([1, 1, 1]))
1030+
)
1031+
# X is `q x d` = 1 x 1. X is a dummy and unused b/c of mocking
1032+
X = torch.zeros(1, 1, device=self.device, dtype=dtype)
1033+
1034+
# basic test
1035+
sampler = IIDNormalSampler(sample_shape=torch.Size([n_samples]))
1036+
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
1037+
res = acqf(X)
1038+
self.assertAllClose(res.item(), std, rtol=0.02, atol=0)
1039+
1040+
# basic test
1041+
sampler = IIDNormalSampler(sample_shape=torch.Size([n_samples]), seed=12345)
1042+
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
1043+
res = acqf(X)
1044+
self.assertAllClose(res.item(), std, rtol=0.02, atol=0)
1045+
self.assertEqual(
1046+
acqf.sampler.base_samples.shape, torch.Size([n_samples, 1, 1, 1])
1047+
)
1048+
bs = acqf.sampler.base_samples.clone()
1049+
res = acqf(X)
1050+
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
1051+
1052+
# basic test, qmc
1053+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([n_samples]))
1054+
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
1055+
res = acqf(X)
1056+
self.assertAllClose(res.item(), std, rtol=0.02, atol=0)
1057+
self.assertEqual(
1058+
acqf.sampler.base_samples.shape, torch.Size([n_samples, 1, 1, 1])
1059+
)
1060+
bs = acqf.sampler.base_samples.clone()
1061+
acqf(X)
1062+
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
1063+
1064+
# basic test for X_pending and warning
1065+
acqf.set_X_pending()
1066+
self.assertIsNone(acqf.X_pending)
1067+
acqf.set_X_pending(None)
1068+
self.assertIsNone(acqf.X_pending)
1069+
acqf.set_X_pending(X)
1070+
self.assertEqual(acqf.X_pending, X)
1071+
mm._posterior._base_shape = torch.Size([1, 2, 1])
1072+
mm._posterior._samples = mm._posterior._samples.expand(n_samples, 1, 2, 1)
1073+
res = acqf(X)
1074+
X2 = torch.zeros(
1075+
1, 1, 1, device=self.device, dtype=dtype, requires_grad=True
1076+
)
1077+
with warnings.catch_warnings(record=True) as ws:
1078+
acqf.set_X_pending(X2)
1079+
self.assertEqual(acqf.X_pending, X2)
1080+
self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1)
1081+
1082+
def test_q_pstd_batch(self):
1083+
# the event shape is `b x q x t` = 2 x 2 x 1
1084+
for dtype in (torch.float, torch.double):
1085+
samples = torch.zeros(2, 2, 1, device=self.device, dtype=dtype)
1086+
samples[0, 0, 0] = 1.0
1087+
mm = MockModel(MockPosterior(samples=samples))
1088+
# X is a dummy and unused b/c of mocking
1089+
X = torch.zeros(2, 2, 1, device=self.device, dtype=dtype)
1090+
1091+
# test batch mode
1092+
sampler = IIDNormalSampler(sample_shape=torch.Size([8]))
1093+
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
1094+
res = acqf(X)
1095+
self.assertEqual(res[0].item(), 0.0)
1096+
self.assertEqual(res[1].item(), 0.0)
1097+
1098+
# test batch mode
1099+
sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345)
1100+
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
1101+
res = acqf(X) # 1-dim batch
1102+
self.assertEqual(res[0].item(), 0.0)
1103+
self.assertEqual(res[1].item(), 0.0)
1104+
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
1105+
bs = acqf.sampler.base_samples.clone()
1106+
acqf(X)
1107+
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
1108+
res = acqf(X.expand(2, -1, 1)) # 2-dim batch
1109+
self.assertEqual(res[0].item(), 0.0)
1110+
self.assertEqual(res[1].item(), 0.0)
1111+
# the base samples should have the batch dim collapsed
1112+
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
1113+
bs = acqf.sampler.base_samples.clone()
1114+
acqf(X.expand(2, -1, 1))
1115+
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
1116+
1117+
# test batch mode, qmc
1118+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2]))
1119+
acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler)
1120+
res = acqf(X)
1121+
self.assertEqual(res[0].item(), 0.0)
1122+
self.assertEqual(res[1].item(), 0.0)
1123+
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
1124+
bs = acqf.sampler.base_samples.clone()
1125+
acqf(X)
1126+
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
1127+
1128+
10121129
class TestMCAcquisitionFunctionWithConstraints(BotorchTestCase):
10131130
def test_mc_acquisition_function_with_constraints(self):
10141131
for dtype in (torch.float, torch.double):
@@ -1033,6 +1150,7 @@ def _test_mc_acquisition_function_with_constraints(self, dtype: torch.dtype):
10331150
# cache_root=True not supported by MockModel, see test_cache_root
10341151
partial(qNoisyExpectedImprovement, cache_root=False, **nei_args),
10351152
partial(qNoisyExpectedImprovement, cache_root=True, **nei_args),
1153+
partial(qPosteriorStandardDeviation, model=mm),
10361154
]:
10371155
acqf = acqf_constructor()
10381156
mm._posterior._samples = (

test/models/test_deterministic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def test_PosteriorMeanModel(self):
152152
train_Y = torch.rand(2, 2)
153153
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
154154
mean_model = PosteriorMeanModel(model=model)
155+
self.assertTrue(mean_model.num_outputs == train_Y.shape[-1])
156+
self.assertTrue(mean_model.batch_shape == torch.Size([]))
155157

156158
test_X = torch.rand(2, 3)
157159
post = model.posterior(test_X)

0 commit comments

Comments
 (0)