Skip to content

Commit

Permalink
Add support for Half tensors to DataParallelTable.
Browse files Browse the repository at this point in the history
  • Loading branch information
gchanan committed Dec 28, 2016
1 parent 4173b22 commit 4a21b71
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions DataParallelTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ function DataParallelTable:_distributeTensorRecursive(dst, src, idx, n)

assert(torch.isTensor(src), 'input must be a tensor or table of tensors')
if self.typeStr == 'torch.CudaHalfTensor' then
assert(false,
'Half Tensors not supported yet by DataParallelTable')
assert(src:type() == self.typeStr or src:type() == 'torch.HalfTensor',
'input must be a CudaHalf or Half tensor')
elseif self.typeStr == 'torch.CudaDoubleTensor' then
assert(src:type() == self.typeStr or src:type() == 'torch.DoubleTensor',
'input must be a CudaDouble or Double tensor')
Expand Down
2 changes: 1 addition & 1 deletion test_DataParallelTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ local t2cpu = {
local function checkHalf()
if cutorch.hasHalf then
table.insert(typenames, 'torch.CudaHalfTensor')
t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor'
t2cpu['torch.CudaHalfTensor'] = 'torch.HalfTensor'
end
end

Expand Down

0 comments on commit 4a21b71

Please sign in to comment.