Skip to content

Commit afe67cd

Browse files
committed
initial commit
0 parents  commit afe67cd

18 files changed

+5247
-0
lines changed

Diff for: config.py

+472
Large diffs are not rendered by default.

Diff for: configs/args.txt

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
--parallel
2+
--evalTrain
3+
--retainVal
4+
--useEMA
5+
--lrReduce
6+
--adam
7+
--clip
8+
--memoryVariationalDropout
9+
--relu=ELU
10+
--encBi
11+
--wrdEmbRandom
12+
--wrdEmbUniform
13+
--outQuestion
14+
--initCtrl=Q
15+
--controlContextual
16+
--controlInputUnshared
17+
--readProjInputs
18+
--readMemConcatKB
19+
--readMemConcatProj
20+
--readMemProj
21+
--readCtrl
22+
--writeMemProj

Diff for: configs/args1.txt

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
--netLength=8
2+
--parallel
3+
--evalTrain
4+
--retainVal
5+
--useEMA
6+
--lrReduce
7+
--adam
8+
--clip
9+
--memoryVariationalDropout
10+
--relu=ELU
11+
--encBi
12+
--wrdEmbRandom
13+
--wrdEmbUniform
14+
--outQuestion
15+
--controlContextual
16+
--readProjInputs
17+
--readMemConcatKB
18+
--readMemConcatProj
19+
--readMemProj
20+
--readCtrl
21+
--writeMemProj
22+
--initCtrl=PRM
23+
--controlFeedPrev
24+
--controlFeedPrevAtt
25+
--controlFeedInputs
26+
--controlContAct=TANH

Diff for: configs/args2.txt

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
--netLength=12
2+
--parallel
3+
--evalTrain
4+
--retainVal
5+
--useEMA
6+
--lrReduce
7+
--adam
8+
--clip
9+
--memoryVariationalDropout
10+
--relu=ELU
11+
--encBi
12+
--wrdEmbRandom
13+
--wrdEmbUniform
14+
--outQuestion
15+
--initCtrl=Q
16+
--controlContextual
17+
--controlInputUnshared
18+
--readProjInputs
19+
--readMemConcatKB
20+
--readMemConcatProj
21+
--readMemProj
22+
--readCtrl
23+
--writeMemProj
24+
--qDropout=0.85
25+
--stemDropout=0.85
26+
--noBucket
27+
--noRebucket

Diff for: configs/args3.txt

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
--netLength=8
2+
--parallel
3+
--evalTrain
4+
--retainVal
5+
--useEMA
6+
--lrReduce
7+
--adam
8+
--clip
9+
--memoryVariationalDropout
10+
--relu=ELU
11+
--encBi
12+
--wrdEmbRandom
13+
--wrdEmbUniform
14+
--outQuestion
15+
--initCtrl=Q
16+
--controlContextual
17+
--controlInputUnshared
18+
--readProjInputs
19+
--readMemConcatKB
20+
--readMemConcatProj
21+
--readMemProj
22+
--readCtrl
23+
--writeMemProj
24+
--writeSelfAtt
25+
--writeSelfAttMod CONT

Diff for: configs/args4.txt

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
--netLength=8
2+
--parallel
3+
--evalTrain
4+
--retainVal
5+
--useEMA
6+
--lrReduce
7+
--adam
8+
--clip
9+
--memoryVariationalDropout
10+
--relu=ELU
11+
--encBi
12+
--wrdEmbRandom
13+
--wrdEmbUniform
14+
--outQuestion
15+
--initCtrl=Q
16+
--controlContextual
17+
--controlInputUnshared
18+
--readProjInputs
19+
--readMemConcatKB
20+
--readMemConcatProj
21+
--readMemProj
22+
--readCtrl
23+
--writeMemProj
24+
--writeGate

Diff for: extract_features.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse, os, json
8+
import h5py
9+
import numpy as np
10+
from scipy.misc import imread, imresize
11+
12+
import torch
13+
import torchvision
14+
15+
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument('--input_image_dir', required=True)
18+
parser.add_argument('--max_images', default=None, type=int)
19+
parser.add_argument('--output_h5_file', required=True)
20+
21+
parser.add_argument('--image_height', default=224, type=int)
22+
parser.add_argument('--image_width', default=224, type=int)
23+
24+
parser.add_argument('--model', default='resnet101')
25+
parser.add_argument('--model_stage', default=3, type=int)
26+
parser.add_argument('--batch_size', default=128, type=int)
27+
28+
29+
def build_model(args):
30+
if not hasattr(torchvision.models, args.model):
31+
raise ValueError('Invalid model "%s"' % args.model)
32+
if not 'resnet' in args.model:
33+
raise ValueError('Feature extraction only supports ResNets')
34+
cnn = getattr(torchvision.models, args.model)(pretrained=True)
35+
layers = [
36+
cnn.conv1,
37+
cnn.bn1,
38+
cnn.relu,
39+
cnn.maxpool,
40+
]
41+
for i in range(args.model_stage):
42+
name = 'layer%d' % (i + 1)
43+
layers.append(getattr(cnn, name))
44+
model = torch.nn.Sequential(*layers)
45+
model.cuda()
46+
model.eval()
47+
return model
48+
49+
50+
def run_batch(cur_batch, model):
51+
mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
52+
std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1)
53+
54+
image_batch = np.concatenate(cur_batch, 0).astype(np.float32)
55+
image_batch = (image_batch / 255.0 - mean) / std
56+
image_batch = torch.FloatTensor(image_batch).cuda()
57+
image_batch = torch.autograd.Variable(image_batch, volatile=True)
58+
59+
feats = model(image_batch)
60+
feats = feats.data.cpu().clone().numpy()
61+
62+
return feats
63+
64+
65+
def main(args):
66+
input_paths = []
67+
idx_set = set()
68+
for fn in os.listdir(args.input_image_dir):
69+
if not fn.endswith('.png'): continue
70+
idx = int(os.path.splitext(fn)[0].split('_')[-1])
71+
input_paths.append((os.path.join(args.input_image_dir, fn), idx))
72+
idx_set.add(idx)
73+
input_paths.sort(key=lambda x: x[1])
74+
assert len(idx_set) == len(input_paths)
75+
assert min(idx_set) == 0 and max(idx_set) == len(idx_set) - 1
76+
if args.max_images is not None:
77+
input_paths = input_paths[:args.max_images]
78+
print(input_paths[0])
79+
print(input_paths[-1])
80+
81+
model = build_model(args)
82+
83+
img_size = (args.image_height, args.image_width)
84+
with h5py.File(args.output_h5_file, 'w') as f:
85+
feat_dset = None
86+
i0 = 0
87+
cur_batch = []
88+
for i, (path, idx) in enumerate(input_paths):
89+
img = imread(path, mode='RGB')
90+
img = imresize(img, img_size, interp='bicubic')
91+
img = img.transpose(2, 0, 1)[None]
92+
cur_batch.append(img)
93+
if len(cur_batch) == args.batch_size:
94+
feats = run_batch(cur_batch, model)
95+
if feat_dset is None:
96+
N = len(input_paths)
97+
_, C, H, W = feats.shape
98+
feat_dset = f.create_dataset('features', (N, C, H, W),
99+
dtype=np.float32)
100+
i1 = i0 + len(cur_batch)
101+
feat_dset[i0:i1] = feats
102+
i0 = i1
103+
print('Processed %d / %d images' % (i1, len(input_paths)))
104+
cur_batch = []
105+
if len(cur_batch) > 0:
106+
feats = run_batch(cur_batch, model)
107+
i1 = i0 + len(cur_batch)
108+
feat_dset[i0:i1] = feats
109+
print('Processed %d / %d images' % (i1, len(input_paths)))
110+
111+
112+
if __name__ == '__main__':
113+
args = parser.parse_args()
114+
main(args)

0 commit comments

Comments
 (0)