Skip to content

Commit

Permalink
Add cudaHalf, cudaDouble functions to nn.Module. (#395)
Browse files Browse the repository at this point in the history
* Add cudaHalf, cudaDouble functions to nn.Module.
  • Loading branch information
gchanan authored and soumith committed Dec 9, 2016
1 parent ff37975 commit 3e652e3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
10 changes: 10 additions & 0 deletions THCUNN.lua
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,14 @@ THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
end

local function Module__converter(type)
return function(self)
return self:type(type)
end
end

rawset(torch.getmetatable('nn.Module'), 'cudaDouble', Module__converter('torch.CudaDoubleTensor'))
if cutorch.hasHalf then
rawset(torch.getmetatable('nn.Module'), 'cudaHalf', Module__converter('torch.CudaHalfTensor'))
end
return THCUNN
19 changes: 19 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5439,6 +5439,25 @@ function cunntest.VolumetricReplicationPadding_backward()
end
end

function cunntest.ModuleConversionFunctions()
local module = nn.Tanh() -- arbitrary module
local input = torch.randn(10)

module:cuda()
mytester:assert(module:type() == 'torch.CudaTensor')
module:forward(input:type('torch.CudaTensor'))

module:cudaDouble()
mytester:assert(module:type() == 'torch.CudaDoubleTensor')
module:forward(input:type('torch.CudaDoubleTensor'))

if cutorch.hasHalf then
module:cudaHalf()
mytester:assert(module:type() == 'torch.CudaHalfTensor')
module:forward(input:type('torch.CudaHalfTensor'))
end
end

function cunntest.GPU()
local ndevice = cutorch.getDeviceCount()
if ndevice < 2 then
Expand Down

0 comments on commit 3e652e3

Please sign in to comment.