1111
1212import torch
1313from torch import nn , Tensor
14- from torchmultimodal .modules .layers .normalizations import Fp32LayerNorm
1514from torchmultimodal .modules .losses .contrastive_loss_with_temperature import (
1615 contrastive_loss_with_temperature ,
1716 ContrastiveLossOutput ,
@@ -130,45 +129,6 @@ def forward(
130129 return ITMLossOutput (logits = scores , loss = loss )
131130
132131
133- class MaskedPredictionHead (nn .Module ):
134- def __init__ (
135- self ,
136- hidden_size : int = 768 ,
137- vocab_size : int = 30522 ,
138- transform_act_fn : Callable [[Tensor ], Tensor ] = nn .functional .gelu ,
139- layer_norm_eps : float = 1e-5 ,
140- use_fp32_layer_norm : bool = True ,
141- ** kwargs : Any ,
142- ):
143- super ().__init__ ()
144-
145- self .dense = nn .Linear (hidden_size , hidden_size )
146- self .transform_act_fn = transform_act_fn
147-
148- self .layer_norm : nn .LayerNorm
149- if use_fp32_layer_norm :
150- self .layer_norm = Fp32LayerNorm (hidden_size , eps = layer_norm_eps )
151- else :
152- self .layer_norm = nn .LayerNorm (hidden_size , eps = layer_norm_eps )
153-
154- # The output weights are the same as the input embeddings, but there is
155- # an output-only bias for each token.
156- self .decoder = nn .Linear (hidden_size , vocab_size , bias = False )
157-
158- self .bias = nn .Parameter (torch .zeros (vocab_size ))
159-
160- # Need a link between the two variables so that the bias is
161- # correctly resized with `resize_token_embeddings`
162- self .decoder .bias = self .bias
163-
164- def forward (self , hidden_states : Tensor ) -> Tensor :
165- hidden_states = self .dense (hidden_states )
166- hidden_states = self .transform_act_fn (hidden_states )
167- hidden_states = self .layer_norm (hidden_states )
168- hidden_states = self .decoder (hidden_states )
169- return hidden_states
170-
171-
172132class MaskedPredictionLoss (nn .Module ):
173133 def __init__ (
174134 self ,
@@ -181,36 +141,29 @@ def __init__(
181141 ** kwargs : Any ,
182142 ):
183143 super ().__init__ ()
184-
185- self .cls = MaskedPredictionHead (
186- hidden_size = hidden_size ,
187- vocab_size = vocab_size ,
188- transform_act_fn = transform_act_fn ,
189- layer_norm_eps = layer_norm_eps ,
190- )
191144 self .ignore_index = ignore_index
192145 self .vocab_size = vocab_size
193146 self .ce_loss = nn .CrossEntropyLoss (ignore_index = ignore_index )
194147 self .ignore_nan = ignore_nan
195148
196149 def forward (
197- self , hidden_states : Tensor , masked_labels : Optional [Tensor ] = None
150+ self ,
151+ prediction : Tensor ,
152+ masked_labels : Optional [Tensor ] = None ,
153+ pos_mask : Optional [Tensor ] = None ,
198154 ) -> MaskedPredictionLossOutput :
199- if self .training :
200- assert_labels_are_present (masked_labels , "masked labels" )
201-
202- if masked_labels is not None :
203- masked_tokens = masked_labels .ne (self .ignore_index )
204- masked_labels = masked_labels [masked_tokens ]
205- sequence_output = hidden_states [masked_tokens , :]
206- else :
207- sequence_output = hidden_states
208155
209- prediction = self .cls (sequence_output )
156+ if pos_mask is not None :
157+ masked_labels = masked_labels [pos_mask ]
158+ masked_tokens = masked_labels .ne (self .ignore_index )
159+ masked_labels = masked_labels [masked_tokens ]
210160
211161 if masked_labels is None :
212162 masked_loss = prediction .sum () * 0
213163 else :
164+ if pos_mask is not None :
165+ prediction = prediction [pos_mask ]
166+ prediction = prediction [masked_tokens , :]
214167 masked_loss = self .ce_loss (
215168 prediction .view (- 1 , self .vocab_size ),
216169 masked_labels .view (- 1 ),
@@ -371,6 +324,10 @@ def forward(
371324 projected_image_embeddings : Optional [Tensor ] = None ,
372325 projected_text_embeddings : Optional [Tensor ] = None ,
373326 itm_logits : Optional [Tensor ] = None ,
327+ mlm_head_output : Optional [Tensor ] = None ,
328+ mim_head_output : Optional [Tensor ] = None ,
329+ mmm_mlm_head_output : Optional [Tensor ] = None ,
330+ mmm_mim_head_output : Optional [Tensor ] = None ,
374331 ) -> FLAVAPretrainingLossOutput :
375332 outputs = FLAVAPretrainingLossOutput ()
376333 pos_mask = None
@@ -380,28 +337,28 @@ def forward(
380337 # text, but that is a research question :)
381338
382339 if (
383- image_masked_sequence is not None
340+ mim_head_output is not None
384341 and self .mim_weight > 0
385342 and multimodal_masked_sequence is None
386343 ):
387344 # Remove CLS token from image_masked_sequence
388-
389345 start_index = - mim_labels .size (1 ) if mim_labels is not None else 1
390346 outputs .mim_output = self .mim_loss (
391- image_masked_sequence [:, start_index :, :], mim_labels
347+ mim_head_output [:, start_index :, :], mim_labels
392348 )
393349 outputs .mim_output .loss *= self .mim_weight
394350 outputs .losses .mim_loss = outputs .mim_output .loss
395351
396352 # Check multimodal_masked_sequence to make sure this is unimodal case
353+
397354 if (
398- text_masked_sequence is not None
355+ mlm_head_output is not None
399356 and self .mlm_weight > 0
400357 and multimodal_masked_sequence is None
401358 ):
402359 start_index = - mlm_labels .size (1 ) if mlm_labels is not None else 1
403360 outputs .mlm_output = self .mlm_loss (
404- text_masked_sequence [:, start_index :, :], mlm_labels
361+ mlm_head_output [:, start_index :, :], mlm_labels
405362 )
406363 outputs .mlm_output .loss *= self .mlm_weight
407364 outputs .losses .mlm_loss = outputs .mlm_output .loss
@@ -422,38 +379,23 @@ def forward(
422379 outputs .itm_output .loss *= self .itm_loss_weight
423380 outputs .losses .itm_loss = outputs .itm_output .loss
424381
425- multimodal_masked_sequence = multimodal_masked_sequence [pos_mask ]
426- if mlm_labels is not None :
427- mlm_labels = mlm_labels [pos_mask ]
428- if mim_labels is not None :
429- mim_labels = mim_labels [pos_mask ]
430-
431- if multimodal_masked_sequence is not None and self .mmm_text_loss_weight > 0 :
432- start_index = (
433- - mlm_labels .size (1 )
434- if mlm_labels is not None
435- else - (text_masked_sequence .size (1 ) - 1 )
436- )
437- sequence_for_text = multimodal_masked_sequence [:, start_index :, :]
382+ if mmm_mlm_head_output is not None and self .mmm_text_loss_weight > 0 :
438383 outputs .mmm_text_output = self .mmm_loss .mlm (
439- sequence_for_text ,
440- mlm_labels ,
384+ mmm_mlm_head_output , mlm_labels , pos_mask
441385 ) # type: ignore
442386 outputs .mmm_text_output .loss *= self .mmm_text_loss_weight
443387 outputs .losses .mmm_text_loss = outputs .mmm_text_output .loss
444388
445- if multimodal_masked_sequence is not None and self .mmm_image_loss_weight > 0 :
389+ if mmm_mim_head_output is not None and self .mmm_image_loss_weight > 0 :
446390 # Starts from 2 because of 2 CLS, one for multimodal encoder and one
447391 # that comes from image encoder.
448392 total_indices = (
449393 mim_labels .size (1 )
450394 if mlm_labels is not None
451395 else (image_masked_sequence .size (1 ) - 1 )
452396 )
453- sequence_for_image = multimodal_masked_sequence [:, 2 : 2 + total_indices , :]
454397 outputs .mmm_image_output = self .mmm_loss .mim (
455- sequence_for_image ,
456- mim_labels ,
398+ mmm_mim_head_output , mim_labels , pos_mask
457399 ) # type: ignore
458400 outputs .mmm_image_output .loss *= self .mmm_image_loss_weight
459401 outputs .losses .mmm_image_loss = outputs .mmm_image_output .loss
0 commit comments