Skip to content

Commit a01e332

Browse files
Peng Chenfacebook-github-bot
authored andcommitted
add blip2 layer under torchmm/models (#484)
Summary: Pull Request resolved: #484 as title Differential Revision: D50145708 fbshipit-source-id: 63f11c36beb72eefc79c414db48e944906bcd3b7
1 parent 1d40a81 commit a01e332

File tree

2 files changed

+294
-0
lines changed

2 files changed

+294
-0
lines changed

tests/models/blip2/test_blip2.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import pytest
9+
import torch
10+
import torch.nn as nn
11+
from tests.test_utils import assert_expected, init_weights_with_constant
12+
from torchmultimodal.models.blip2.blip2 import BLIP2
13+
from torchmultimodal.models.blip2.qformer_model import QformerForCLM
14+
from torchmultimodal.modules.encoders.vision_transformer import VisionTransformer
15+
from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings
16+
from torchmultimodal.modules.layers.transformer import TransformerEncoder
17+
18+
19+
@pytest.fixture
20+
def dim_q():
21+
return 4
22+
23+
24+
@pytest.fixture
25+
def dim_kv():
26+
return 2
27+
28+
29+
@pytest.fixture
30+
def dim_feedforward():
31+
return 6
32+
33+
34+
@pytest.fixture
35+
def num_hidden_layers():
36+
return 2
37+
38+
39+
@pytest.fixture
40+
def num_heads():
41+
return 2
42+
43+
44+
@pytest.fixture
45+
def vocab_size():
46+
return 20
47+
48+
49+
@pytest.fixture
50+
def qformer_model_for_clm(
51+
dim_q,
52+
dim_kv,
53+
dim_feedforward,
54+
num_hidden_layers,
55+
num_heads,
56+
vocab_size,
57+
):
58+
qformer_for_clm = QformerForCLM(
59+
dim_q=dim_q,
60+
dim_kv=dim_kv,
61+
dim_feedforward=dim_feedforward,
62+
num_heads=num_heads,
63+
attn_dropout=0.0,
64+
dropout=0.0,
65+
num_hidden_layers=num_hidden_layers,
66+
max_position_embeddings=512,
67+
vocab_size=vocab_size,
68+
)
69+
return qformer_for_clm
70+
71+
72+
@pytest.fixture
73+
def vit():
74+
embedding = PatchEmbeddings(image_size=2, patch_size=1, hidden_size=2)
75+
encoder = TransformerEncoder(
76+
n_layer=1,
77+
d_model=2,
78+
n_head=1,
79+
dim_feedforward=1,
80+
activation=nn.GELU,
81+
norm_first=True,
82+
final_layer_norm_eps=1e-5,
83+
)
84+
image_encoder = VisionTransformer(
85+
embeddings=embedding,
86+
encoder=encoder,
87+
)
88+
init_weights_with_constant(image_encoder)
89+
image_encoder.eval()
90+
return image_encoder
91+
92+
93+
@pytest.fixture
94+
def blip2(dim_q, dim_kv, qformer_model_for_clm, vit):
95+
blip2 = BLIP2(
96+
dim_q=dim_q,
97+
image_encoder_embedding_dim=dim_kv,
98+
qformer=qformer_model_for_clm,
99+
vision_encoder=vit,
100+
embedding_dim=4,
101+
decoder_bos_token_id=19,
102+
)
103+
init_weights_with_constant(blip2)
104+
blip2.eval()
105+
return blip2
106+
107+
108+
@pytest.fixture
109+
def attn_mask():
110+
return torch.Tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]])
111+
112+
113+
class TestBLIP2:
114+
def test_blip2(self, blip2, attn_mask):
115+
image = torch.ones(2, 3, 2, 2)
116+
input_ids = torch.ones(2, 4).long()
117+
output = blip2(image, input_ids, attn_mask)
118+
assert_expected(
119+
output.image_features, torch.ones([2, 32, 4]) * 0.5, rtol=0, atol=1e-4
120+
)
121+
assert_expected(
122+
output.text_features, torch.ones([2, 4]) * 0.5, rtol=0, atol=1e-4
123+
)
124+
assert_expected(
125+
output.image_embeddings, torch.ones([2, 5, 2]), rtol=0, atol=1e-4
126+
)
127+
assert_expected(
128+
output.prediction_scores, torch.ones([2, 4, 20]) * 5, rtol=0, atol=1e-4
129+
)
130+
131+
def test_blip2_scripting(self, blip2, attn_mask):
132+
image = torch.ones(2, 3, 2, 2)
133+
input_ids = torch.ones(2, 4).long()
134+
scripted_model = torch.jit.script(blip2)
135+
actual = scripted_model(image, input_ids, attn_mask)
136+
expected = blip2(image, input_ids, attn_mask)
137+
assert_expected(actual, expected)
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import NamedTuple, Optional
9+
10+
import torch
11+
12+
from torch import nn, Tensor
13+
from torch.nn import functional as F
14+
from torchmultimodal.modules.layers.transformer import TransformerOutput
15+
16+
17+
class Blip2Output(NamedTuple):
18+
"""
19+
BLIP2 model output for loss computation.
20+
21+
image_embeddings(Tensor): normalized image embeddings returned by the visual encoder
22+
with shape [bsz x seq_len x embed_dim].
23+
image_features(Tensor): Image features after qformer and projection (for stage 1 training)
24+
with shape [bsz, num_query_tokens, embed_dim]
25+
image_qformer_output(Tensor) : last hidden state for qformer output by given image input
26+
text_features(Optional[Tensor]): Text features after qformer and projection if text input is provided
27+
with shape [bsz, embed_dim]
28+
prediction_scores (Optional[Tensor]): computed for next word prediction
29+
with shape of [bsz, seq_len, vocab_size]
30+
"""
31+
32+
image_embeddings: Tensor
33+
image_features: Tensor
34+
image_qformer_output: Tensor
35+
text_features: Optional[Tensor] = None
36+
prediction_scores: Optional[Tensor] = None
37+
38+
39+
class BLIP2(nn.Module):
40+
"""
41+
BLIP2(https://arxiv.org/pdf/2301.12597.pdf) provides a pre-training strategy to bootstrap vision-language
42+
pre-training from frozen image encoders and frozen large language models(LLM). BLIP-2 bridges the modality gap
43+
and facilitates cross-modal alignment via Querying Transformer (Q-former). Q-former is a lightweight transformer
44+
which has a set of learnable query vectors to extract visual features from the frozen image encoder.
45+
46+
Args:
47+
qformer(nn.Module): Querying Transformer (Q-former)
48+
visual_encoder(nn.Module): Frozen image encoder
49+
dim_q(int) : Dimension of query tensor, this value should be the same as dim_q in qformer.
50+
image_encoder_embedding_dim(int): Embedding dimension for image encoder,
51+
this value should be the same as dim_kv in qformer.
52+
freeze_visual_encoder(bool): Whether to freeze the visual encoder, default to True
53+
cross_attention_freq(int): Frequency of adding cross-attention block in Qformer, default to 2
54+
embedding_dim(int): Embedding dimension
55+
num_query_token(int): Number of query tokens in Qformer, default to 32
56+
init_query_tokens(bool): whether init query token params, default to True
57+
decoder_bos_token_id(Optional[int]): bos_token_id used in decoder, default to None
58+
"""
59+
60+
def __init__(
61+
self,
62+
qformer: nn.Module,
63+
vision_encoder: nn.Module,
64+
dim_q: int,
65+
image_encoder_embedding_dim: int,
66+
freeze_vision_encoder: bool = True,
67+
cross_attention_freq: int = 2,
68+
embedding_dim: int = 256,
69+
num_query_token: int = 32,
70+
init_query_tokens: bool = True,
71+
decoder_bos_token_id: Optional[int] = None,
72+
):
73+
super().__init__()
74+
self.vision_encoder = vision_encoder
75+
if freeze_vision_encoder:
76+
for param in self.vision_encoder.parameters():
77+
param.requires_grad = False
78+
self.vision_encoder = self.vision_encoder.eval()
79+
80+
self.qformer = qformer
81+
self.decoder_bos_token_id = decoder_bos_token_id
82+
self.dim_q = dim_q
83+
self.query_tokens = nn.Parameter(torch.zeros(1, num_query_token, self.dim_q))
84+
if init_query_tokens:
85+
self.query_tokens.data.normal_(mean=0.0, std=0.02)
86+
87+
self.vision_proj = nn.Linear(self.dim_q, embedding_dim)
88+
self.text_proj = nn.Linear(self.dim_q, embedding_dim)
89+
self.ln_vision = nn.LayerNorm(image_encoder_embedding_dim)
90+
91+
def forward(
92+
self,
93+
image: Tensor,
94+
input_ids: Optional[Tensor] = None,
95+
attention_mask: Optional[Tensor] = None,
96+
) -> Blip2Output:
97+
"""
98+
Args:
99+
image(Tensor): Image input tensor with shape [B, C, H, W]
100+
input_ids(Optional[Tensor]): Text input tensor with shape [bsz, seq_len]
101+
attention_mask(Optional[Tensor]): Attention mask tensor with shape [bsz, seq_len]
102+
103+
Returns:
104+
return BLIP2 model output(Blip2Output).
105+
"""
106+
vision_encoder_output = self.vision_encoder(image)
107+
if isinstance(vision_encoder_output, TransformerOutput):
108+
vision_encoder_output = vision_encoder_output.last_hidden_state
109+
assert vision_encoder_output is not None
110+
image_embeds = self.ln_vision(vision_encoder_output)
111+
# query tokens: [batch_size, num_query_token, encoder_hidden_size]
112+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
113+
query_output = self.qformer.model(
114+
query_embeds=query_tokens,
115+
encoder_hidden_states=image_embeds,
116+
use_cache=True,
117+
)
118+
119+
# image_feats: [batch_size, num_query_token, embedding_dim]
120+
image_feats = F.normalize(self.vision_proj(query_output[0]), dim=-1)
121+
122+
text_feats: Optional[Tensor] = None
123+
prediction_scores: Optional[Tensor] = None
124+
if input_ids is not None:
125+
text_output = self.qformer.model(
126+
input_ids,
127+
attention_mask=attention_mask,
128+
use_cache=False,
129+
)
130+
text_feats = F.normalize(self.text_proj(text_output[0][:, 0, :]), dim=-1)
131+
132+
decoder_input_ids = input_ids.clone()
133+
if self.decoder_bos_token_id is not None:
134+
# pyre-ignore
135+
decoder_input_ids[:, 0] = self.decoder_bos_token_id
136+
137+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
138+
input_ids.device
139+
)
140+
if attention_mask is not None:
141+
attention_mask = torch.cat([query_atts, attention_mask], dim=1)
142+
143+
# set use_cache = False since past_key_values should be cached in previous steps.
144+
prediction_scores = self.qformer(
145+
input_ids=decoder_input_ids,
146+
attention_mask=attention_mask,
147+
past_key_values=query_output[1],
148+
use_cache=False,
149+
)
150+
151+
return Blip2Output(
152+
image_embeddings=image_embeds,
153+
image_features=image_feats,
154+
image_qformer_output=query_output[0],
155+
text_features=text_feats,
156+
prediction_scores=prediction_scores,
157+
)

0 commit comments

Comments
 (0)