diff --git a/THCUNN.lua b/THCUNN.lua index ac05ee05..8bbca17e 100644 --- a/THCUNN.lua +++ b/THCUNN.lua @@ -155,4 +155,14 @@ THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor'] end +local function Module__converter(type) + return function(self) + return self:type(type) + end +end + +rawset(torch.getmetatable('nn.Module'), 'cudaDouble', Module__converter('torch.CudaDoubleTensor')) +if cutorch.hasHalf then + rawset(torch.getmetatable('nn.Module'), 'cudaHalf', Module__converter('torch.CudaHalfTensor')) +end return THCUNN diff --git a/test.lua b/test.lua index 2820baf4..598198f7 100644 --- a/test.lua +++ b/test.lua @@ -5439,6 +5439,25 @@ function cunntest.VolumetricReplicationPadding_backward() end end +function cunntest.ModuleConversionFunctions() + local module = nn.Tanh() -- arbitrary module + local input = torch.randn(10) + + module:cuda() + mytester:assert(module:type() == 'torch.CudaTensor') + module:forward(input:type('torch.CudaTensor')) + + module:cudaDouble() + mytester:assert(module:type() == 'torch.CudaDoubleTensor') + module:forward(input:type('torch.CudaDoubleTensor')) + + if cutorch.hasHalf then + module:cudaHalf() + mytester:assert(module:type() == 'torch.CudaHalfTensor') + module:forward(input:type('torch.CudaHalfTensor')) + end +end + function cunntest.GPU() local ndevice = cutorch.getDeviceCount() if ndevice < 2 then