From 4a21b71b58f30cc7ce7474f1db54e1d6d55a41df Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Wed, 28 Dec 2016 13:06:00 -0800 Subject: [PATCH] Add support for Half tensors to DataParallelTable. --- DataParallelTable.lua | 4 ++-- test_DataParallelTable.lua | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DataParallelTable.lua b/DataParallelTable.lua index e0194d48..9e07978c 100644 --- a/DataParallelTable.lua +++ b/DataParallelTable.lua @@ -550,8 +550,8 @@ function DataParallelTable:_distributeTensorRecursive(dst, src, idx, n) assert(torch.isTensor(src), 'input must be a tensor or table of tensors') if self.typeStr == 'torch.CudaHalfTensor' then - assert(false, - 'Half Tensors not supported yet by DataParallelTable') + assert(src:type() == self.typeStr or src:type() == 'torch.HalfTensor', + 'input must be a CudaHalf or Half tensor') elseif self.typeStr == 'torch.CudaDoubleTensor' then assert(src:type() == self.typeStr or src:type() == 'torch.DoubleTensor', 'input must be a CudaDouble or Double tensor') diff --git a/test_DataParallelTable.lua b/test_DataParallelTable.lua index 2b25cf22..ec91b78e 100644 --- a/test_DataParallelTable.lua +++ b/test_DataParallelTable.lua @@ -25,7 +25,7 @@ local t2cpu = { local function checkHalf() if cutorch.hasHalf then table.insert(typenames, 'torch.CudaHalfTensor') - t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor' + t2cpu['torch.CudaHalfTensor'] = 'torch.HalfTensor' end end