diff --git a/DataParallelTable.lua b/DataParallelTable.lua index 9b5f3923..c3d7d739 100644 --- a/DataParallelTable.lua +++ b/DataParallelTable.lua @@ -231,8 +231,8 @@ function DataParallelTable:__backward(method, input, gradOutput, scale) self:_distribute(self.gradOutputGpu, gradOutput) self.gradInputGpu = self.impl:exec(function(m, i) - if torch.isTensor(inputGpu[i]) and inputGpu[i]:numel() == 0 then - return torch.CudaTensor() + if not _hasData(inputGpu[i]) then + return inputGpu[i] else return m[method](m, inputGpu[i], gradOutputGpu[i], scale) end @@ -246,8 +246,8 @@ function DataParallelTable:__backward(method, input, gradOutput, scale) if method == 'accGradParameters' then self.impl:exec(function(m, i) - if torch.isTensor(inputGpu[i]) and inputGpu[i]:numel() == 0 then - return torch.CudaTensor() + if not _hasData(inputGpu[i]) then + return inputGpu[i] else return m:accGradParameters(inputGpu[i], gradOutputGpu[i], scale) end