Skip to content

Commit a9c7b70

Browse files
blethamfacebook-github-bot
authored andcommitted
Fix variance shape bug in Riemann posterior (#2939)
Summary: Pull Request resolved: #2939 Currently, BoundedRiemannPosterior.mean returns a Size([b, 1]) tensor, but BoundedRiemannPosterior.mean_of_square returns a Size([b]) tensor. As a result, when these are combined to get the variance, we end up with a Size([b, b]) variance instead of the correct Size([b, 1]). This fixes the issue. Reviewed By: SamuelGabriel Differential Revision: D78911131 fbshipit-source-id: 5ba69604b3c198ddce68b70a4af806e64fdbc0b1
1 parent 29877b8 commit a9c7b70

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

botorch_community/posteriors/riemann.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def mean(self):
107107
r"""The mean of the posterior distribution."""
108108
bucket_widths = self.borders[1:] - self.borders[:-1]
109109
bucket_means = self.borders[:-1] + bucket_widths / 2
110-
return (bucket_means * (self.probabilities)).sum(-1, keepdim=True)
110+
return (self.probabilities @ bucket_means).unsqueeze(-1)
111111

112112
@property
113113
def mean_of_square(self) -> torch.Tensor:
@@ -119,7 +119,7 @@ def mean_of_square(self) -> torch.Tensor:
119119
+ right_borders.square()
120120
+ left_borders * right_borders
121121
) / 3.0
122-
return self.probabilities @ bucket_mean_of_square
122+
return (self.probabilities @ bucket_mean_of_square).unsqueeze(-1)
123123

124124
@property
125125
def variance(self) -> torch.Tensor:

test_community/posteriors/test_riemann.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,13 @@ def test_variance(self):
148148
computed_variance = posterior.variance
149149
self.assertLess((computed_variance - true_variance).abs().item(), 0.05)
150150

151+
# Check with batch dimension
152+
probabilities = torch.rand(2, n_buckets, **tkwargs)
153+
probabilities = probabilities / probabilities.sum(-1, keepdim=True)
154+
posterior = BoundedRiemannPosterior(borders, probabilities)
155+
self.assertEqual(posterior.variance.shape, torch.Size([2, 1]))
156+
self.assertEqual(posterior.mean.shape, torch.Size([2, 1]))
157+
151158
def test_confidence_region(self):
152159
torch.manual_seed(13)
153160
for dtype in (torch.float, torch.double):

0 commit comments

Comments
 (0)