Skip to content

Commit 04ae7c4

Browse files
author
Sahran Ashoor
committed
Upstream merge + pathwise test coverage + build + lint
1 parent 9774176 commit 04ae7c4

File tree

11 files changed

+493
-131
lines changed

11 files changed

+493
-131
lines changed

botorch/models/fully_bayesian_multitask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from botorch.models.transforms.outcome import OutcomeTransform
2525
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
2626
from gpytorch.distributions import MultivariateNormal
27-
from gpytorch.kernels import MaternKernel
27+
from gpytorch.kernels import IndexKernel, MaternKernel
2828
from gpytorch.kernels.kernel import Kernel
2929
from gpytorch.likelihoods.likelihood import Likelihood
3030
from gpytorch.means.mean import Mean

botorch/optim/optimize_mixed.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import dataclasses
8+
import itertools
9+
import random
810
import warnings
11+
from collections.abc import Sequence
912
from typing import Any, Callable
1013

1114
import torch
@@ -745,6 +748,7 @@ def discrete_step(
745748
def continuous_step(
746749
opt_inputs: OptimizeAcqfInputs,
747750
discrete_dims: Tensor,
751+
cat_dims: Tensor,
748752
current_x: Tensor,
749753
) -> tuple[Tensor, Tensor]:
750754
"""Continuous search using L-BFGS-B through optimize_acqf.

botorch/sampling/pathwise/paths.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
from abc import ABC, abstractmethod
10-
from collections.abc import Callable, Iterable, Iterator, Mapping
10+
from collections.abc import Callable, Iterable, Mapping
1111
from string import ascii_letters
1212
from typing import Any
1313

@@ -142,7 +142,6 @@ def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
142142
path.set_ensemble_as_batch(ensemble_as_batch)
143143

144144

145-
146145
class GeneralizedLinearPath(SamplePath):
147146
r"""A sample path in the form of a generalized linear model."""
148147

botorch/sampling/pathwise/prior_samplers.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,19 +149,59 @@ def _draw_kernel_feature_paths_MultiTaskGP(
149149
else model._task_feature
150150
)
151151

152-
# NOTE: May want to use a `ProductKernel` instead in `MultiTaskGP`
153-
base_kernel = deepcopy(model.covar_module)
154-
base_kernel.active_dims = torch.LongTensor(
155-
[index for index in range(train_X.shape[-1]) if index != task_index],
156-
device=base_kernel.device,
157-
)
158-
159-
task_kernel = deepcopy(model.task_covar_module)
160-
task_kernel.active_dims = torch.tensor([task_index], device=base_kernel.device)
152+
# Extract kernels from the product kernel structure
153+
# model.covar_module is a ProductKernel
154+
# containing data_covar_module * task_covar_module
155+
from gpytorch.kernels import ProductKernel
156+
157+
if isinstance(model.covar_module, ProductKernel):
158+
# Get the individual kernels from the product kernel
159+
kernels = model.covar_module.kernels
160+
161+
# Find data and task kernels based on their active_dims
162+
data_kernel = None
163+
task_kernel = None
164+
165+
for kernel in kernels:
166+
if hasattr(kernel, "active_dims") and kernel.active_dims is not None:
167+
if task_index in kernel.active_dims:
168+
task_kernel = deepcopy(kernel)
169+
else:
170+
data_kernel = deepcopy(kernel)
171+
else:
172+
# If no active_dims, it's likely the data kernel
173+
data_kernel = deepcopy(kernel)
174+
data_kernel.active_dims = torch.LongTensor(
175+
[
176+
index
177+
for index in range(train_X.shape[-1])
178+
if index != task_index
179+
],
180+
device=data_kernel.device,
181+
)
182+
183+
# If we couldn't find the task kernel, create it based on the structure
184+
if task_kernel is None:
185+
from gpytorch.kernels import IndexKernel
186+
187+
task_kernel = IndexKernel(
188+
num_tasks=model.num_tasks,
189+
rank=model._rank,
190+
active_dims=[task_index],
191+
).to(device=model.covar_module.device, dtype=model.covar_module.dtype)
192+
193+
# Set task kernel active dims correctly
194+
task_kernel.active_dims = torch.tensor([task_index], device=task_kernel.device)
195+
196+
# Use the existing product kernel structure
197+
combined_kernel = data_kernel * task_kernel
198+
else:
199+
# Fallback to using the original covar_module directly
200+
combined_kernel = model.covar_module
161201

162202
return _draw_kernel_feature_paths_fallback(
163203
mean_module=model.mean_module,
164-
covar_module=base_kernel * task_kernel,
204+
covar_module=combined_kernel,
165205
input_transform=get_input_transform(model),
166206
output_transform=get_output_transform(model),
167207
num_ambient_inputs=num_ambient_inputs,

botorch/sampling/pathwise/update_strategies.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,17 +172,58 @@ def _draw_kernel_feature_paths_MultiTaskGP(
172172
if model._task_feature < 0
173173
else model._task_feature
174174
)
175-
base_kernel = deepcopy(model.covar_module)
176-
base_kernel.active_dims = torch.LongTensor(
177-
[index for index in range(num_inputs) if index != task_index],
178-
device=base_kernel.device,
179-
)
180-
task_kernel = deepcopy(model.task_covar_module)
181-
task_kernel.active_dims = torch.LongTensor([task_index], device=base_kernel.device)
175+
176+
# Extract kernels from the product kernel structure
177+
# model.covar_module is a ProductKernel
178+
# containing data_covar_module * task_covar_module
179+
from gpytorch.kernels import ProductKernel
180+
181+
if isinstance(model.covar_module, ProductKernel):
182+
# Get the individual kernels from the product kernel
183+
kernels = model.covar_module.kernels
184+
185+
# Find data and task kernels based on their active_dims
186+
data_kernel = None
187+
task_kernel = None
188+
189+
for kernel in kernels:
190+
if hasattr(kernel, "active_dims") and kernel.active_dims is not None:
191+
if task_index in kernel.active_dims:
192+
task_kernel = deepcopy(kernel)
193+
else:
194+
data_kernel = deepcopy(kernel)
195+
else:
196+
# If no active_dims, it's likely the data kernel
197+
data_kernel = deepcopy(kernel)
198+
data_kernel.active_dims = torch.LongTensor(
199+
[index for index in range(num_inputs) if index != task_index],
200+
device=data_kernel.device,
201+
)
202+
203+
# If we couldn't find the task kernel, create it based on the structure
204+
if task_kernel is None:
205+
from gpytorch.kernels import IndexKernel
206+
207+
task_kernel = IndexKernel(
208+
num_tasks=model.num_tasks,
209+
rank=model._rank,
210+
active_dims=[task_index],
211+
).to(device=model.covar_module.device, dtype=model.covar_module.dtype)
212+
213+
# Set task kernel active dims correctly
214+
task_kernel.active_dims = torch.LongTensor(
215+
[task_index], device=task_kernel.device
216+
)
217+
218+
# Use the existing product kernel structure
219+
combined_kernel = data_kernel * task_kernel
220+
else:
221+
# Fallback to using the original covar_module directly
222+
combined_kernel = model.covar_module
182223

183224
# Return exact update using product kernel
184225
return _gaussian_update_exact(
185-
kernel=base_kernel * task_kernel,
226+
kernel=combined_kernel,
186227
points=points,
187228
target_values=target_values,
188229
sample_values=sample_values,

test/models/test_fully_bayesian_multitask.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131
)
3232
from botorch.models import ModelList, ModelListGP
3333
from botorch.models.deterministic import GenericDeterministicModel
34-
from botorch.models.fully_bayesian import MCMC_DIM, MIN_INFERRED_NOISE_LEVEL
34+
from botorch.models.fully_bayesian import (
35+
matern52_kernel,
36+
MCMC_DIM,
37+
MIN_INFERRED_NOISE_LEVEL,
38+
)
3539
from botorch.models.fully_bayesian_multitask import (
3640
MultitaskSaasPyroModel,
3741
SaasFullyBayesianMultiTaskGP,
@@ -46,7 +50,7 @@
4650
)
4751
from botorch.utils.test_helpers import gen_multi_task_dataset
4852
from botorch.utils.testing import BotorchTestCase
49-
from gpytorch.kernels import MaternKernel, ScaleKernel
53+
from gpytorch.kernels import IndexKernel, MaternKernel, ScaleKernel
5054
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
5155
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
5256
from gpytorch.means import ConstantMean

test/optim/test_optimize_mixed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import random
78
from dataclasses import fields
89
from itertools import product
910
from typing import Any, Callable
@@ -29,6 +30,7 @@
2930
continuous_step,
3031
discrete_step,
3132
generate_starting_points,
33+
get_categorical_neighbors,
3234
get_nearest_neighbors,
3335
get_spray_points,
3436
MAX_DISCRETE_VALUES,

test/sampling/pathwise/helpers.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,28 @@ def gen_random_inputs(
8484
tkwargs = {"device": train_X.device, "dtype": train_X.dtype}
8585
X = torch.rand((*batch_shape, train_X.shape[-1]), **tkwargs)
8686
if isinstance(model, models.MultiTaskGP):
87-
num_tasks = model.task_covar_module.raw_var.shape[-1]
87+
# Extract task kernel from the product kernel structure
88+
from gpytorch.kernels import ProductKernel
89+
90+
if isinstance(model.covar_module, ProductKernel):
91+
# Find the task kernel based on active_dims
92+
task_kernel = None
93+
for kernel in model.covar_module.kernels:
94+
if (
95+
hasattr(kernel, "active_dims")
96+
and kernel.active_dims is not None
97+
):
98+
if model._task_feature in kernel.active_dims:
99+
task_kernel = kernel
100+
break
101+
102+
if task_kernel is not None and hasattr(task_kernel, "raw_var"):
103+
num_tasks = task_kernel.raw_var.shape[-1]
104+
else:
105+
num_tasks = model.num_tasks
106+
else:
107+
num_tasks = model.num_tasks
108+
88109
X[..., model._task_feature] = (
89110
torch.randint(num_tasks, size=X.shape[:-1], **tkwargs)
90111
if task_id is None

0 commit comments

Comments
 (0)