diff --git a/DataParallelTable.lua b/DataParallelTable.lua index e0194d48..9e07978c 100644 --- a/DataParallelTable.lua +++ b/DataParallelTable.lua @@ -550,8 +550,8 @@ function DataParallelTable:_distributeTensorRecursive(dst, src, idx, n) assert(torch.isTensor(src), 'input must be a tensor or table of tensors') if self.typeStr == 'torch.CudaHalfTensor' then - assert(false, - 'Half Tensors not supported yet by DataParallelTable') + assert(src:type() == self.typeStr or src:type() == 'torch.HalfTensor', + 'input must be a CudaHalf or Half tensor') elseif self.typeStr == 'torch.CudaDoubleTensor' then assert(src:type() == self.typeStr or src:type() == 'torch.DoubleTensor', 'input must be a CudaDouble or Double tensor') diff --git a/test_DataParallelTable.lua b/test_DataParallelTable.lua index 2b25cf22..ec91b78e 100644 --- a/test_DataParallelTable.lua +++ b/test_DataParallelTable.lua @@ -25,7 +25,7 @@ local t2cpu = { local function checkHalf() if cutorch.hasHalf then table.insert(typenames, 'torch.CudaHalfTensor') - t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor' + t2cpu['torch.CudaHalfTensor'] = 'torch.HalfTensor' end end