Skip to content

Commit 015cabf

Browse files
lixinyufacebook-github-bot
lixinyu
authored andcommitted
move GroupByFilename Dataset to DataPipe (pytorch#51709)
Summary: Pull Request resolved: pytorch#51709 Move GroupByFilename Dataset to DataPipe Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D26263585 Pulled By: glaringlee fbshipit-source-id: 00e3e13b47b89117f1ccfc4cd6239940a40d071e
1 parent 482b94a commit 015cabf

File tree

3 files changed

+144
-1
lines changed

3 files changed

+144
-1
lines changed

test/test_datapipe.py

+33
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
basichandlers as decoder_basichandlers,
2020
imagehandler as decoder_imagehandler)
2121

22+
2223
def create_temp_dir_and_files():
2324
# The temp dir and files within it will be released and deleted in tearDown().
2425
# Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
@@ -178,6 +179,38 @@ def test_routeddecoder_iterable_datapipe(self):
178179
self.assertTrue(rec[1] == open(rec[0], 'rb').read().decode('utf-8'))
179180

180181

182+
def test_groupbykey_iterable_datapipe(self):
183+
temp_dir = self.temp_dir.name
184+
temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
185+
file_list = [
186+
"a.png", "b.png", "c.json", "a.json", "c.png", "b.json", "d.png",
187+
"d.json", "e.png", "f.json", "g.png", "f.png", "g.json", "e.json",
188+
"h.txt", "h.json"]
189+
with tarfile.open(temp_tarfile_pathname, "w:gz") as tar:
190+
for file_name in file_list:
191+
file_pathname = os.path.join(temp_dir, file_name)
192+
with open(file_pathname, 'w') as f:
193+
f.write('12345abcde')
194+
tar.add(file_pathname)
195+
196+
datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar')
197+
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
198+
datapipe3 = dp.iter.ReadFilesFromTar(datapipe2)
199+
datapipe4 = dp.iter.GroupByKey(datapipe3, group_size=2)
200+
201+
expected_result = [("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), (
202+
"f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.json", "h.txt")]
203+
204+
count = 0
205+
for rec, expected in zip(datapipe4, expected_result):
206+
count = count + 1
207+
self.assertEqual(os.path.basename(rec[0][0]), expected[0])
208+
self.assertEqual(os.path.basename(rec[1][0]), expected[1])
209+
self.assertEqual(rec[0][1].read(), b'12345abcde')
210+
self.assertEqual(rec[1][1].read(), b'12345abcde')
211+
self.assertEqual(count, 8)
212+
213+
181214
class IDP_NoLen(IterDataPipe):
182215
def __init__(self, input_dp):
183216
super().__init__()

torch/utils/data/datapipes/iter/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from torch.utils.data.datapipes.iter.readfilesfromtar import ReadFilesFromTarIterDataPipe as ReadFilesFromTar
44
from torch.utils.data.datapipes.iter.readfilesfromzip import ReadFilesFromZipIterDataPipe as ReadFilesFromZip
55
from torch.utils.data.datapipes.iter.routeddecoder import RoutedDecoderIterDataPipe as RoutedDecoder
6+
from torch.utils.data.datapipes.iter.groupbykey import GroupByKeyIterDataPipe as GroupByKey
67

78
# Functional DataPipe
89
from torch.utils.data.datapipes.iter.batch import BatchIterDataPipe as Batch, BucketBatchIterDataPipe as BucketBatch
910
from torch.utils.data.datapipes.iter.callable import CallableIterDataPipe as Callable, CollateIterDataPipe as Collate
1011
from torch.utils.data.datapipes.iter.sampler import SamplerIterDataPipe as Sampler
1112

12-
__all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip', 'RoutedDecoder',
13+
__all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip', 'RoutedDecoder', 'GroupByKey',
1314
'Batch', 'BucketBatch', 'Callable', 'Collate', 'Sampler']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from torch.utils.data import IterDataPipe
2+
from typing import Dict, List, Tuple, Any, Callable, Iterable, Iterator, Union
3+
4+
import os
5+
import functools
6+
7+
8+
# defaut group key is the file pathname without the extension.
9+
# Assuming the passed in data is a tuple and 1st item is file's pathname.
10+
def default_group_key_fn(dataitem : Tuple[str, Any]):
11+
return os.path.splitext(dataitem[0])[0]
12+
13+
14+
def default_sort_data_fn(datalist : List[Tuple[str, Any]]):
15+
txt_ext = ['.json', '.jsn', '.txt', '.text']
16+
17+
def cmp_fn(a : Tuple[str, Any], b : Tuple[str, Any]):
18+
a_is_txt = os.path.splitext(a[0])[1] in txt_ext
19+
b_is_txt = os.path.splitext(b[0])[1] in txt_ext
20+
21+
# if a is txt but b is not, b go front
22+
if a_is_txt and not b_is_txt:
23+
return 1
24+
# if a is not txt but b is txt, a go front
25+
if not a_is_txt and b_is_txt:
26+
return -1
27+
# if a and b both are or are not txt, sort in alphabetic order
28+
if a[0] < b[0]:
29+
return -1
30+
elif a[0] > b[0]:
31+
return 1
32+
return 0
33+
34+
return sorted(datalist, key=functools.cmp_to_key(cmp_fn))
35+
36+
37+
class GroupByKeyIterDataPipe(IterDataPipe):
38+
r""" :class:`GroupByKeyIterDataPipe`.
39+
40+
Iterable datapipe to group data from input iterable by keys which are generated from `group_key_fn`,
41+
yields a list with `group_size` items in it, each item in the list is a tuple of key and data
42+
43+
args:
44+
datapipe: Iterable datapipe that provides data. (typically str key (eg. pathname) and data stream in tuples)
45+
group_size: the size of group
46+
max_buffer_size: the max size of stream buffer which is used to store not yet grouped but iterated data
47+
group_key_fn: a function which is used to generate group key from the data in the input datapipe
48+
sort_data_fn: a function which is used to sort the grouped data before yielding back
49+
length: a nominal length of the datapipe
50+
"""
51+
52+
def __init__(
53+
self,
54+
datapipe : Iterable[Tuple[str, Any]],
55+
*,
56+
group_size : int,
57+
max_buffer_size : Union[int, None] = None,
58+
group_key_fn : Callable = default_group_key_fn,
59+
sort_data_fn : Callable = default_sort_data_fn,
60+
length: int = -1):
61+
super().__init__()
62+
63+
assert group_size > 0
64+
self.datapipe : Iterable[Tuple[str, Any]] = datapipe
65+
self.group_size : int = group_size
66+
67+
# default max buffer size is group_size * 10
68+
self.max_buffer_size = max_buffer_size if max_buffer_size is not None else group_size * 10
69+
assert self.max_buffer_size >= self.group_size
70+
71+
self.group_key_fn : Callable = group_key_fn
72+
self.sort_data_fn : Callable = sort_data_fn
73+
self.curr_buffer_size : int = 0
74+
self.stream_buffer : Dict[str, List[Tuple[str, Any]]] = {}
75+
self.length : int = length
76+
77+
78+
def __iter__(self) -> Iterator[list]:
79+
if self.group_size == 1:
80+
for data in self.datapipe:
81+
yield [data]
82+
else:
83+
for data in self.datapipe:
84+
key = self.group_key_fn(data)
85+
if key not in self.stream_buffer:
86+
self.stream_buffer[key] = []
87+
res = self.stream_buffer[key]
88+
res.append(data)
89+
if len(res) == self.group_size:
90+
yield self.sort_data_fn(res)
91+
del self.stream_buffer[key]
92+
self.curr_buffer_size = self.curr_buffer_size - self.group_size + 1
93+
else:
94+
if self.curr_buffer_size == self.max_buffer_size:
95+
raise OverflowError(
96+
"stream_buffer is overflow, please adjust the order of data "
97+
"in the input datapipe or increase the buffer size!")
98+
self.curr_buffer_size = self.curr_buffer_size + 1
99+
100+
if self.curr_buffer_size > 0:
101+
msg = "Not able to group [{}] with group size {}.".format(
102+
','.join([v[0] for _, vs in self.stream_buffer.items() for v in vs]), str(self.group_size))
103+
raise RuntimeError(msg)
104+
105+
106+
def __len__(self):
107+
if self.length == -1:
108+
raise NotImplementedError
109+
return self.length

0 commit comments

Comments
 (0)