1818 reshape_and_detach ,
1919 SaasPyroModel ,
2020)
21- from botorch .models .gpytorch import BatchedMultiOutputGPyTorchModel
2221from botorch .models .multitask import MultiTaskGP
2322from botorch .models .transforms .input import InputTransform
2423from botorch .models .transforms .outcome import OutcomeTransform
2524from botorch .posteriors .fully_bayesian import GaussianMixturePosterior , MCMC_DIM
26- from gpytorch .distributions import MultivariateNormal
25+ from gpytorch .distributions . multivariate_normal import MultivariateNormal
2726from gpytorch .kernels import MaternKernel
28- from gpytorch .kernels .index_kernel import IndexKernel
2927from gpytorch .kernels .kernel import Kernel
3028from gpytorch .likelihoods .likelihood import Likelihood
3129from gpytorch .means .mean import Mean
@@ -134,7 +132,7 @@ def sample_task_lengthscale(
134132
135133 def load_mcmc_samples (
136134 self , mcmc_samples : dict [str , Tensor ]
137- ) -> tuple [Mean , Kernel , Likelihood , Kernel ]:
135+ ) -> tuple [Mean , Kernel , Likelihood , Kernel , Parameter ]:
138136 r"""Load the MCMC samples into the mean_module, covar_module, and likelihood."""
139137 tkwargs = {"device" : self .train_X .device , "dtype" : self .train_X .dtype }
140138 num_mcmc_samples = len (mcmc_samples ["mean" ])
@@ -144,32 +142,27 @@ def load_mcmc_samples(
144142 mcmc_samples = mcmc_samples
145143 )
146144
147- latent_covar_module = MaternKernel (
145+ task_covar_module = MaternKernel (
148146 nu = 2.5 ,
149147 ard_num_dims = self .task_rank ,
150148 batch_shape = batch_shape ,
151149 ).to (** tkwargs )
152- latent_covar_module .lengthscale = reshape_and_detach (
153- target = latent_covar_module .lengthscale ,
150+ task_covar_module .lengthscale = reshape_and_detach (
151+ target = task_covar_module .lengthscale ,
154152 new_value = mcmc_samples ["task_lengthscale" ],
155153 )
156- latent_features = mcmc_samples [ "latent_features" ]
157- task_covar = latent_covar_module ( latent_features )
158- task_covar_module = IndexKernel (
159- num_tasks = self . num_tasks ,
160- rank = self . task_rank ,
161- batch_shape = latent_features . shape [: - 2 ],
154+ latent_features = Parameter (
155+ torch . rand (
156+ batch_shape + torch . Size ([ self . num_tasks , self . task_rank ]),
157+ requires_grad = True ,
158+ ** tkwargs ,
159+ )
162160 )
163- task_covar_module .covar_factor = Parameter (
164- task_covar .cholesky ().to_dense ().detach ()
161+ latent_features = reshape_and_detach (
162+ target = latent_features ,
163+ new_value = mcmc_samples ["latent_features" ],
165164 )
166-
167- # NOTE: 'var' is implicitly assumed to be zero from the sampling procedure in
168- # the FBMTGP model but not in the regular MTGP. I dont how if the var parameter
169- # affects predictions in practice, but setting it to zero is consistent with the
170- # previous implementation.
171- task_covar_module .var = torch .zeros_like (task_covar_module .var )
172- return mean_module , covar_module , likelihood , task_covar_module
165+ return mean_module , covar_module , likelihood , task_covar_module , latent_features
173166
174167
175168class SaasFullyBayesianMultiTaskGP (MultiTaskGP ):
@@ -368,6 +361,7 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
368361 self .covar_module ,
369362 self .likelihood ,
370363 self .task_covar_module ,
364+ self .latent_features ,
371365 ) = self .pyro_model .load_mcmc_samples (mcmc_samples = mcmc_samples )
372366
373367 def posterior (
@@ -397,7 +391,30 @@ def posterior(
397391
398392 def forward (self , X : Tensor ) -> MultivariateNormal :
399393 self ._check_if_fitted ()
400- return super ().forward (X )
394+ x_basic , task_idcs = self ._split_inputs (X )
395+
396+ mean_x = self .mean_module (x_basic )
397+ covar_x = self .covar_module (x_basic )
398+
399+ tsub_idcs = task_idcs .squeeze (- 1 )
400+ if tsub_idcs .ndim > 1 :
401+ tsub_idcs = tsub_idcs .squeeze (- 2 )
402+ latent_features = self .latent_features [:, tsub_idcs , :]
403+
404+ if X .ndim > 3 :
405+ # batch eval mode
406+ # for X (batch_shape x num_samples x q x d), task_idcs[:,i,:,] are the same
407+ # reshape X to (batch_shape x num_samples x q x d)
408+ latent_features = latent_features .permute (
409+ [- i for i in range (X .ndim - 1 , 2 , - 1 )]
410+ + [0 ]
411+ + [- i for i in range (2 , 0 , - 1 )]
412+ )
413+
414+ # Combine the two in an ICM fashion
415+ covar_i = self .task_covar_module (latent_features )
416+ covar = covar_x .mul (covar_i )
417+ return MultivariateNormal (mean_x , covar )
401418
402419 def load_state_dict (self , state_dict : Mapping [str , Any ], strict : bool = True ):
403420 r"""Custom logic for loading the state dict.
@@ -439,40 +456,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
439456 self .covar_module ,
440457 self .likelihood ,
441458 self .task_covar_module ,
459+ self .latent_features ,
442460 ) = self .pyro_model .load_mcmc_samples (mcmc_samples = mcmc_samples )
443461 # Load the actual samples from the state dict
444462 super ().load_state_dict (state_dict = state_dict , strict = strict )
445-
446- def condition_on_observations (
447- self , X : Tensor , Y : Tensor , ** kwargs : Any
448- ) -> BatchedMultiOutputGPyTorchModel :
449- """Conditions on additional observations for a Fully Bayesian model (either
450- identical across models or unique per-model).
451-
452- Args:
453- X: A `batch_shape x num_samples x d`-dim Tensor, where `d` is
454- the dimension of the feature space and `batch_shape` is the number of
455- sampled models.
456- Y: A `batch_shape x num_samples x 1`-dim Tensor, where `d` is
457- the dimension of the feature space and `batch_shape` is the number of
458- sampled models.
459-
460- Returns:
461- BatchedMultiOutputGPyTorchModel: A fully bayesian model conditioned on
462- given observations. The returned model has `batch_shape` copies of the
463- training data in case of identical observations (and `batch_shape`
464- training datasets otherwise).
465- """
466- if X .ndim == 2 and Y .ndim == 2 :
467- # To avoid an error in GPyTorch when inferring the batch dimension, we add
468- # the explicit batch shape here. The result is that the conditioned model
469- # will have 'batch_shape' copies of the training data.
470- X = X .repeat (self .batch_shape + (1 , 1 ))
471- Y = Y .repeat (self .batch_shape + (1 , 1 ))
472-
473- elif X .ndim < Y .ndim :
474- # We need to duplicate the training data to enable correct batch
475- # size inference in gpytorch.
476- X = X .repeat (* (Y .shape [:- 2 ] + (1 , 1 )))
477-
478- return super ().condition_on_observations (X , Y , ** kwargs )
0 commit comments