@@ -689,7 +689,9 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select
689689                    image_feature  =  torch .cat (
690690                        (
691691                            image_feature ,
692-                             image_newline [:, None , None ].expand (* image_feature .shape [:- 1 ], 1 ).to (image_feature .dtype ),
692+                             image_newline [:, None , None ]
693+                             .expand (* image_feature .shape [:- 1 ], 1 )
694+                             .to (image_feature .device , image_feature .dtype ),
693695                        ),
694696                        dim = - 1 ,
695697                    )
@@ -835,18 +837,9 @@ def forward(
835837                "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" 
836838            )
837839
838-         legacy_processing  =  False 
839840        if  inputs_embeds  is  None :
840841            inputs_embeds  =  self .get_input_embeddings ()(input_ids )
841842
842-             # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing 
843-             # not very reliable, but we don't expect one to actually pass 500+ images for one prompt 
844-             # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True 
845-             legacy_processing  =  (
846-                 (input_ids  ==  self .config .image_token_index ).sum (1 ).max () <  self .config .image_seq_length 
847-             ) or  (input_ids .shape [- 1 ] ==  1  and  pixel_values  is  not   None )
848- 
849-         image_features  =  None 
850843        if  pixel_values  is  not   None  and  pixel_values .size (0 ) >  0 :
851844            image_features  =  self .get_image_features (
852845                pixel_values ,
@@ -863,70 +856,14 @@ def forward(
863856                image_newline = self .image_newline ,
864857            )
865858
866-         if  legacy_processing :
867-             logger .warning_once (
868-                 "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " 
869-                 "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " 
870-                 "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " 
871-                 "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." 
872-             )
873-             if  input_ids .shape [1 ] !=  1 :
874-                 inputs_embeds  =  inputs_embeds .to (image_features .dtype )
875-                 inputs_embeds , attention_mask , position_ids , labels , _  =  self ._merge_input_ids_with_image_features (
876-                     image_features ,
877-                     feature_lens ,
878-                     inputs_embeds ,
879-                     input_ids ,
880-                     attention_mask ,
881-                     position_ids ,
882-                     labels = labels ,
883-                 )
884-                 cache_position  =  torch .arange (attention_mask .shape [1 ], device = attention_mask .device )
885-             else :
886-                 # Retrieve the first layer to inspect the logits and mask out the hidden states 
887-                 # that are set to 0 
888-                 first_layer_past_key_value  =  past_key_values [0 ][0 ][:, :, :, 0 ]
889- 
890-                 # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 
891-                 batch_index , non_attended_tokens  =  torch .where (first_layer_past_key_value .float ().sum (- 2 ) ==  0 )
892- 
893-                 # Get the target length 
894-                 target_length  =  input_ids .shape [1 ]
895-                 past_length  =  first_layer_past_key_value .shape [- 1 ]
896- 
897-                 extended_attention_mask  =  torch .ones (
898-                     (attention_mask .shape [0 ], past_length ),
899-                     dtype = attention_mask .dtype ,
900-                     device = attention_mask .device ,
901-                 )
902- 
903-                 # Filter out only the tokens that can be un-attended, this can happen 
904-                 # if one uses Llava + Fused modules where the cache on the 
905-                 # first iteration is already big enough, or if one passes custom cache 
906-                 valid_indices  =  non_attended_tokens  <  extended_attention_mask .size (- 1 )
907-                 new_batch_index  =  batch_index [valid_indices ]
908-                 new_non_attended_tokens  =  non_attended_tokens [valid_indices ]
909- 
910-                 # Zero-out the places where we don't need to attend 
911-                 extended_attention_mask [new_batch_index , new_non_attended_tokens ] =  0 
912-                 attention_mask  =  torch .cat ((extended_attention_mask , attention_mask [:, - target_length :]), dim = 1 )
913-                 position_ids  =  torch .sum (attention_mask , dim = 1 ).unsqueeze (- 1 ) -  1 
914-                 cache_position  =  torch .arange (attention_mask .shape [1 ], device = attention_mask .device )[- target_length :]
915- 
916-         # TODO: @raushan retain only the new behavior after v4.47 
917-         elif  image_features  is  not   None :
918859            n_image_tokens  =  (input_ids  ==  self .config .image_token_index ).sum ().item ()
919860            n_image_features  =  image_features .shape [0 ]
920861            if  n_image_tokens  !=  n_image_features :
921862                raise  ValueError (
922863                    f"Image features and image tokens do not match: tokens: { n_image_tokens }  , features { n_image_features }  " 
923864                )
924-             special_image_mask  =  (
925-                 (input_ids  ==  self .config .image_token_index )
926-                 .unsqueeze (- 1 )
927-                 .expand_as (inputs_embeds )
928-                 .to (inputs_embeds .device )
929-             )
865+             special_image_mask  =  (input_ids  ==  self .config .image_token_index ).unsqueeze (- 1 )
866+             special_image_mask  =  special_image_mask .expand_as (inputs_embeds ).to (inputs_embeds .device )
930867            image_features  =  image_features .to (inputs_embeds .device , inputs_embeds .dtype )
931868            inputs_embeds  =  inputs_embeds .masked_scatter (special_image_mask , image_features )
932869
0 commit comments