@@ -149,19 +149,59 @@ def _draw_kernel_feature_paths_MultiTaskGP(
149149 else model ._task_feature
150150 )
151151
152- # NOTE: May want to use a `ProductKernel` instead in `MultiTaskGP`
153- base_kernel = deepcopy (model .covar_module )
154- base_kernel .active_dims = torch .LongTensor (
155- [index for index in range (train_X .shape [- 1 ]) if index != task_index ],
156- device = base_kernel .device ,
157- )
158-
159- task_kernel = deepcopy (model .task_covar_module )
160- task_kernel .active_dims = torch .tensor ([task_index ], device = base_kernel .device )
152+ # Extract kernels from the product kernel structure
153+ # model.covar_module is a ProductKernel
154+ # containing data_covar_module * task_covar_module
155+ from gpytorch .kernels import ProductKernel
156+
157+ if isinstance (model .covar_module , ProductKernel ):
158+ # Get the individual kernels from the product kernel
159+ kernels = model .covar_module .kernels
160+
161+ # Find data and task kernels based on their active_dims
162+ data_kernel = None
163+ task_kernel = None
164+
165+ for kernel in kernels :
166+ if hasattr (kernel , "active_dims" ) and kernel .active_dims is not None :
167+ if task_index in kernel .active_dims :
168+ task_kernel = deepcopy (kernel )
169+ else :
170+ data_kernel = deepcopy (kernel )
171+ else :
172+ # If no active_dims, it's likely the data kernel
173+ data_kernel = deepcopy (kernel )
174+ data_kernel .active_dims = torch .LongTensor (
175+ [
176+ index
177+ for index in range (train_X .shape [- 1 ])
178+ if index != task_index
179+ ],
180+ device = data_kernel .device ,
181+ )
182+
183+ # If we couldn't find the task kernel, create it based on the structure
184+ if task_kernel is None :
185+ from gpytorch .kernels import IndexKernel
186+
187+ task_kernel = IndexKernel (
188+ num_tasks = model .num_tasks ,
189+ rank = model ._rank ,
190+ active_dims = [task_index ],
191+ ).to (device = model .covar_module .device , dtype = model .covar_module .dtype )
192+
193+ # Set task kernel active dims correctly
194+ task_kernel .active_dims = torch .tensor ([task_index ], device = task_kernel .device )
195+
196+ # Use the existing product kernel structure
197+ combined_kernel = data_kernel * task_kernel
198+ else :
199+ # Fallback to using the original covar_module directly
200+ combined_kernel = model .covar_module
161201
162202 return _draw_kernel_feature_paths_fallback (
163203 mean_module = model .mean_module ,
164- covar_module = base_kernel * task_kernel ,
204+ covar_module = combined_kernel ,
165205 input_transform = get_input_transform (model ),
166206 output_transform = get_output_transform (model ),
167207 num_ambient_inputs = num_ambient_inputs ,
0 commit comments