@@ -729,6 +729,7 @@ def attribute_future(
729
729
feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
730
730
perturbations_per_eval : int = 1 ,
731
731
show_progress : bool = False ,
732
+ enable_cross_tensor_attribution : bool = False ,
732
733
** kwargs : Any ,
733
734
) -> Future [TensorOrTupleOfTensorsGeneric ]:
734
735
r"""
@@ -743,17 +744,18 @@ def attribute_future(
743
744
formatted_additional_forward_args = _format_additional_forward_args (
744
745
additional_forward_args
745
746
)
746
- num_examples = formatted_inputs [0 ].shape [0 ]
747
747
formatted_feature_mask = _format_feature_mask (feature_mask , formatted_inputs )
748
748
749
749
assert (
750
750
isinstance (perturbations_per_eval , int ) and perturbations_per_eval >= 1
751
751
), "Perturbations per evaluation must be an integer and at least 1."
752
752
with torch .no_grad ():
753
+ attr_progress = None
753
754
if show_progress :
754
755
attr_progress = self ._attribute_progress_setup (
755
756
formatted_inputs ,
756
757
formatted_feature_mask ,
758
+ enable_cross_tensor_attribution ,
757
759
** kwargs ,
758
760
perturbations_per_eval = perturbations_per_eval ,
759
761
)
@@ -768,7 +770,7 @@ def attribute_future(
768
770
formatted_additional_forward_args ,
769
771
)
770
772
771
- if show_progress :
773
+ if attr_progress is not None :
772
774
attr_progress .update ()
773
775
774
776
processed_initial_eval_fut : Optional [
@@ -788,101 +790,136 @@ def attribute_future(
788
790
)
789
791
)
790
792
791
- # The will be the same amount futures as modified_eval down there,
792
- # since we cannot add up the evaluation result adhoc under async mode.
793
- all_modified_eval_futures : List [
794
- List [Future [Tuple [List [Tensor ], List [Tensor ]]]]
795
- ] = [[] for _ in range (len (inputs ))]
796
- # Iterate through each feature tensor for ablation
797
- for i in range (len (formatted_inputs )):
798
- # Skip any empty input tensors
799
- if torch .numel (formatted_inputs [i ]) == 0 :
800
- continue
801
-
802
- for (
803
- current_inputs ,
804
- current_add_args ,
805
- current_target ,
806
- current_mask ,
807
- ) in self ._ith_input_ablation_generator (
808
- i ,
793
+ if enable_cross_tensor_attribution :
794
+ raise NotImplementedError ("Not supported yet" )
795
+ else :
796
+ # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
797
+ # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
798
+ # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
799
+ return self ._attribute_with_independent_feature_masks_future (
809
800
formatted_inputs ,
810
801
formatted_additional_forward_args ,
811
802
target ,
812
803
baselines ,
813
804
formatted_feature_mask ,
814
805
perturbations_per_eval ,
806
+ attr_progress ,
807
+ processed_initial_eval_fut ,
808
+ is_inputs_tuple ,
815
809
** kwargs ,
816
- ):
817
- # modified_eval has (n_feature_perturbed * n_outputs) elements
818
- # shape:
819
- # agg mode: (*initial_eval.shape)
820
- # non-agg mode:
821
- # (feature_perturbed * batch_size, *initial_eval.shape[1:])
822
- modified_eval : Union [Tensor , Future [Tensor ]] = _run_forward (
823
- self .forward_func ,
824
- current_inputs ,
825
- current_target ,
826
- current_add_args ,
827
- )
810
+ )
828
811
829
- if show_progress :
830
- attr_progress .update ()
812
+ def _attribute_with_independent_feature_masks_future (
813
+ self ,
814
+ formatted_inputs : Tuple [Tensor , ...],
815
+ formatted_additional_forward_args : Optional [Tuple [object , ...]],
816
+ target : TargetType ,
817
+ baselines : BaselineType ,
818
+ formatted_feature_mask : Tuple [Tensor , ...],
819
+ perturbations_per_eval : int ,
820
+ attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
821
+ processed_initial_eval_fut : Future [
822
+ Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]
823
+ ],
824
+ is_inputs_tuple : bool ,
825
+ ** kwargs : Any ,
826
+ ) -> Future [Union [Tensor , Tuple [Tensor , ...]]]:
827
+ num_examples = formatted_inputs [0 ].shape [0 ]
828
+ # The will be the same amount futures as modified_eval down there,
829
+ # since we cannot add up the evaluation result adhoc under async mode.
830
+ all_modified_eval_futures : List [
831
+ List [Future [Tuple [List [Tensor ], List [Tensor ]]]]
832
+ ] = [[] for _ in range (len (formatted_inputs ))]
833
+ # Iterate through each feature tensor for ablation
834
+ for i in range (len (formatted_inputs )):
835
+ # Skip any empty input tensors
836
+ if torch .numel (formatted_inputs [i ]) == 0 :
837
+ continue
831
838
832
- if not isinstance (modified_eval , torch .Future ):
833
- raise AssertionError (
834
- "when using attribute_future, modified_eval should have "
835
- f"Future type rather than { type (modified_eval )} "
836
- )
837
- if processed_initial_eval_fut is None :
838
- raise AssertionError (
839
- "processed_initial_eval_fut should not be None"
840
- )
839
+ for (
840
+ current_inputs ,
841
+ current_add_args ,
842
+ current_target ,
843
+ current_mask ,
844
+ ) in self ._ith_input_ablation_generator (
845
+ i ,
846
+ formatted_inputs ,
847
+ formatted_additional_forward_args ,
848
+ target ,
849
+ baselines ,
850
+ formatted_feature_mask ,
851
+ perturbations_per_eval ,
852
+ ** kwargs ,
853
+ ):
854
+ # modified_eval has (n_feature_perturbed * n_outputs) elements
855
+ # shape:
856
+ # agg mode: (*initial_eval.shape)
857
+ # non-agg mode:
858
+ # (feature_perturbed * batch_size, *initial_eval.shape[1:])
859
+ modified_eval : Union [Tensor , Future [Tensor ]] = _run_forward (
860
+ self .forward_func ,
861
+ current_inputs ,
862
+ current_target ,
863
+ current_add_args ,
864
+ )
841
865
842
- # Need to collect both initial eval and modified_eval
843
- eval_futs : Future [
844
- List [
845
- Future [
846
- Union [
847
- Tuple [
848
- List [Tensor ],
849
- List [Tensor ],
850
- Tensor ,
851
- Tensor ,
852
- int ,
853
- dtype ,
854
- ],
866
+ if attr_progress is not None :
867
+ attr_progress .update ()
868
+
869
+ if not isinstance (modified_eval , torch .Future ):
870
+ raise AssertionError (
871
+ "when using attribute_future, modified_eval should have "
872
+ f"Future type rather than { type (modified_eval )} "
873
+ )
874
+ if processed_initial_eval_fut is None :
875
+ raise AssertionError (
876
+ "processed_initial_eval_fut should not be None"
877
+ )
878
+
879
+ # Need to collect both initial eval and modified_eval
880
+ eval_futs : Future [
881
+ List [
882
+ Future [
883
+ Union [
884
+ Tuple [
885
+ List [Tensor ],
886
+ List [Tensor ],
887
+ Tensor ,
855
888
Tensor ,
856
- ]
889
+ int ,
890
+ dtype ,
891
+ ],
892
+ Tensor ,
857
893
]
858
894
]
859
- ] = collect_all (
860
- [
861
- processed_initial_eval_fut ,
862
- modified_eval ,
863
- ]
864
- )
895
+ ]
896
+ ] = collect_all (
897
+ [
898
+ processed_initial_eval_fut ,
899
+ modified_eval ,
900
+ ]
901
+ )
865
902
866
- ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = (
867
- eval_futs .then (
868
- lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : self ._eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
869
- eval_futs = eval_futs ,
870
- current_inputs = current_inputs ,
871
- current_mask = current_mask ,
872
- i = i ,
873
- perturbations_per_eval = perturbations_per_eval ,
874
- num_examples = num_examples ,
875
- formatted_inputs = formatted_inputs ,
876
- )
903
+ ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = (
904
+ eval_futs .then (
905
+ lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : self ._eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
906
+ eval_futs = eval_futs ,
907
+ current_inputs = current_inputs ,
908
+ current_mask = current_mask ,
909
+ i = i ,
910
+ perturbations_per_eval = perturbations_per_eval ,
911
+ num_examples = num_examples ,
912
+ formatted_inputs = formatted_inputs ,
877
913
)
878
914
)
915
+ )
879
916
880
- all_modified_eval_futures [i ].append (ablated_out_fut )
917
+ all_modified_eval_futures [i ].append (ablated_out_fut )
881
918
882
- if show_progress :
883
- attr_progress .close ()
919
+ if attr_progress is not None :
920
+ attr_progress .close ()
884
921
885
- return self ._generate_async_result (all_modified_eval_futures , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
922
+ return self ._generate_async_result (all_modified_eval_futures , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
886
923
887
924
# pyre-fixme[3] return type must be annotated
888
925
def _attribute_progress_setup (
0 commit comments