Skip to content

Commit

Permalink
Merge pull request #169 from asappresearch/3.0.0-dev-tao
Browse files Browse the repository at this point in the history
Speed up data loading / batching for ONE BILLION WORD experiment
  • Loading branch information
taoleicn authored Mar 5, 2021
2 parents 7200f94 + 84ef01c commit 81a657b
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions experiments/srupp_experiments/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,22 +421,20 @@ def __init__(self, paths, vocab, bsz, bptt, n_nodes=1, rank=0,

def get_sent_stream(self, path):
sents = self.vocab.encode_file(path, add_double_eos=True)
sents = [s for i, s in enumerate(sents) if i % self.n_nodes == self.rank]
if self.shuffle:
np.random.shuffle(sents)
sent_stream = iter(sents)

return sent_stream

def __iter__(self):
'''
Iterate over all splits of data *repeatively*
'''
if self.shuffle:
np.random.shuffle(self.paths)

paths = self.paths * self.n_nodes
paths_this_node = [path for i, path in enumerate(paths)
if i % self.n_nodes == self.rank]
while True:
for path in self.paths:
for path in paths_this_node:
# sent_stream is an iterator
sent_stream = self.get_sent_stream(path)
for batch in self.stream_iterator(sent_stream):
Expand Down Expand Up @@ -510,11 +508,6 @@ def get_distributed_iterator(self, split, *args, **kwargs):
data_iter = DistributedLMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
elif split in ['valid', 'test']:
raise NotImplementedError()
# data = self.valid if split == 'valid' else self.test
# if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
# data_iter = DistributedLMOrderedIterator(data, *args, **kwargs)
# elif self.dataset == 'lm1b':
# data_iter = LMShuffledIterator(data, *args, **kwargs)

return data_iter

Expand Down

0 comments on commit 81a657b

Please sign in to comment.