forked from imclab/neuraltalk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_provider.py
116 lines (98 loc) · 4.13 KB
/
data_provider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import os
import random
import scipy.io
import codecs
from collections import defaultdict
class BasicDataProvider:
def __init__(self, dataset):
print 'Initializing data provider for dataset %s...' % (dataset, )
# !assumptions on folder structure
self.dataset_root = os.path.join('data', dataset)
self.image_root = os.path.join('data', dataset, 'imgs')
# load the dataset into memory
dataset_path = os.path.join(self.dataset_root, 'dataset.json')
print 'BasicDataProvider: reading %s' % (dataset_path, )
self.dataset = json.load(open(dataset_path, 'r'))
# load the image features into memory
features_path = os.path.join(self.dataset_root, 'vgg_feats.mat')
print 'BasicDataProvider: reading %s' % (features_path, )
features_struct = scipy.io.loadmat(features_path)
self.features = features_struct['feats']
# group images by their train/val/test split into a dictionary -> list structure
self.split = defaultdict(list)
for img in self.dataset['images']:
self.split[img['split']].append(img)
# "PRIVATE" FUNCTIONS
# in future we may want to create copies here so that we don't touch the
# data provider class data, but for now lets do the simple thing and
# just return raw internal img sent structs. This also has the advantage
# that the driver could store various useful caching stuff in these structs
# and they will be returned in the future with the cache present
def _getImage(self, img):
""" create an image structure for the driver """
# lazily fill in some attributes
if not 'local_file_path' in img: img['local_file_path'] = os.path.join(self.image_root, img['filename'])
if not 'feat' in img: # also fill in the features
feature_index = img['imgid'] # NOTE: imgid is an integer, and it indexes into features
img['feat'] = self.features[:,feature_index]
return img
def _getSentence(self, sent):
""" create a sentence structure for the driver """
# NOOP for now
return sent
# PUBLIC FUNCTIONS
def getSplitSize(self, split, ofwhat = 'sentences'):
""" return size of a split, either number of sentences or number of images """
if ofwhat == 'sentences':
return sum(len(img['sentences']) for img in self.split[split])
else: # assume images
return len(self.split[split])
def sampleImageSentencePair(self, split = 'train'):
""" sample image sentence pair from a split """
images = self.split[split]
img = random.choice(images)
sent = random.choice(img['sentences'])
out = {}
out['image'] = self._getImage(img)
out['sentence'] = self._getSentence(sent)
return out
def iterImageSentencePair(self, split = 'train', max_images = -1):
for i,img in enumerate(self.split[split]):
if max_images >= 0 and i >= max_images: break
for sent in img['sentences']:
out = {}
out['image'] = self._getImage(img)
out['sentence'] = self._getSentence(sent)
yield out
def iterImageSentencePairBatch(self, split = 'train', max_images = -1, max_batch_size = 100):
batch = []
for i,img in enumerate(self.split[split]):
if max_images >= 0 and i >= max_images: break
for sent in img['sentences']:
out = {}
out['image'] = self._getImage(img)
out['sentence'] = self._getSentence(sent)
batch.append(out)
if len(batch) >= max_batch_size:
yield batch
batch = []
if batch:
yield batch
def iterSentences(self, split = 'train'):
for img in self.split[split]:
for sent in img['sentences']:
yield self._getSentence(sent)
def iterImages(self, split = 'train', shuffle = False, max_images = -1):
imglist = self.split[split]
ix = range(len(imglist))
if shuffle:
random.shuffle(ix)
if max_images > 0:
ix = ix[:min(len(ix),max_images)] # crop the list
for i in ix:
yield self._getImage(imglist[i])
def getDataProvider(dataset):
""" we could intercept a special dataset and return different data providers """
assert dataset in ['flickr8k', 'flickr30k', 'coco'], 'dataset %s unknown' % (dataset, )
return BasicDataProvider(dataset)