diff --git a/torchmultimodal/models/flava/model.py b/torchmultimodal/models/flava/model.py index e2431c8e..b696a390 100644 --- a/torchmultimodal/models/flava/model.py +++ b/torchmultimodal/models/flava/model.py @@ -62,7 +62,7 @@ FLAVA_FOR_PRETRAINED_MAPPING = { # This will no longer load with the updated model, but keeping here just in case # "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin", - "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm.pt", + "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt", } FLAVA_MODEL_MAPPING = { @@ -314,6 +314,50 @@ def forward(self, hidden_states: Tensor): return logits +class MaskedPredictionHead(nn.Module): + def __init__( + self, + hidden_size: int = 768, + vocab_size: int = 30522, + transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu, + layer_norm_eps: float = 1e-5, + use_fp32_layer_norm: bool = True, + ignore_index: int = -1, + **kwargs: Any, + ): + super().__init__() + + self.dense = nn.Linear(hidden_size, hidden_size) + self.transform_act_fn = transform_act_fn + + self.layer_norm: nn.LayerNorm + if use_fp32_layer_norm: + self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps) + else: + self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(vocab_size)) + + # Need a link between the two variables so that the bias is + # correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + self.ignore_index = ignore_index + + def forward(self, hidden_states: Tensor, masked_labels: Tensor) -> Tensor: + masked_tokens = masked_labels.ne(self.ignore_index) + sequence_output = hidden_states[masked_tokens, :] + + head_output = self.dense(sequence_output) + head_output = self.transform_act_fn(head_output) + head_output = self.layer_norm(head_output) + head_output = self.decoder(head_output) + return head_output + + class FLAVAForPreTraining(nn.Module, PretrainedMixin): # TODOs: # 1. Expose logit scale @@ -325,12 +369,20 @@ def __init__( image_codebook: nn.Module, loss: nn.Module, itm_head: nn.Module, + mlm_head: nn.Module, + mim_head: nn.Module, + mmm_mlm_head: nn.Module, + mmm_mim_head: nn.Module, ): super().__init__() self.model = model self.image_codebook = image_codebook self.loss = loss self.itm_head = itm_head + self.mlm_head = mlm_head + self.mim_head = mim_head + self.mmm_mlm_head = mmm_mlm_head + self.mmm_mim_head = mmm_mim_head def encode_image( self, @@ -380,24 +432,83 @@ def forward( ) multimodal_masked_sequence = flava_output.multimodal_masked.last_hidden_state itm_logits = None + + image_masked_sequence = flava_output.image_masked.last_hidden_state + text_masked_sequence = flava_output.text_masked.last_hidden_state + mlm_head_output = ( + mim_head_output + ) = mmm_mlm_head_output = mmm_mim_head_output = None + pos_mask = None + if image_masked_sequence is not None and multimodal_masked_sequence is None: + # Remove CLS token from image_masked_sequence + start_index = -image_labels.size(1) if image_labels is not None else 1 + mim_head_output = self.mim_head( + image_masked_sequence[:, start_index:, :], image_labels + ) + + if text_masked_sequence is not None and multimodal_masked_sequence is None: + start_index = -mlm_labels.size(1) if mlm_labels is not None else 1 + mlm_head_output = self.mlm_head( + text_masked_sequence[:, start_index:, :], mlm_labels + ) + + mmm_mlm_labels = mlm_labels + mmm_mim_labels = image_labels + if multimodal_masked_sequence is not None: + if itm_labels is not None: + pos_pairs = itm_labels.ne(0) + pos_mask = torch.where( + pos_pairs.any(), pos_pairs, pos_pairs.new([True]) + ) + else: + pos_mask = torch.ones( + multimodal_masked_sequence.size(0), + device=multimodal_masked_sequence.device, + ).bool() itm_logits = self.itm_head(multimodal_masked_sequence) + multimodal_masked_sequence = multimodal_masked_sequence[pos_mask] + if mlm_labels is not None: + mmm_mlm_labels = mlm_labels[pos_mask] + if image_labels is not None: + mmm_mim_labels = image_labels[pos_mask] + + if multimodal_masked_sequence is not None: + start_index = ( + -mmm_mlm_labels.size(1) + if mmm_mlm_labels is not None + else -(text_masked_sequence.size(1) - 1) + ) + sequence_for_text = multimodal_masked_sequence[:, start_index:, :] + mmm_mlm_head_output = self.mmm_mlm_head(sequence_for_text, mmm_mlm_labels) + + if multimodal_masked_sequence is not None: + # Starts from 2 because of 2 CLS, one for multimodal encoder and one + # that comes from image encoder. + total_indices = ( + mmm_mim_labels.size(1) + if mmm_mim_labels is not None + else (image_masked_sequence.size(1) - 1) + ) + sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :] + mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, mmm_mim_labels) + return self.loss( - image_sequence=flava_output.image.last_hidden_state, - text_sequence=flava_output.text.last_hidden_state, - image_masked_sequence=flava_output.image_masked.last_hidden_state, - text_masked_sequence=flava_output.text_masked.last_hidden_state, - multimodal_sequence=flava_output.multimodal.last_hidden_state - if not skip_unmasked_mm_encoder - else None, multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state, + pos_mask=pos_mask, itm_labels=itm_labels, mim_labels=image_labels, mlm_labels=mlm_labels, + mmm_mlm_labels=mmm_mlm_labels, + mmm_mim_labels=mmm_mim_labels, projected_image_embeddings=flava_output.projected_image_embeddings, projected_text_embeddings=flava_output.projected_text_embeddings, itm_logits=itm_logits, + mlm_head_output=mlm_head_output, + mim_head_output=mim_head_output, + mmm_mlm_head_output=mmm_mlm_head_output, + mmm_mim_head_output=mmm_mim_head_output, ) @@ -548,17 +659,36 @@ def flava_model( def flava_model_for_pretraining( codebook_image_size: int = 112, pretrained_model_key: Optional[str] = None, + image_vocab_size: int = 8192, **flava_model_kwargs: Any, # TODO: Add parameters for loss here ) -> FLAVAForPreTraining: model = flava_model(**flava_model_kwargs) hidden_size = flava_model_kwargs.get("hidden_size") or 768 + text_vocab_size = flava_model_kwargs.get("vocab_size") or 30522 itm_head = ITMHead(hidden_size) + mlm_head = MaskedPredictionHead(hidden_size=hidden_size, vocab_size=text_vocab_size) + mim_head = MaskedPredictionHead( + hidden_size=hidden_size, vocab_size=image_vocab_size + ) + mmm_mlm_head = MaskedPredictionHead( + hidden_size=hidden_size, vocab_size=text_vocab_size + ) + mmm_mim_head = MaskedPredictionHead( + hidden_size=hidden_size, vocab_size=image_vocab_size + ) losses = FLAVAPretrainingLoss() codebook = DalleVAEEncoder(image_size=codebook_image_size) flava = FLAVAForPreTraining( - model=model, image_codebook=codebook, loss=losses, itm_head=itm_head + model=model, + image_codebook=codebook, + loss=losses, + itm_head=itm_head, + mlm_head=mlm_head, + mim_head=mim_head, + mmm_mlm_head=mmm_mlm_head, + mmm_mim_head=mmm_mim_head, ) if pretrained_model_key is not None: diff --git a/torchmultimodal/modules/losses/flava.py b/torchmultimodal/modules/losses/flava.py index 834b8e2c..6839b45e 100644 --- a/torchmultimodal/modules/losses/flava.py +++ b/torchmultimodal/modules/losses/flava.py @@ -7,11 +7,10 @@ import math import warnings from dataclasses import dataclass, field -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch from torch import nn, Tensor -from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm from torchmultimodal.modules.losses.contrastive_loss_with_temperature import ( contrastive_loss_with_temperature, ContrastiveLossOutput, @@ -119,83 +118,28 @@ def forward( return ITMLossOutput(logits=scores, loss=loss) -class MaskedPredictionHead(nn.Module): - def __init__( - self, - hidden_size: int = 768, - vocab_size: int = 30522, - transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu, - layer_norm_eps: float = 1e-5, - use_fp32_layer_norm: bool = True, - **kwargs: Any, - ): - super().__init__() - - self.dense = nn.Linear(hidden_size, hidden_size) - self.transform_act_fn = transform_act_fn - - self.layer_norm: nn.LayerNorm - if use_fp32_layer_norm: - self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps) - else: - self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(vocab_size)) - - # Need a link between the two variables so that the bias is - # correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states: Tensor) -> Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - class MaskedPredictionLoss(nn.Module): def __init__( self, - hidden_size: int = 768, vocab_size: int = 30522, - transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu, - layer_norm_eps: float = 1e-5, ignore_index: int = -1, ignore_nan: bool = False, **kwargs: Any, ): super().__init__() - - self.cls = MaskedPredictionHead( - hidden_size=hidden_size, - vocab_size=vocab_size, - transform_act_fn=transform_act_fn, - layer_norm_eps=layer_norm_eps, - ) self.ignore_index = ignore_index self.vocab_size = vocab_size self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index) self.ignore_nan = ignore_nan def forward( - self, hidden_states: Tensor, masked_labels: Optional[Tensor] = None + self, + prediction: Tensor, + masked_labels: Optional[Tensor] = None, ) -> MaskedPredictionLossOutput: - if self.training: - assert_labels_are_present(masked_labels, "masked labels") - - if masked_labels is not None: - masked_tokens = masked_labels.ne(self.ignore_index) - masked_labels = masked_labels[masked_tokens] - sequence_output = hidden_states[masked_tokens, :] - else: - sequence_output = hidden_states - prediction = self.cls(sequence_output) + masked_tokens = masked_labels.ne(self.ignore_index) + masked_labels = masked_labels[masked_tokens] if masked_labels is None: masked_loss = prediction.sum() * 0 @@ -221,11 +165,6 @@ class FLAVAGlobalContrastiveLoss(nn.Module): def __init__( self, logit_scale: Union[float, nn.Parameter] = None, - image_embedding_size: int = 768, - text_embedding_size: int = 768, - projection_size: int = 768, - image_embedding_index: int = 0, - text_embedding_index: int = 0, ): super().__init__() if logit_scale is None: @@ -277,11 +216,8 @@ class FLAVAPretrainingLoss(nn.Module): def __init__( self, logit_scale: Union[float, nn.Parameter] = None, - hidden_size: int = 768, text_vocab_size: int = 30522, image_vocab_size: int = 8192, - transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu, - layer_norm_eps: float = 1e-5, ignore_index: int = -1, mlm_weight: float = 1.0, mim_weight: float = 1.0, @@ -297,39 +233,24 @@ def __init__( ) self.contrastive_loss = FLAVAGlobalContrastiveLoss( logit_scale=logit_scale, - image_embedding_size=hidden_size, - text_embedding_size=hidden_size, - projection_size=hidden_size, ) self.mlm_loss = MaskedPredictionLoss( - hidden_size=hidden_size, vocab_size=text_vocab_size, - transform_act_fn=transform_act_fn, - layer_norm_eps=layer_norm_eps, ignore_index=ignore_index, ) self.mim_loss = MaskedPredictionLoss( - hidden_size=hidden_size, vocab_size=image_vocab_size, - transform_act_fn=transform_act_fn, - layer_norm_eps=layer_norm_eps, ignore_index=ignore_index, ) # Create separate weights for MMM loss self.mmm_loss = nn.ModuleDict( { "mlm": MaskedPredictionLoss( - hidden_size=hidden_size, vocab_size=text_vocab_size, - transform_act_fn=transform_act_fn, - layer_norm_eps=layer_norm_eps, ignore_index=ignore_index, ), "mim": MaskedPredictionLoss( - hidden_size=hidden_size, vocab_size=image_vocab_size, - transform_act_fn=transform_act_fn, - layer_norm_eps=layer_norm_eps, ignore_index=ignore_index, ), } @@ -347,101 +268,63 @@ def __init__( # for better usability def forward( self, - image_sequence: Optional[Tensor] = None, - text_sequence: Optional[Tensor] = None, - image_masked_sequence: Optional[Tensor] = None, - text_masked_sequence: Optional[Tensor] = None, - multimodal_sequence: Optional[Tensor] = None, multimodal_masked_sequence: Optional[Tensor] = None, + pos_mask: Optional[Tensor] = None, itm_labels: Optional[Tensor] = None, mim_labels: Optional[Tensor] = None, mlm_labels: Optional[Tensor] = None, + mmm_mim_labels: Optional[Tensor] = None, + mmm_mlm_labels: Optional[Tensor] = None, projected_image_embeddings: Optional[Tensor] = None, projected_text_embeddings: Optional[Tensor] = None, itm_logits: Optional[Tensor] = None, + mlm_head_output: Optional[Tensor] = None, + mim_head_output: Optional[Tensor] = None, + mmm_mlm_head_output: Optional[Tensor] = None, + mmm_mim_head_output: Optional[Tensor] = None, ) -> FLAVAPretrainingLossOutput: outputs = FLAVAPretrainingLossOutput() - pos_mask = None # Check multimodal_masked_sequence to make sure this is unimodal case # This specific case can though be backpropagated directly as MIM is independent of # text, but that is a research question :) if ( - image_masked_sequence is not None + mim_head_output is not None and self.mim_weight > 0 and multimodal_masked_sequence is None ): - # Remove CLS token from image_masked_sequence - - start_index = -mim_labels.size(1) if mim_labels is not None else 1 - outputs.mim_output = self.mim_loss( - image_masked_sequence[:, start_index:, :], mim_labels - ) + outputs.mim_output = self.mim_loss(mim_head_output, mim_labels) outputs.mim_output.loss *= self.mim_weight outputs.losses.mim_loss = outputs.mim_output.loss # Check multimodal_masked_sequence to make sure this is unimodal case + if ( - text_masked_sequence is not None + mlm_head_output is not None and self.mlm_weight > 0 and multimodal_masked_sequence is None ): - start_index = -mlm_labels.size(1) if mlm_labels is not None else 1 - outputs.mlm_output = self.mlm_loss( - text_masked_sequence[:, start_index:, :], mlm_labels - ) + outputs.mlm_output = self.mlm_loss(mlm_head_output, mlm_labels) outputs.mlm_output.loss *= self.mlm_weight outputs.losses.mlm_loss = outputs.mlm_output.loss if multimodal_masked_sequence is not None and self.itm_loss_weight > 0: assert itm_logits is not None - if itm_labels is not None: - pos_pairs = itm_labels.ne(0) - pos_mask = torch.where( - pos_pairs.any(), pos_pairs, pos_pairs.new([True]) - ) - else: - pos_mask = torch.ones( - multimodal_masked_sequence.size(0), - device=multimodal_masked_sequence.device, - ).bool() outputs.itm_output = self.itm_loss(itm_logits, itm_labels) outputs.itm_output.loss *= self.itm_loss_weight outputs.losses.itm_loss = outputs.itm_output.loss - multimodal_masked_sequence = multimodal_masked_sequence[pos_mask] - if mlm_labels is not None: - mlm_labels = mlm_labels[pos_mask] - if mim_labels is not None: - mim_labels = mim_labels[pos_mask] - - if multimodal_masked_sequence is not None and self.mmm_text_loss_weight > 0: - start_index = ( - -mlm_labels.size(1) - if mlm_labels is not None - else -(text_masked_sequence.size(1) - 1) - ) - sequence_for_text = multimodal_masked_sequence[:, start_index:, :] + if mmm_mlm_head_output is not None and self.mmm_text_loss_weight > 0: outputs.mmm_text_output = self.mmm_loss.mlm( - sequence_for_text, - mlm_labels, + mmm_mlm_head_output, mmm_mlm_labels ) # type: ignore outputs.mmm_text_output.loss *= self.mmm_text_loss_weight outputs.losses.mmm_text_loss = outputs.mmm_text_output.loss - if multimodal_masked_sequence is not None and self.mmm_image_loss_weight > 0: - # Starts from 2 because of 2 CLS, one for multimodal encoder and one - # that comes from image encoder. - total_indices = ( - mim_labels.size(1) - if mlm_labels is not None - else (image_masked_sequence.size(1) - 1) - ) - sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :] + if mmm_mim_head_output is not None and self.mmm_image_loss_weight > 0: outputs.mmm_image_output = self.mmm_loss.mim( - sequence_for_image, - mim_labels, + mmm_mim_head_output, mmm_mim_labels ) # type: ignore outputs.mmm_image_output.loss *= self.mmm_image_loss_weight outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss