Skip to content

Commit b53b2c1

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Add enable_cross_tensor_attribution flag to attribute_future (#1546)
Summary: reserved Differential Revision: D73464680
1 parent 5248929 commit b53b2c1

File tree

1 file changed

+115
-78
lines changed

1 file changed

+115
-78
lines changed

captum/attr/_core/feature_ablation.py

+115-78
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ def attribute_future(
729729
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
730730
perturbations_per_eval: int = 1,
731731
show_progress: bool = False,
732+
enable_cross_tensor_attribution: bool = False,
732733
**kwargs: Any,
733734
) -> Future[TensorOrTupleOfTensorsGeneric]:
734735
r"""
@@ -743,17 +744,18 @@ def attribute_future(
743744
formatted_additional_forward_args = _format_additional_forward_args(
744745
additional_forward_args
745746
)
746-
num_examples = formatted_inputs[0].shape[0]
747747
formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs)
748748

749749
assert (
750750
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
751751
), "Perturbations per evaluation must be an integer and at least 1."
752752
with torch.no_grad():
753+
attr_progress = None
753754
if show_progress:
754755
attr_progress = self._attribute_progress_setup(
755756
formatted_inputs,
756757
formatted_feature_mask,
758+
enable_cross_tensor_attribution,
757759
**kwargs,
758760
perturbations_per_eval=perturbations_per_eval,
759761
)
@@ -768,7 +770,7 @@ def attribute_future(
768770
formatted_additional_forward_args,
769771
)
770772

771-
if show_progress:
773+
if attr_progress is not None:
772774
attr_progress.update()
773775

774776
processed_initial_eval_fut: Optional[
@@ -788,101 +790,136 @@ def attribute_future(
788790
)
789791
)
790792

791-
# The will be the same amount futures as modified_eval down there,
792-
# since we cannot add up the evaluation result adhoc under async mode.
793-
all_modified_eval_futures: List[
794-
List[Future[Tuple[List[Tensor], List[Tensor]]]]
795-
] = [[] for _ in range(len(inputs))]
796-
# Iterate through each feature tensor for ablation
797-
for i in range(len(formatted_inputs)):
798-
# Skip any empty input tensors
799-
if torch.numel(formatted_inputs[i]) == 0:
800-
continue
801-
802-
for (
803-
current_inputs,
804-
current_add_args,
805-
current_target,
806-
current_mask,
807-
) in self._ith_input_ablation_generator(
808-
i,
793+
if enable_cross_tensor_attribution:
794+
raise NotImplementedError("Not supported yet")
795+
else:
796+
# pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
797+
# <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
798+
# `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
799+
return self._attribute_with_independent_feature_masks_future(
809800
formatted_inputs,
810801
formatted_additional_forward_args,
811802
target,
812803
baselines,
813804
formatted_feature_mask,
814805
perturbations_per_eval,
806+
attr_progress,
807+
processed_initial_eval_fut,
808+
is_inputs_tuple,
815809
**kwargs,
816-
):
817-
# modified_eval has (n_feature_perturbed * n_outputs) elements
818-
# shape:
819-
# agg mode: (*initial_eval.shape)
820-
# non-agg mode:
821-
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
822-
modified_eval: Union[Tensor, Future[Tensor]] = _run_forward(
823-
self.forward_func,
824-
current_inputs,
825-
current_target,
826-
current_add_args,
827-
)
810+
)
828811

829-
if show_progress:
830-
attr_progress.update()
812+
def _attribute_with_independent_feature_masks_future(
813+
self,
814+
formatted_inputs: Tuple[Tensor, ...],
815+
formatted_additional_forward_args: Optional[Tuple[object, ...]],
816+
target: TargetType,
817+
baselines: BaselineType,
818+
formatted_feature_mask: Tuple[Tensor, ...],
819+
perturbations_per_eval: int,
820+
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
821+
processed_initial_eval_fut: Future[
822+
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
823+
],
824+
is_inputs_tuple: bool,
825+
**kwargs: Any,
826+
) -> Future[Union[Tensor, Tuple[Tensor, ...]]]:
827+
num_examples = formatted_inputs[0].shape[0]
828+
# The will be the same amount futures as modified_eval down there,
829+
# since we cannot add up the evaluation result adhoc under async mode.
830+
all_modified_eval_futures: List[
831+
List[Future[Tuple[List[Tensor], List[Tensor]]]]
832+
] = [[] for _ in range(len(formatted_inputs))]
833+
# Iterate through each feature tensor for ablation
834+
for i in range(len(formatted_inputs)):
835+
# Skip any empty input tensors
836+
if torch.numel(formatted_inputs[i]) == 0:
837+
continue
831838

832-
if not isinstance(modified_eval, torch.Future):
833-
raise AssertionError(
834-
"when using attribute_future, modified_eval should have "
835-
f"Future type rather than {type(modified_eval)}"
836-
)
837-
if processed_initial_eval_fut is None:
838-
raise AssertionError(
839-
"processed_initial_eval_fut should not be None"
840-
)
839+
for (
840+
current_inputs,
841+
current_add_args,
842+
current_target,
843+
current_mask,
844+
) in self._ith_input_ablation_generator(
845+
i,
846+
formatted_inputs,
847+
formatted_additional_forward_args,
848+
target,
849+
baselines,
850+
formatted_feature_mask,
851+
perturbations_per_eval,
852+
**kwargs,
853+
):
854+
# modified_eval has (n_feature_perturbed * n_outputs) elements
855+
# shape:
856+
# agg mode: (*initial_eval.shape)
857+
# non-agg mode:
858+
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
859+
modified_eval: Union[Tensor, Future[Tensor]] = _run_forward(
860+
self.forward_func,
861+
current_inputs,
862+
current_target,
863+
current_add_args,
864+
)
841865

842-
# Need to collect both initial eval and modified_eval
843-
eval_futs: Future[
844-
List[
845-
Future[
846-
Union[
847-
Tuple[
848-
List[Tensor],
849-
List[Tensor],
850-
Tensor,
851-
Tensor,
852-
int,
853-
dtype,
854-
],
866+
if attr_progress is not None:
867+
attr_progress.update()
868+
869+
if not isinstance(modified_eval, torch.Future):
870+
raise AssertionError(
871+
"when using attribute_future, modified_eval should have "
872+
f"Future type rather than {type(modified_eval)}"
873+
)
874+
if processed_initial_eval_fut is None:
875+
raise AssertionError(
876+
"processed_initial_eval_fut should not be None"
877+
)
878+
879+
# Need to collect both initial eval and modified_eval
880+
eval_futs: Future[
881+
List[
882+
Future[
883+
Union[
884+
Tuple[
885+
List[Tensor],
886+
List[Tensor],
887+
Tensor,
855888
Tensor,
856-
]
889+
int,
890+
dtype,
891+
],
892+
Tensor,
857893
]
858894
]
859-
] = collect_all(
860-
[
861-
processed_initial_eval_fut,
862-
modified_eval,
863-
]
864-
)
895+
]
896+
] = collect_all(
897+
[
898+
processed_initial_eval_fut,
899+
modified_eval,
900+
]
901+
)
865902

866-
ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = (
867-
eval_futs.then(
868-
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
869-
eval_futs=eval_futs,
870-
current_inputs=current_inputs,
871-
current_mask=current_mask,
872-
i=i,
873-
perturbations_per_eval=perturbations_per_eval,
874-
num_examples=num_examples,
875-
formatted_inputs=formatted_inputs,
876-
)
903+
ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = (
904+
eval_futs.then(
905+
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
906+
eval_futs=eval_futs,
907+
current_inputs=current_inputs,
908+
current_mask=current_mask,
909+
i=i,
910+
perturbations_per_eval=perturbations_per_eval,
911+
num_examples=num_examples,
912+
formatted_inputs=formatted_inputs,
877913
)
878914
)
915+
)
879916

880-
all_modified_eval_futures[i].append(ablated_out_fut)
917+
all_modified_eval_futures[i].append(ablated_out_fut)
881918

882-
if show_progress:
883-
attr_progress.close()
919+
if attr_progress is not None:
920+
attr_progress.close()
884921

885-
return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long
922+
return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long
886923

887924
# pyre-fixme[3] return type must be annotated
888925
def _attribute_progress_setup(

0 commit comments

Comments
 (0)