Skip to content

Commit 482b94a

Browse files
lixinyufacebook-github-bot
lixinyu
authored andcommitted
move RoutedDecoder Dataset to DataPipe (pytorch#51704)
Summary: Pull Request resolved: pytorch#51704 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D26245910 Pulled By: glaringlee fbshipit-source-id: 91e3c9f8a6c1209c1a1a752ba29a80dbd9bf4119
1 parent 8ab22a0 commit 482b94a

File tree

4 files changed

+344
-1
lines changed

4 files changed

+344
-1
lines changed

test/test_datapipe.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import warnings
66
import tarfile
77
import zipfile
8+
import numpy as np
9+
from PIL import Image
810

911
import torch
1012
from torch.testing._internal.common_utils import (TestCase, run_tests)
@@ -13,6 +15,10 @@
1315

1416
import torch.utils.data.datapipes as dp
1517

18+
from torch.utils.data.datapipes.utils.decoder import (
19+
basichandlers as decoder_basichandlers,
20+
imagehandler as decoder_imagehandler)
21+
1622
def create_temp_dir_and_files():
1723
# The temp dir and files within it will be released and deleted in tearDown().
1824
# Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
@@ -153,6 +159,25 @@ def test_readfilesfromzip_iterable_datapipe(self):
153159
self.assertEqual(data_refs[i][1].read(), open(self.temp_files[i], 'rb').read())
154160

155161

162+
def test_routeddecoder_iterable_datapipe(self):
163+
temp_dir = self.temp_dir.name
164+
temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
165+
img = Image.new('RGB', (2, 2), color='red')
166+
img.save(temp_pngfile_pathname)
167+
datapipe1 = dp.iter.ListDirFiles(temp_dir, ['*.png', '*.txt'])
168+
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
169+
datapipe3 = dp.iter.RoutedDecoder(datapipe2, handlers=[decoder_imagehandler('rgb')])
170+
datapipe3.add_handler(decoder_basichandlers)
171+
172+
for rec in datapipe3:
173+
ext = os.path.splitext(rec[0])[1]
174+
if ext == '.png':
175+
expected = np.array([[[1., 0., 0.], [1., 0., 0.]], [[1., 0., 0.], [1., 0., 0.]]], dtype=np.single)
176+
self.assertTrue(np.array_equal(rec[1], expected))
177+
else:
178+
self.assertTrue(rec[1] == open(rec[0], 'rb').read().decode('utf-8'))
179+
180+
156181
class IDP_NoLen(IterDataPipe):
157182
def __init__(self, input_dp):
158183
super().__init__()

torch/utils/data/datapipes/iter/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from torch.utils.data.datapipes.iter.loadfilesfromdisk import LoadFilesFromDiskIterDataPipe as LoadFilesFromDisk
33
from torch.utils.data.datapipes.iter.readfilesfromtar import ReadFilesFromTarIterDataPipe as ReadFilesFromTar
44
from torch.utils.data.datapipes.iter.readfilesfromzip import ReadFilesFromZipIterDataPipe as ReadFilesFromZip
5+
from torch.utils.data.datapipes.iter.routeddecoder import RoutedDecoderIterDataPipe as RoutedDecoder
56

67
# Functional DataPipe
78
from torch.utils.data.datapipes.iter.batch import BatchIterDataPipe as Batch, BucketBatchIterDataPipe as BucketBatch
89
from torch.utils.data.datapipes.iter.callable import CallableIterDataPipe as Callable, CollateIterDataPipe as Collate
910
from torch.utils.data.datapipes.iter.sampler import SamplerIterDataPipe as Sampler
1011

11-
__all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip'
12+
__all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip', 'RoutedDecoder',
1213
'Batch', 'BucketBatch', 'Callable', 'Collate', 'Sampler']
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from torch.utils.data import IterDataPipe
2+
from torch.utils.data.datapipes.utils.decoder import (
3+
Decoder,
4+
basichandlers as decoder_basichandlers,
5+
imagehandler as decoder_imagehandler)
6+
7+
from typing import Iterable, Iterator, Union, List, Tuple, Any, Callable
8+
from io import BufferedIOBase
9+
10+
class RoutedDecoderIterDataPipe(IterDataPipe):
11+
r""" :class:`RoutedDecoderIterDataPipe`.
12+
13+
Iterable datapipe to decode binary streams from input iterables,
14+
yield pathname and decoded binary stream in a tuple.
15+
args:
16+
datapipe: Iterable datapipe that provides pathname and binary stream in tuples
17+
handlers: user defined decoder handlers, if None, basic and image decoder handlers will be set as default
18+
length: a nominal length of the datapipe
19+
"""
20+
21+
def __init__(
22+
self,
23+
datapipe : Iterable[Tuple[str, BufferedIOBase]],
24+
*,
25+
handlers : Union[None, List[Callable]] = None,
26+
length: int = -1):
27+
super().__init__()
28+
self.datapipe : Iterable[Tuple[str, BufferedIOBase]] = datapipe
29+
if handlers:
30+
self.decoder = Decoder(handlers)
31+
else:
32+
self.decoder = Decoder([decoder_basichandlers, decoder_imagehandler('torch')])
33+
self.length : int = length
34+
35+
def add_handler(self, handler : Callable) -> None:
36+
self.decoder.add_handler(handler)
37+
38+
def __iter__(self) -> Iterator[Tuple[str, Any]]:
39+
for data in self.datapipe:
40+
pathname = data[0]
41+
result = self.decoder(data)
42+
yield (pathname, result[pathname])
43+
44+
def __len__(self):
45+
if self.length == -1:
46+
raise NotImplementedError
47+
return self.length
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# This file takes partial of the implementation from NVIDIA's webdataset at here:
2+
# https://github.com/tmbdev/webdataset/blob/master/webdataset/autodecode.py
3+
4+
import pickle
5+
import re
6+
import os
7+
8+
import numpy as np
9+
import PIL
10+
import PIL.Image
11+
import json
12+
import tempfile
13+
import io
14+
15+
16+
################################################################
17+
# handle basic datatypes
18+
################################################################
19+
20+
21+
def basichandlers(key, data):
22+
23+
extension = re.sub(r".*[.]", "", key)
24+
25+
if extension in "txt text transcript":
26+
return data.decode("utf-8")
27+
28+
if extension in "cls cls2 class count index inx id".split():
29+
try:
30+
return int(data)
31+
except ValueError:
32+
return None
33+
34+
if extension in "json jsn":
35+
return json.loads(data)
36+
37+
if extension in "pyd pickle".split():
38+
return pickle.loads(data)
39+
40+
if extension in "pt".split():
41+
import torch
42+
stream = io.BytesIO(data)
43+
return torch.load(stream)
44+
45+
# if extension in "ten tb".split():
46+
# from . import tenbin
47+
# return tenbin.decode_buffer(data)
48+
49+
# if extension in "mp msgpack msg".split():
50+
# import msgpack
51+
# return msgpack.unpackb(data)
52+
53+
return None
54+
55+
56+
################################################################
57+
# handle images
58+
################################################################
59+
60+
imagespecs = {
61+
"l8": ("numpy", "uint8", "l"),
62+
"rgb8": ("numpy", "uint8", "rgb"),
63+
"rgba8": ("numpy", "uint8", "rgba"),
64+
"l": ("numpy", "float", "l"),
65+
"rgb": ("numpy", "float", "rgb"),
66+
"rgba": ("numpy", "float", "rgba"),
67+
"torchl8": ("torch", "uint8", "l"),
68+
"torchrgb8": ("torch", "uint8", "rgb"),
69+
"torchrgba8": ("torch", "uint8", "rgba"),
70+
"torchl": ("torch", "float", "l"),
71+
"torchrgb": ("torch", "float", "rgb"),
72+
"torch": ("torch", "float", "rgb"),
73+
"torchrgba": ("torch", "float", "rgba"),
74+
"pill": ("pil", None, "l"),
75+
"pil": ("pil", None, "rgb"),
76+
"pilrgb": ("pil", None, "rgb"),
77+
"pilrgba": ("pil", None, "rgba"),
78+
}
79+
80+
def handle_extension(extensions, f):
81+
"""
82+
Returns a decoder handler function for the list of extensions.
83+
Extensions can be a space separated list of extensions.
84+
Extensions can contain dots, in which case the corresponding number
85+
of extension components must be present in the key given to f.
86+
Comparisons are case insensitive.
87+
Examples:
88+
handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg
89+
handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg
90+
"""
91+
92+
extensions = extensions.lower().split()
93+
94+
def g(key, data):
95+
extension = key.lower().split(".")
96+
97+
for target in extensions:
98+
target = target.split(".")
99+
if len(target) > len(extension):
100+
continue
101+
102+
if extension[-len(target):] == target:
103+
return f(data)
104+
return None
105+
return g
106+
107+
108+
class ImageHandler:
109+
"""
110+
Decode image data using the given `imagespec`.
111+
The `imagespec` specifies whether the image is decoded
112+
to numpy/torch/pi, decoded to uint8/float, and decoded
113+
to l/rgb/rgba:
114+
115+
- l8: numpy uint8 l
116+
- rgb8: numpy uint8 rgb
117+
- rgba8: numpy uint8 rgba
118+
- l: numpy float l
119+
- rgb: numpy float rgb
120+
- rgba: numpy float rgba
121+
- torchl8: torch uint8 l
122+
- torchrgb8: torch uint8 rgb
123+
- torchrgba8: torch uint8 rgba
124+
- torchl: torch float l
125+
- torchrgb: torch float rgb
126+
- torch: torch float rgb
127+
- torchrgba: torch float rgba
128+
- pill: pil None l
129+
- pil: pil None rgb
130+
- pilrgb: pil None rgb
131+
- pilrgba: pil None rgba
132+
"""
133+
def __init__(self, imagespec):
134+
assert imagespec in list(imagespecs.keys()), "unknown image specification: {}".format(imagespec)
135+
self.imagespec = imagespec.lower()
136+
137+
def __call__(self, key, data):
138+
extension = re.sub(r".*[.]", "", key)
139+
if extension.lower() not in "jpg jpeg png ppm pgm pbm pnm".split():
140+
return None
141+
142+
imagespec = self.imagespec
143+
atype, etype, mode = imagespecs[imagespec]
144+
145+
with io.BytesIO(data) as stream:
146+
img = PIL.Image.open(stream)
147+
img.load()
148+
img = img.convert(mode.upper())
149+
if atype == "pil":
150+
return img
151+
elif atype == "numpy":
152+
result = np.asarray(img)
153+
assert result.dtype == np.uint8, "numpy image array should be type uint8, but got {}".format(result.dtype)
154+
if etype == "uint8":
155+
return result
156+
else:
157+
return result.astype("f") / 255.0
158+
elif atype == "torch":
159+
import torch
160+
161+
result = np.asarray(img)
162+
assert result.dtype == np.uint8, "numpy image array should be type uint8, but got {}".format(result.dtype)
163+
164+
if etype == "uint8":
165+
result = np.array(result.transpose(2, 0, 1))
166+
return torch.tensor(result)
167+
else:
168+
result = np.array(result.transpose(2, 0, 1))
169+
return torch.tensor(result) / 255.0
170+
return None
171+
172+
def imagehandler(imagespec):
173+
return ImageHandler(imagespec)
174+
175+
176+
################################################################
177+
# torch video
178+
################################################################
179+
180+
181+
def torch_video(key, data):
182+
extension = re.sub(r".*[.]", "", key)
183+
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
184+
return None
185+
186+
# add `type: ignore` to avoid mypy's warning on import missing
187+
import torchvision.io # type: ignore
188+
with tempfile.TemporaryDirectory() as dirname:
189+
fname = os.path.join(dirname, f"file.{extension}")
190+
with open(fname, "wb") as stream:
191+
stream.write(data)
192+
return torchvision.io.read_video(fname)
193+
194+
195+
################################################################
196+
# torchaudio
197+
################################################################
198+
199+
200+
def torch_audio(key, data):
201+
extension = re.sub(r".*[.]", "", key)
202+
if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
203+
return None
204+
205+
# add `type: ignore` to avoid mypy's warning on import missing
206+
import torchaudio # type: ignore
207+
with tempfile.TemporaryDirectory() as dirname:
208+
fname = os.path.join(dirname, f"file.{extension}")
209+
with open(fname, "wb") as stream:
210+
stream.write(data)
211+
return torchaudio.load(fname)
212+
213+
214+
215+
################################################################
216+
# a sample decoder
217+
################################################################
218+
219+
220+
class Decoder:
221+
"""
222+
Decode key/data sets using a list of handlers.
223+
For each key/data item, this iterates through the list of
224+
handlers until some handler returns something other than None.
225+
"""
226+
227+
def __init__(self, handlers):
228+
self.handlers = handlers
229+
230+
def add_handler(self, handler):
231+
if not handler:
232+
return
233+
if not self.handlers:
234+
self.handlers = [handler]
235+
else:
236+
self.handlers.append(handler)
237+
238+
def decode1(self, key, data):
239+
if not data:
240+
return data
241+
242+
# if data is a stream handle, we need to read all the content before decoding
243+
if isinstance(data, io.BufferedIOBase) or isinstance(data, io.RawIOBase):
244+
data = data.read()
245+
246+
for f in self.handlers:
247+
result = f(key, data)
248+
if result is not None:
249+
return result
250+
return data
251+
252+
def decode(self, data):
253+
result = {}
254+
# single data tuple(pathname, data stream)
255+
if isinstance(data, tuple):
256+
data = [data]
257+
258+
if data is not None:
259+
for k, v in data:
260+
# TODO: xinyu, figure out why Nvidia do this?
261+
if k[0] == "_":
262+
if isinstance(v, bytes):
263+
v = v.decode("utf-8")
264+
result[k] = v
265+
continue
266+
result[k] = self.decode1(k, v)
267+
return result
268+
269+
def __call__(self, data):
270+
return self.decode(data)

0 commit comments

Comments
 (0)