diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 8b5a6f86b..5b754f85f 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -1038,7 +1038,12 @@ def attribute( # type: ignore coefficient of the corresponding interpretale feature. All elements with the same value in the feature mask will contain the same coefficient in the returned - attributions. If return_input_shape is False, a 1D + attributions. + If forward_func returns a single element per batch, then the + first dimension of each tensor will be 1, and the remaining + dimensions will have the same shape as the original input + tensor. + If return_input_shape is False, a 1D tensor is returned, containing only the coefficients of the trained interpreatable models, with length num_interp_features. @@ -1242,6 +1247,7 @@ def _attribute_kwargs( # type: ignore coefs, num_interp_features, is_inputs_tuple, + leading_dim_one=(bsz > 1), ) else: return coefs @@ -1254,6 +1260,7 @@ def _convert_output_shape( coefs: Tensor, num_interp_features: int, is_inputs_tuple: Literal[True], + leading_dim_one: bool = False, ) -> Tuple[Tensor, ...]: ... @typing.overload @@ -1264,6 +1271,7 @@ def _convert_output_shape( # type: ignore coefs: Tensor, num_interp_features: int, is_inputs_tuple: Literal[False], + leading_dim_one: bool = False, ) -> Tensor: ... @typing.overload @@ -1274,6 +1282,7 @@ def _convert_output_shape( coefs: Tensor, num_interp_features: int, is_inputs_tuple: bool, + leading_dim_one: bool = False, ) -> Union[Tensor, Tuple[Tensor, ...]]: ... def _convert_output_shape( @@ -1283,6 +1292,7 @@ def _convert_output_shape( coefs: Tensor, num_interp_features: int, is_inputs_tuple: bool, + leading_dim_one: bool = False, ) -> Union[Tensor, Tuple[Tensor, ...]]: coefs = coefs.flatten() attr = [ @@ -1295,4 +1305,7 @@ def _convert_output_shape( coefs[single_feature].item() * (feature_mask[tensor_ind] == single_feature).float() ) + if leading_dim_one: + for i in range(len(attr)): + attr[i] = attr[i][0:1] return _format_output(is_inputs_tuple, tuple(attr)) diff --git a/tests/attr/test_kernel_shap.py b/tests/attr/test_kernel_shap.py index fa34d369d..61bd66397 100644 --- a/tests/attr/test_kernel_shap.py +++ b/tests/attr/test_kernel_shap.py @@ -348,9 +348,9 @@ def _multi_input_scalar_kernel_shap_assert(self, func: Callable) -> None: mask2 = torch.tensor([[0, 1, 2]]) mask3 = torch.tensor([[0, 1, 2]]) expected = ( - [[3850.6666, 3850.6666, 3850.6666]] * 2, - [[306.6666, 3850.6666, 410.6666]] * 2, - [[306.6666, 3850.6666, 410.6666]] * 2, + [[3850.6666, 3850.6666, 3850.6666]], + [[306.6666, 3850.6666, 410.6666]], + [[306.6666, 3850.6666, 410.6666]], ) self._kernel_shap_test_assert( diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index e3fccb279..095ef9cf0 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -494,9 +494,9 @@ def _multi_input_scalar_lime_assert(self, func: Callable) -> None: mask2 = torch.tensor([[0, 1, 2]]) mask3 = torch.tensor([[0, 1, 2]]) expected = ( - [[3850.6666, 3850.6666, 3850.6666]] * 2, - [[305.5, 3850.6666, 410.1]] * 2, - [[305.5, 3850.6666, 410.1]] * 2, + [[3850.6666, 3850.6666, 3850.6666]], + [[305.5, 3850.6666, 410.1]], + [[305.5, 3850.6666, 410.1]], ) self._lime_test_assert(