From 172f8351a383119f07b68f22b7a81864ee68a223 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 25 Feb 2025 14:04:03 -0800 Subject: [PATCH] Fix Lime output dimension in batch forward (#1513) Summary: Currently, when a batch of inputs is provided with a forward function that returns a single scalar per batch, Lime and KernelShap still return output matching the input shape. This behavior is inconsistent with other perturbation based methods, particularly Feature Ablation and Shapley Value Sampling. This change breaks backward compatibility for OSS users, but since it's a specific case (scalar per batch), should be fine to update with only a documentation update. Reviewed By: craymichael Differential Revision: D70096644 --- captum/attr/_core/lime.py | 15 ++++++++++++++- tests/attr/test_kernel_shap.py | 6 +++--- tests/attr/test_lime.py | 6 +++--- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 8b5a6f86b1..5b754f85fe 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -1038,7 +1038,12 @@ def attribute( # type: ignore coefficient of the corresponding interpretale feature. All elements with the same value in the feature mask will contain the same coefficient in the returned - attributions. If return_input_shape is False, a 1D + attributions. + If forward_func returns a single element per batch, then the + first dimension of each tensor will be 1, and the remaining + dimensions will have the same shape as the original input + tensor. + If return_input_shape is False, a 1D tensor is returned, containing only the coefficients of the trained interpreatable models, with length num_interp_features. @@ -1242,6 +1247,7 @@ def _attribute_kwargs( # type: ignore coefs, num_interp_features, is_inputs_tuple, + leading_dim_one=(bsz > 1), ) else: return coefs @@ -1254,6 +1260,7 @@ def _convert_output_shape( coefs: Tensor, num_interp_features: int, is_inputs_tuple: Literal[True], + leading_dim_one: bool = False, ) -> Tuple[Tensor, ...]: ... @typing.overload @@ -1264,6 +1271,7 @@ def _convert_output_shape( # type: ignore coefs: Tensor, num_interp_features: int, is_inputs_tuple: Literal[False], + leading_dim_one: bool = False, ) -> Tensor: ... @typing.overload @@ -1274,6 +1282,7 @@ def _convert_output_shape( coefs: Tensor, num_interp_features: int, is_inputs_tuple: bool, + leading_dim_one: bool = False, ) -> Union[Tensor, Tuple[Tensor, ...]]: ... def _convert_output_shape( @@ -1283,6 +1292,7 @@ def _convert_output_shape( coefs: Tensor, num_interp_features: int, is_inputs_tuple: bool, + leading_dim_one: bool = False, ) -> Union[Tensor, Tuple[Tensor, ...]]: coefs = coefs.flatten() attr = [ @@ -1295,4 +1305,7 @@ def _convert_output_shape( coefs[single_feature].item() * (feature_mask[tensor_ind] == single_feature).float() ) + if leading_dim_one: + for i in range(len(attr)): + attr[i] = attr[i][0:1] return _format_output(is_inputs_tuple, tuple(attr)) diff --git a/tests/attr/test_kernel_shap.py b/tests/attr/test_kernel_shap.py index fa34d369d9..61bd66397f 100644 --- a/tests/attr/test_kernel_shap.py +++ b/tests/attr/test_kernel_shap.py @@ -348,9 +348,9 @@ def _multi_input_scalar_kernel_shap_assert(self, func: Callable) -> None: mask2 = torch.tensor([[0, 1, 2]]) mask3 = torch.tensor([[0, 1, 2]]) expected = ( - [[3850.6666, 3850.6666, 3850.6666]] * 2, - [[306.6666, 3850.6666, 410.6666]] * 2, - [[306.6666, 3850.6666, 410.6666]] * 2, + [[3850.6666, 3850.6666, 3850.6666]], + [[306.6666, 3850.6666, 410.6666]], + [[306.6666, 3850.6666, 410.6666]], ) self._kernel_shap_test_assert( diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index e3fccb2794..095ef9cf0d 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -494,9 +494,9 @@ def _multi_input_scalar_lime_assert(self, func: Callable) -> None: mask2 = torch.tensor([[0, 1, 2]]) mask3 = torch.tensor([[0, 1, 2]]) expected = ( - [[3850.6666, 3850.6666, 3850.6666]] * 2, - [[305.5, 3850.6666, 410.1]] * 2, - [[305.5, 3850.6666, 410.1]] * 2, + [[3850.6666, 3850.6666, 3850.6666]], + [[305.5, 3850.6666, 410.1]], + [[305.5, 3850.6666, 410.1]], ) self._lime_test_assert(