Skip to content

updated transforms.ToPILImage, see #105 #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def test_tensor_to_pil_image(self):
expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(l).numpy())

def test_tensor_gray_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()

img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
img_data_short = torch.ShortTensor(1, 4, 4).random_()
img_data_int = torch.IntTensor(1, 4, 4).random_()

img_byte = trans(img_data_byte)
img_short = trans(img_data_short)
img_int = trans(img_data_int)
assert img_byte.mode == 'L'
assert img_short.mode == 'I;16'
assert img_int.mode == 'I'

assert np.allclose(img_data_short.numpy(), to_tensor(img_short).numpy())
assert np.allclose(img_data_int.numpy(), to_tensor(img_int).numpy())

def test_ndarray_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
Expand All @@ -169,6 +187,35 @@ def test_ndarray_to_pil_image(self):
l, = img.split()
assert np.allclose(l, img_data[:, :, 0])

def test_ndarray_bad_types_to_pil_image(self):
trans = transforms.ToPILImage()
with self.assertRaises(AssertionError):
trans(np.ones([4, 4, 1], np.int64))
trans(np.ones([4, 4, 1], np.uint16))
trans(np.ones([4, 4, 1], np.uint32))
trans(np.ones([4, 4, 1], np.float64))

def test_ndarray_gray_float32_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.FloatTensor(4, 4, 1).random_().numpy()
img = trans(img_data)
assert img.mode == 'F'
assert np.allclose(img, img_data[:, :, 0])

def test_ndarray_gray_int16_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.ShortTensor(4, 4, 1).random_().numpy()
img = trans(img_data)
assert img.mode == 'I;16'
assert np.allclose(img, img_data[:, :, 0])

def test_ndarray_gray_int32_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.IntTensor(4, 4, 1).random_().numpy()
img = trans(img_data)
assert img.mode == 'I'
assert np.allclose(img, img_data[:, :, 0])


if __name__ == '__main__':
unittest.main()
60 changes: 41 additions & 19 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,38 +39,60 @@ def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backard compability
return img.float().div(255)
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
else:
# handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
return img.float().div(255)
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img


class ToPILImage(object):
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
to a PIL.Image of range [0, 255]
"""Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving value range.
"""

def __call__(self, pic):
npimg = pic
mode = None
if not isinstance(npimg, np.ndarray):
npimg = pic.mul(255).byte().numpy()
npimg = np.transpose(npimg, (1, 2, 0))

if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
mode = "L"

if npimg.dtype == np.uint8:
mode = 'L'
if npimg.dtype == np.int16:
mode = 'I;16'
if npimg.dtype == np.int32:
mode = 'I'
elif npimg.dtype == np.float32:
mode = 'F'
else:
if npimg.dtype == np.uint8:
mode = 'RGB'
assert mode is not None, '{} is not supported'.format(npimg.dtype)
return Image.fromarray(npimg, mode=mode)


Expand Down