Skip to content

Commit 0cb012f

Browse files
eibarollefacebook-github-bot
authored andcommitted
NP Regression Model w/ LIG Acquisition (#2683)
Summary: ## Motivation This pull request adds a Neural Process Regression Model with a Latent Information Gain acquisition function for BoTorch functionality. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes, and I've followed all the steps and testing. Pull Request resolved: #2683 Test Plan: I wrote my own unit tests for both the model and acquisition function, and all of them passed. The test files are in the appropriate folder. In addition, I ran the pytests on my files, and all of them succeeded for those files. ## Related I made a repository holding the pushed files at https://github.com/eibarolle/np_regression, and it has the appropriate API documentation. Reviewed By: sdaulton Differential Revision: D75169618 Pulled By: hvarfner fbshipit-source-id: 793b0bcfdac42d1997e50483565f51a3dc5e184f
1 parent c09af15 commit 0cb012f

File tree

4 files changed

+829
-0
lines changed

4 files changed

+829
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
Latent Information Gain Acquisition Function for Neural Process Models.
9+
10+
References:
11+
12+
.. [Wu2023arxiv]
13+
Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023).
14+
Deep Bayesian Active Learning for Accelerating Stochastic Simulation.
15+
arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770
16+
17+
Contributor: eibarolle
18+
"""
19+
20+
from __future__ import annotations
21+
22+
from typing import Any, Type
23+
24+
import torch
25+
from botorch.acquisition import AcquisitionFunction
26+
from botorch_community.models.np_regression import NeuralProcessModel
27+
from torch import Tensor
28+
# reference: https://arxiv.org/abs/2106.02770
29+
30+
31+
class LatentInformationGain(AcquisitionFunction):
32+
def __init__(
33+
self,
34+
model: Type[Any],
35+
num_samples: int = 10,
36+
min_std: float = 0.01,
37+
scaler: float = 0.5,
38+
) -> None:
39+
"""
40+
Latent Information Gain (LIG) Acquisition Function.
41+
Uses the model's built-in posterior function to generalize KL computation.
42+
43+
Args:
44+
model: The model class to be used, defaults to NeuralProcessModel.
45+
num_samples: Int showing the # of samples for calculation, defaults to 10.
46+
min_std: Float representing the minimum possible standardized std,
47+
defaults to 0.01.
48+
scaler: Float scaling the std, defaults to 0.5.
49+
"""
50+
super().__init__(model)
51+
self.model = model
52+
self.num_samples = num_samples
53+
self.min_std = min_std
54+
self.scaler = scaler
55+
56+
def forward(self, candidate_x: Tensor) -> Tensor:
57+
"""
58+
Conduct the Latent Information Gain acquisition function for the inputs.
59+
60+
Args:
61+
candidate_x: Candidate input points, as a Tensor. Ideally in the shape
62+
(N, q, D).
63+
64+
Returns:
65+
torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q).
66+
"""
67+
device = candidate_x.device
68+
candidate_x = candidate_x.to(device)
69+
N, q, D = candidate_x.shape
70+
kl = torch.zeros(N, device=device, dtype=torch.float32)
71+
72+
if isinstance(self.model, NeuralProcessModel):
73+
x_c, y_c, _, _ = self.model.random_split_context_target(
74+
self.model.train_X, self.model.train_Y, self.model.n_context
75+
)
76+
self.model.z_mu_context, self.model.z_logvar_context = (
77+
self.model.data_to_z_params(x_c, y_c)
78+
)
79+
80+
for i in range(N):
81+
x_i = candidate_x[i]
82+
kl_i = 0.0
83+
84+
for _ in range(self.num_samples):
85+
sample_z = self.model.sample_z(
86+
self.model.z_mu_context, self.model.z_logvar_context
87+
)
88+
if sample_z.dim() == 1:
89+
sample_z = sample_z.unsqueeze(0)
90+
91+
y_pred = self.model.decoder(x_i, sample_z)
92+
93+
combined_x = torch.cat([x_c, x_i], dim=0)
94+
combined_y = torch.cat([y_c, y_pred], dim=0)
95+
96+
self.model.z_mu_all, self.model.z_logvar_all = (
97+
self.model.data_to_z_params(combined_x, combined_y)
98+
)
99+
kl_sample = self.model.KLD_gaussian(self.min_std, self.scaler)
100+
kl_i += kl_sample
101+
102+
kl[i] = kl_i / self.num_samples
103+
104+
else:
105+
for i in range(N):
106+
x_i = candidate_x[i]
107+
kl_i = 0.0
108+
for _ in range(self.num_samples):
109+
posterior_prior = self.model.posterior(self.model.train_inputs[0])
110+
posterior_candidate = self.model.posterior(x_i)
111+
112+
mean_prior = posterior_prior.mean.mean(dim=0)
113+
cov_prior = posterior_prior.variance.mean(dim=0)
114+
mvn_prior = torch.distributions.MultivariateNormal(
115+
mean_prior, torch.diag(cov_prior)
116+
)
117+
118+
mean_candidate = posterior_candidate.mean.mean(dim=0)
119+
cov_candidate = posterior_candidate.variance.mean(dim=0)
120+
mvn_candidate = torch.distributions.MultivariateNormal(
121+
mean_candidate, torch.diag(cov_candidate)
122+
)
123+
124+
kl_i += torch.distributions.kl_divergence(mvn_candidate, mvn_prior)
125+
126+
kl[i] = kl_i / self.num_samples
127+
128+
return kl

0 commit comments

Comments
 (0)