|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | + |
| 8 | +from typing import NamedTuple, Optional |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +from torch import nn, Tensor |
| 13 | +from torch.nn import functional as F |
| 14 | +from torchmultimodal.modules.layers.transformer import TransformerOutput |
| 15 | + |
| 16 | + |
| 17 | +class Blip2Output(NamedTuple): |
| 18 | + """ |
| 19 | + BLIP2 model output for loss computation. |
| 20 | +
|
| 21 | + image_embeddings(Tensor): normalized image embeddings returned by the visual encoder |
| 22 | + with shape [bsz x seq_len x embed_dim]. |
| 23 | + image_features(Tensor): Image features after qformer and projection (for stage 1 training) |
| 24 | + with shape [bsz, num_query_tokens, embed_dim] |
| 25 | + image_qformer_output(Tensor) : last hidden state for qformer output by given image input |
| 26 | + text_features(Optional[Tensor]): Text features after qformer and projection if text input is provided |
| 27 | + with shape [bsz, embed_dim] |
| 28 | + prediction_scores (Optional[Tensor]): computed for next word prediction |
| 29 | + with shape of [bsz, seq_len, vocab_size] |
| 30 | + """ |
| 31 | + |
| 32 | + image_embeddings: Tensor |
| 33 | + image_features: Tensor |
| 34 | + image_qformer_output: Tensor |
| 35 | + text_features: Optional[Tensor] = None |
| 36 | + prediction_scores: Optional[Tensor] = None |
| 37 | + |
| 38 | + |
| 39 | +class BLIP2(nn.Module): |
| 40 | + """ |
| 41 | + BLIP2(https://arxiv.org/pdf/2301.12597.pdf) provides a pre-training strategy to bootstrap vision-language |
| 42 | + pre-training from frozen image encoders and frozen large language models(LLM). BLIP-2 bridges the modality gap |
| 43 | + and facilitates cross-modal alignment via Querying Transformer (Q-former). Q-former is a lightweight transformer |
| 44 | + which has a set of learnable query vectors to extract visual features from the frozen image encoder. |
| 45 | +
|
| 46 | + Args: |
| 47 | + qformer(nn.Module): Querying Transformer (Q-former) |
| 48 | + visual_encoder(nn.Module): Frozen image encoder |
| 49 | + dim_q(int) : Dimension of query tensor, this value should be the same as dim_q in qformer. |
| 50 | + image_encoder_embedding_dim(int): Embedding dimension for image encoder, |
| 51 | + this value should be the same as dim_kv in qformer. |
| 52 | + freeze_visual_encoder(bool): Whether to freeze the visual encoder, default to True |
| 53 | + cross_attention_freq(int): Frequency of adding cross-attention block in Qformer, default to 2 |
| 54 | + embedding_dim(int): Embedding dimension |
| 55 | + num_query_token(int): Number of query tokens in Qformer, default to 32 |
| 56 | + init_query_tokens(bool): whether init query token params, default to True |
| 57 | + decoder_bos_token_id(Optional[int]): bos_token_id used in decoder, default to None |
| 58 | + """ |
| 59 | + |
| 60 | + def __init__( |
| 61 | + self, |
| 62 | + qformer: nn.Module, |
| 63 | + vision_encoder: nn.Module, |
| 64 | + dim_q: int, |
| 65 | + image_encoder_embedding_dim: int, |
| 66 | + freeze_vision_encoder: bool = True, |
| 67 | + cross_attention_freq: int = 2, |
| 68 | + embedding_dim: int = 256, |
| 69 | + num_query_token: int = 32, |
| 70 | + init_query_tokens: bool = True, |
| 71 | + decoder_bos_token_id: Optional[int] = None, |
| 72 | + ): |
| 73 | + super().__init__() |
| 74 | + self.vision_encoder = vision_encoder |
| 75 | + if freeze_vision_encoder: |
| 76 | + for param in self.vision_encoder.parameters(): |
| 77 | + param.requires_grad = False |
| 78 | + self.vision_encoder = self.vision_encoder.eval() |
| 79 | + |
| 80 | + self.qformer = qformer |
| 81 | + self.decoder_bos_token_id = decoder_bos_token_id |
| 82 | + self.dim_q = dim_q |
| 83 | + self.query_tokens = nn.Parameter(torch.zeros(1, num_query_token, self.dim_q)) |
| 84 | + if init_query_tokens: |
| 85 | + self.query_tokens.data.normal_(mean=0.0, std=0.02) |
| 86 | + |
| 87 | + self.vision_proj = nn.Linear(self.dim_q, embedding_dim) |
| 88 | + self.text_proj = nn.Linear(self.dim_q, embedding_dim) |
| 89 | + self.ln_vision = nn.LayerNorm(image_encoder_embedding_dim) |
| 90 | + |
| 91 | + def forward( |
| 92 | + self, |
| 93 | + image: Tensor, |
| 94 | + input_ids: Optional[Tensor] = None, |
| 95 | + attention_mask: Optional[Tensor] = None, |
| 96 | + ) -> Blip2Output: |
| 97 | + """ |
| 98 | + Args: |
| 99 | + image(Tensor): Image input tensor with shape [B, C, H, W] |
| 100 | + input_ids(Optional[Tensor]): Text input tensor with shape [bsz, seq_len] |
| 101 | + attention_mask(Optional[Tensor]): Attention mask tensor with shape [bsz, seq_len] |
| 102 | +
|
| 103 | + Returns: |
| 104 | + return BLIP2 model output(Blip2Output). |
| 105 | + """ |
| 106 | + vision_encoder_output = self.vision_encoder(image) |
| 107 | + if isinstance(vision_encoder_output, TransformerOutput): |
| 108 | + vision_encoder_output = vision_encoder_output.last_hidden_state |
| 109 | + assert vision_encoder_output is not None |
| 110 | + image_embeds = self.ln_vision(vision_encoder_output) |
| 111 | + # query tokens: [batch_size, num_query_token, encoder_hidden_size] |
| 112 | + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
| 113 | + query_output = self.qformer.model( |
| 114 | + query_embeds=query_tokens, |
| 115 | + encoder_hidden_states=image_embeds, |
| 116 | + use_cache=True, |
| 117 | + ) |
| 118 | + |
| 119 | + # image_feats: [batch_size, num_query_token, embedding_dim] |
| 120 | + image_feats = F.normalize(self.vision_proj(query_output[0]), dim=-1) |
| 121 | + |
| 122 | + text_feats: Optional[Tensor] = None |
| 123 | + prediction_scores: Optional[Tensor] = None |
| 124 | + if input_ids is not None: |
| 125 | + text_output = self.qformer.model( |
| 126 | + input_ids, |
| 127 | + attention_mask=attention_mask, |
| 128 | + use_cache=False, |
| 129 | + ) |
| 130 | + text_feats = F.normalize(self.text_proj(text_output[0][:, 0, :]), dim=-1) |
| 131 | + |
| 132 | + decoder_input_ids = input_ids.clone() |
| 133 | + if self.decoder_bos_token_id is not None: |
| 134 | + # pyre-ignore |
| 135 | + decoder_input_ids[:, 0] = self.decoder_bos_token_id |
| 136 | + |
| 137 | + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( |
| 138 | + input_ids.device |
| 139 | + ) |
| 140 | + if attention_mask is not None: |
| 141 | + attention_mask = torch.cat([query_atts, attention_mask], dim=1) |
| 142 | + |
| 143 | + # set use_cache = False since past_key_values should be cached in previous steps. |
| 144 | + prediction_scores = self.qformer( |
| 145 | + input_ids=decoder_input_ids, |
| 146 | + attention_mask=attention_mask, |
| 147 | + past_key_values=query_output[1], |
| 148 | + use_cache=False, |
| 149 | + ) |
| 150 | + |
| 151 | + return Blip2Output( |
| 152 | + image_embeddings=image_embeds, |
| 153 | + image_features=image_feats, |
| 154 | + image_qformer_output=query_output[0], |
| 155 | + text_features=text_feats, |
| 156 | + prediction_scores=prediction_scores, |
| 157 | + ) |
0 commit comments