Skip to content

Commit 01f749f

Browse files
committed
added comments
1 parent 9e6dabb commit 01f749f

File tree

6 files changed

+254
-46
lines changed

6 files changed

+254
-46
lines changed

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
This is a **PyTorch Tutorial to Image Captioning**.
1+
This is a **[PyTorch](https://pytorch.org) Tutorial to Image Captioning**.
22

3-
This is the first of a series of tutorials I plan to write about _implementing_ cool models on your own with the amazing [PyTorch](https://pytorch.org) library.
3+
This is the first in a series of tutorials I plan to write about _implementing_ cool models on your own with the amazing PyTorch library.
44

55
Basic knowledge of PyTorch, convolutional and recurrent neural networks is assumed.
66

@@ -24,9 +24,9 @@ I'm using `PyTorch 0.4` in `Python 3.6`.
2424

2525
**To build a model that can generate a descriptive caption for an image we provide it.**
2626

27-
In the interest of keeping things simple, let's choose to implement the [_Show, Attend, and Tell_](https://arxiv.org/abs/1502.03044) paper. This is by no means the current state-of-the-art, but is still pretty darn amazing.
27+
In the interest of keeping things simple, let's implement the [_Show, Attend, and Tell_](https://arxiv.org/abs/1502.03044) paper. This is by no means the current state-of-the-art, but is still pretty darn amazing.
2828

29-
**This model learns _where_ to look.**
29+
This model learns _where_ to look.
3030

3131
As you generate a caption, word by word, you can see the the model's gaze shifting across the image.
3232

@@ -78,7 +78,7 @@ There are more examples at the [end of the tutorial](https://github.com/sgrvinod
7878

7979
# Overview
8080

81-
In this section, I will present a broad overview of this model. I don't really get into the _minutiae_ here - feel free to skip to the implementation section and commented code for details.
81+
In this section, I will present a broad overview of this model. If you're already familiar with it, you can skip straight to the implementation section or the commented code.
8282

8383
### Encoder
8484

caption.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@
1414

1515

1616
def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):
17+
"""
18+
Reads an image and captions it with beam search.
19+
20+
:param encoder: encoder model
21+
:param decoder: decoder model
22+
:param image_path: path to image
23+
:param word_map: word map
24+
:param beam_size: number of sequences to consider at each decode-step
25+
:return: caption, weights for visualization
26+
"""
27+
28+
k = beam_size
29+
vocab_size = len(word_map)
30+
31+
# Read image and process
1732
img = imread(image_path)
1833
if len(img.shape) == 2:
1934
img = img[:, :, np.newaxis]
@@ -27,9 +42,6 @@ def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=
2742
transform = transforms.Compose([normalize])
2843
image = transform(img) # (3, 256, 256)
2944

30-
k = beam_size
31-
vocab_size = len(word_map)
32-
3345
# Encode
3446
image = image.unsqueeze(0) # (1, 3, 256, 256)
3547
encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim)
@@ -136,6 +148,17 @@ def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=
136148

137149

138150
def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True):
151+
"""
152+
Visualizes caption with weights at every word.
153+
154+
Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb
155+
156+
:param image_path: path to image that has been captioned
157+
:param seq: caption
158+
:param alphas: weights
159+
:param rev_word_map: reverse word mapping, i.e. ix2word
160+
:param smooth: smooth weights?
161+
"""
139162
image = Image.open(image_path)
140163
image = image.resize([14 * 24, 14 * 24], Image.LANCZOS)
141164

datasets.py

+11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@
66

77

88
class CaptionDataset(Dataset):
9+
"""
10+
A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
11+
"""
12+
913
def __init__(self, data_folder, data_name, split, transform=None):
14+
"""
15+
16+
:param data_folder: folder where data files are stored
17+
:param data_name: base name of processed datasets
18+
:param split: split, one of 'TRAIN', 'VAL', or 'TEST'
19+
:param transform: image transform pipeline
20+
"""
1021
self.split = split
1122
assert self.split in {'TRAIN', 'VAL', 'TEST'}
1223

models.py

+87-19
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,43 @@
66

77

88
class Encoder(nn.Module):
9+
"""
10+
Encoder.
11+
"""
12+
913
def __init__(self, encoded_image_size=14):
1014
super(Encoder, self).__init__()
1115
self.enc_image_size = encoded_image_size
1216

13-
resnet = torchvision.models.resnet101(pretrained=True)
17+
resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101
18+
19+
# Remove linear and pool layers (since we're not doing classification)
1420
modules = list(resnet.children())[:-2]
1521
self.resnet = nn.Sequential(*modules)
1622

23+
# Resize image to fixed size to allow input images of variable size
1724
self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
1825

1926
self.fine_tune()
2027

2128
def forward(self, images):
22-
# images.shape = (batch_size, 3, image_size, image_size)
29+
"""
30+
Forward propagation.
31+
32+
:param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
33+
:return: encoded images
34+
"""
2335
out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32)
2436
out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size)
2537
out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048)
2638
return out
2739

2840
def fine_tune(self, fine_tune=True):
41+
"""
42+
Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
43+
44+
:param fine_tune: Allow?
45+
"""
2946
for p in self.resnet.parameters():
3047
p.requires_grad = False
3148
# If fine-tuning, only fine-tune convolutional blocks 2 through 4
@@ -35,18 +52,31 @@ def fine_tune(self, fine_tune=True):
3552

3653

3754
class Attention(nn.Module):
55+
"""
56+
Attention Network.
57+
"""
58+
3859
def __init__(self, encoder_dim, decoder_dim, attention_dim):
60+
"""
61+
:param encoder_dim: feature size of encoded images
62+
:param decoder_dim: size of decoder's RNN
63+
:param attention_dim: size of the attention network
64+
"""
3965
super(Attention, self).__init__()
40-
self.encoder_att = nn.Linear(encoder_dim, attention_dim)
41-
self.decoder_att = nn.Linear(decoder_dim, attention_dim)
42-
self.full_att = nn.Linear(attention_dim, 1)
66+
self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image
67+
self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output
68+
self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed
4369
self.relu = nn.ReLU()
44-
self.softmax = nn.Softmax(dim=1)
70+
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
4571

4672
def forward(self, encoder_out, decoder_hidden):
47-
# encoder_out.shape = (batch_size, num_pixels, encoder_dim)
48-
# decoder_hidden.shape = (batch_size, decoder_dim)
73+
"""
74+
Forward propagation.
4975
76+
:param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
77+
:param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
78+
:return: attention weighted encoding, weights
79+
"""
5080
att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
5181
att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
5282
att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels)
@@ -57,8 +87,21 @@ def forward(self, encoder_out, decoder_hidden):
5787

5888

5989
class DecoderWithAttention(nn.Module):
90+
"""
91+
Decoder.
92+
"""
93+
6094
def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, decoder_layers=1, encoder_dim=2048,
6195
dropout=0.5):
96+
"""
97+
:param attention_dim: size of attention network
98+
:param embed_dim: embedding size
99+
:param decoder_dim: size of decoder's RNN
100+
:param vocab_size: size of vocabulary
101+
:param decoder_layers: number of layers in the decoder
102+
:param encoder_dim: feature size of encoded images
103+
:param dropout: dropout
104+
"""
62105
super(DecoderWithAttention, self).__init__()
63106

64107
self.encoder_dim = encoder_dim
@@ -69,40 +112,65 @@ def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, decoder_la
69112
self.decoder_layers = decoder_layers
70113
self.dropout = dropout
71114

72-
self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
115+
self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network
73116

74-
self.embedding = nn.Embedding(vocab_size, embed_dim)
117+
self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
75118
self.dropout = nn.Dropout(p=self.dropout)
76-
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, decoder_layers)
77-
self.init_h = nn.Linear(encoder_dim, decoder_dim)
78-
self.init_c = nn.Linear(encoder_dim, decoder_dim)
79-
self.f_beta = nn.Linear(decoder_dim, encoder_dim)
119+
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, decoder_layers) # decoding LSTMCell
120+
self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell
121+
self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell
122+
self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
80123
self.sigmoid = nn.Sigmoid()
81-
self.fc = nn.Linear(decoder_dim, vocab_size)
82-
self.init_weights()
124+
self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary
125+
self.init_weights() # initialize some layers with the uniform distribution
83126

84127
def init_weights(self):
128+
"""
129+
Initializes some parameters with values from the uniform distribution, for easier convergence.
130+
"""
85131
self.embedding.weight.data.uniform_(-0.1, 0.1)
86132
self.fc.bias.data.fill_(0)
87133
self.fc.weight.data.uniform_(-0.1, 0.1)
88134

89135
def load_pretrained_embeddings(self, embeddings):
136+
"""
137+
Loads embedding layer with pre-trained embeddings.
138+
139+
:param embeddings: pre-trained embeddings
140+
"""
90141
self.embedding.weight = nn.Parameter(embeddings)
91142

92143
def fine_tune_embeddings(self, fine_tune=True):
144+
"""
145+
Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
146+
147+
:param fine_tune: Allow?
148+
"""
93149
for p in self.embedding.parameters():
94150
p.requires_grad = fine_tune
95151

96152
def init_hidden_state(self, encoder_out):
153+
"""
154+
Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
155+
156+
:param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
157+
:return: hidden state, cell state
158+
"""
97159
mean_encoder_out = encoder_out.mean(dim=1)
98160
h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
99161
c = self.init_c(mean_encoder_out)
100162
return h, c
101163

102164
def forward(self, encoder_out, encoded_captions, caption_lengths):
103-
# encoder_out.shape = (batch_size, image_size, image_size, encoder_dim), image_size being the pixel width/height
104-
# encoded_captions.shape = (batch_size, max_caption_length)
105-
# caption_lengths.shape = (batch_size, 1)
165+
"""
166+
Forward propagation.
167+
168+
:param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
169+
:param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
170+
:param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
171+
:return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
172+
"""
173+
106174
batch_size = encoder_out.size(0)
107175
encoder_dim = encoder_out.size(-1)
108176
vocab_size = self.vocab_size

0 commit comments

Comments
 (0)