diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ce811f3..501a2de3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,11 @@ FILE(WRITE THCUNN_h.lua "return [[") FILE(APPEND THCUNN_h.lua ${THCUNN_headers}) FILE(APPEND THCUNN_h.lua "]]") +FILE(STRINGS lib/THCUNN/generic/THCUNN.h THCUNN_generic_headers NEWLINE_CONSUME) +FILE(WRITE THCUNN_generic_h.lua "return [[") +FILE(APPEND THCUNN_generic_h.lua ${THCUNN_generic_headers}) +FILE(APPEND THCUNN_generic_h.lua "]]") + FILE(GLOB luasrc *.lua) ADD_SUBDIRECTORY(lib) diff --git a/THCUNN.lua b/THCUNN.lua index 771cd0ce..f4a8af30 100644 --- a/THCUNN.lua +++ b/THCUNN.lua @@ -19,7 +19,17 @@ local THCUNN_h = require 'cunn.THCUNN_h' THCUNN_h = THCUNN_h:gsub("\n#[^\n]*", "") THCUNN_h = THCUNN_h:gsub("^#[^\n]*\n", "") +local THCUNN_generic_h = require 'cunn.THCUNN_generic_h' +-- strip all lines starting with # +-- to remove preprocessor directives originally present +-- in THNN.h +THCUNN_generic_h = THCUNN_generic_h:gsub("\n#[^\n]*", "") +THCUNN_generic_h = THCUNN_generic_h:gsub("^#[^\n]*\n", "") + +ffi.cdef("half THC_float2half(float a);") + local preprocessed = string.gsub(THCUNN_h, 'TH_API ', '') +local preprocessed_generic = string.gsub(THCUNN_generic_h, 'TH_API void THNN_%(([%a%d_]+)%)', 'void THNN_TYPE%1') local replacements = { @@ -31,6 +41,38 @@ local replacements = } } +local cct2lt = { + ['THCudaFloatTensor'] = 'torch.CudaTensor', + ['THCudaDoubleTensor'] = 'torch.CudaDoubleTensor', +} + +local replacements_generic = +{ + { + ['THCTensor'] = 'THCudaTensor', + ['THIndexTensor'] = 'THCudaLongTensor', + ['TYPE'] = 'Cuda', + ['real'] = 'float' + }, + { + ['THCTensor'] = 'THCudaDoubleTensor', + ['THIndexTensor'] = 'THCudaLongTensor', + ['TYPE'] = 'CudaDouble', + ['real'] = 'double', + } +} + +if cutorch.hasHalf then + cct2lt['THCudaHalfTensor'] = 'torch.CudaHalfTensor' + local half_replacement = { + ['THCTensor'] = 'THCudaHalfTensor', + ['THIndexTensor'] = 'THCudaLongTensor', + ['TYPE'] = 'CudaHalf', + ['real'] = 'half' + } + table.insert(replacements_generic, half_replacement) +end + for i=1,#replacements do local r = replacements[i] local s = preprocessed @@ -40,6 +82,15 @@ for i=1,#replacements do ffi.cdef(s) end +for i=1,#replacements_generic do + local r = replacements_generic[i] + local s = preprocessed_generic + for k,v in pairs(r) do + s = string.gsub(s, k, v) + end + ffi.cdef(s) +end + local function extract_function_names(s) local t = {} for n in string.gmatch(s, 'TH_API void THNN_Cuda([%a%d_]+)') do @@ -48,10 +99,45 @@ local function extract_function_names(s) return t end +local function extract_function_names_generic(s) + local t = {} + for n in string.gmatch(s, 'TH_API void THNN_%(([%a%d_]+)%)') do + t[#t+1] = n + end + return t +end + -- build function table local function_names = extract_function_names(THCUNN_h) +local function_names_generic = extract_function_names_generic(THCUNN_generic_h) + +-- combine function names for CudaTensor +for k,v in pairs(function_names_generic) do + function_names[#function_names+1] = v +end THNN.kernels['torch.CudaTensor'] = THNN.bind(THCUNN.C, function_names, 'Cuda', THCUNN.getState) torch.getmetatable('torch.CudaTensor').THNN = THNN.kernels['torch.CudaTensor'] +THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'CudaDouble', THCUNN.getState) +torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor'] + +-- in order to call 'half' functions from lua, convert number arguments from +-- float to half +local transform_args_to_half = function(t) + for k,v in pairs(t) do + if torch.type(v) == 'number' then + t[k] = ffi.C.THC_float2half(t[k]) + end + end + return t +end + +local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState) +for k,v in pairs(raw_half_functions) do + raw_half_functions[k] = function(...) v(unpack(transform_args_to_half({...}))) end +end +THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions +torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor'] + return THCUNN diff --git a/lib/THCUNN/CMakeLists.txt b/lib/THCUNN/CMakeLists.txt index 84925037..5dd87126 100644 --- a/lib/THCUNN/CMakeLists.txt +++ b/lib/THCUNN/CMakeLists.txt @@ -31,6 +31,7 @@ ENDIF() FILE(GLOB src-cuda *.cu) +CUDA_INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) CUDA_ADD_LIBRARY(THCUNN MODULE ${src-cuda}) INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/lib/THCUNN/SoftPlus.cu b/lib/THCUNN/SoftPlus.cu index 0d1609ae..cb9ecb7d 100644 --- a/lib/THCUNN/SoftPlus.cu +++ b/lib/THCUNN/SoftPlus.cu @@ -1,52 +1,42 @@ #include "THCUNN.h" -#include "common.h" +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" +template struct softPlusupdateOutput_functor { - const float threshold; - const float beta; + const T threshold; + const T beta; - softPlusupdateOutput_functor(float threshold_, float beta_) + softPlusupdateOutput_functor(T threshold_, T beta_) : threshold(threshold_) , beta(beta_) {} - __device__ void operator()(float *output, const float *input) const - { - float betain = beta * (*input); + __device__ void operator()(T *output, const T *input) const { + T betain = beta * (*input); *output = ((betain) > threshold) ? *input : (1/beta) * log1p(exp(betain)); } }; -void THNN_CudaSoftPlus_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, float beta, float threshold) -{ - THCUNN_assertSameGPU(state, 2, input, output); - THCudaTensor_resizeAs(state, output, input); - THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta)); -} - +template struct softPlusupdateGradInput_functor { - const float threshold; - const float beta; + const T threshold; + const T beta; - softPlusupdateGradInput_functor(float threshold_, float beta_) + softPlusupdateGradInput_functor(T threshold_, T beta_) : threshold(threshold_) , beta(beta_) {} - __device__ void operator()(float *gradInput, const float *output, const float *gradOutput) const + __device__ void operator()(T *gradInput, const T *output, const T *gradOutput) const { - float betaout = beta * (*output); - float exp_bo = exp(betaout); + T betaout = beta * (*output); + T exp_bo = exp(betaout); *gradInput = ((betaout) > threshold) ? *gradOutput : *gradOutput * (exp_bo - 1) / exp_bo; } }; -void THNN_CudaSoftPlus_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, - THCudaTensor *output, float beta, float threshold) -{ - THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput); - THCudaTensor_resizeAs(state, gradInput, output); - THC_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta)); -} +#include "generic/SoftPlus.cu" +#include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/THCHalfAutoNumerics.cuh b/lib/THCUNN/THCHalfAutoNumerics.cuh new file mode 100644 index 00000000..ccbef0ad --- /dev/null +++ b/lib/THCUNN/THCHalfAutoNumerics.cuh @@ -0,0 +1,67 @@ +#ifndef THC_HALF_AUTO_NUMERICS_INC +#define THC_HALF_AUTO_NUMERICS_INC + +#include "THCHalf.h" +#include "THCNumerics.cuh" + +// Half numerics functions defined as free functions, so cunn code can be +//written generically, i.e. without calling THCNumerics functions. + +#ifdef CUDA_HALF_TENSOR + +inline __host__ __device__ half operator+(half a, half b) { + return THCNumerics::add(a, b); +} + +inline __host__ __device__ half operator-(half a, int b) { + return THCNumerics::add(a, THCNumerics::neg(ScalarConvert::to(b))); +} + +// This implementation could move to THCNumerics +inline __host__ __device__ half operator*(half a, half b) { + #ifdef __CUDA_ARCH__ + #ifdef CUDA_HALF_INSTRUCTIONS + return __hmul(a, b); + #else + float fa = __half2float(a); + float fb = __half2float(b); + return __float2half( fa * fb ); + #endif + #else // __CUDA_ARCH__ + return THC_float2half(THC_half2float(a) * THC_half2float(b)); + #endif +} + +inline __host__ __device__ bool operator>(half a, half b) { + return THCNumerics::gt(a, b); +} + +inline __host__ __device__ half log1p(half a) { + return THCNumerics::log1p(a); +} + +inline __host__ __device__ half exp(half a) { + return THCNumerics::exp(a); +} + +// This implementation could move to THCNumerics +inline __host__ __device__ half operator/(half a, half b) { + #ifdef __CUDA_ARCH__ + #ifdef CUDA_HALF_INSTRUCTIONS + return __hdiv(a, b); + #else + float fa = __half2float(a); + float fb = __half2float(b); + return __float2half( fa / fb ); + #endif + #else // __CUDA_ARCH__ + return THC_float2half(THC_half2float(a) / THC_half2float(b)); + #endif +} + +inline __host__ __device__ half operator/(int a, half b) { + return ScalarConvert::to(a) / b; +} + +#endif +#endif diff --git a/lib/THCUNN/THCUNN.h b/lib/THCUNN/THCUNN.h index 9aeef2df..a04b9182 100644 --- a/lib/THCUNN/THCUNN.h +++ b/lib/THCUNN/THCUNN.h @@ -4,6 +4,8 @@ #define THIndexTensor THCudaLongTensor #define THIndexTensor_(NAME) THCudaLongTensor_ ## NAME +#define THNN_(NAME) TH_CONCAT_3(THNN_, CReal, NAME) + TH_API void THNN_CudaAbs_updateOutput( THCState *state, THCudaTensor *input, @@ -342,21 +344,6 @@ TH_API void THNN_CudaSoftMax_updateGradInput( THCudaTensor *gradInput, THCudaTensor *output); -TH_API void THNN_CudaSoftPlus_updateOutput( - THCState *state, - THCudaTensor *input, - THCudaTensor *output, - float beta, - float threshold); -TH_API void THNN_CudaSoftPlus_updateGradInput( - THCState *state, - THCudaTensor *input, - THCudaTensor *gradOutput, - THCudaTensor *gradInput, - THCudaTensor *output, - float beta, - float threshold); - TH_API void THNN_CudaSoftShrink_updateOutput( THCState *state, THCudaTensor *input, @@ -1088,3 +1075,6 @@ TH_API void THNN_CudaVolumetricReplicationPadding_updateGradInput( int pleft, int pright, int ptop, int pbottom, int pfront, int pback); + +#include "generic/THCUNN.h" +#include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/common.h b/lib/THCUNN/common.h index e0975356..52897062 100644 --- a/lib/THCUNN/common.h +++ b/lib/THCUNN/common.h @@ -8,6 +8,10 @@ #define THCUNN_assertSameGPU(...) THAssertMsg(THCudaTensor_checkGPU(__VA_ARGS__), \ "Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one.") +// _generic can be removed once everything is genericized +#define THCUNN_assertSameGPU_generic(...) THAssertMsg(THCTensor_(checkGPU)(__VA_ARGS__), \ + "Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one.") + // Use 1024 threads per block, which requires cuda sm_2x or above const int CUDA_NUM_THREADS = 1024; diff --git a/lib/THCUNN/generic/SoftPlus.cu b/lib/THCUNN/generic/SoftPlus.cu new file mode 100644 index 00000000..39794b00 --- /dev/null +++ b/lib/THCUNN/generic/SoftPlus.cu @@ -0,0 +1,33 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/SoftPlus.cu" +#else + +#include "../common.h" + +void THNN_(SoftPlus_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + real beta, + real threshold) +{ + THCUNN_assertSameGPU_generic(state, 2, input, output); + THCTensor_(resizeAs)(state, output, input); + THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta)); +} + +void THNN_(SoftPlus_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + THCTensor *output, + real beta, + real threshold) +{ + THCUNN_assertSameGPU_generic(state, 4, input, output, gradOutput, gradInput); + THCTensor_(resizeAs)(state, gradInput, output); + THC_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta)); +} + +#endif diff --git a/lib/THCUNN/generic/SparseLinear.cu b/lib/THCUNN/generic/SparseLinear.cu new file mode 100644 index 00000000..5e40c087 --- /dev/null +++ b/lib/THCUNN/generic/SparseLinear.cu @@ -0,0 +1 @@ +asdf \ No newline at end of file diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h new file mode 100644 index 00000000..d70f0f44 --- /dev/null +++ b/lib/THCUNN/generic/THCUNN.h @@ -0,0 +1,21 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/THCUNN.h" +#else + +TH_API void THNN_(SoftPlus_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + real beta, + real threshold); + +TH_API void THNN_(SoftPlus_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + THCTensor *output, + real beta, + real threshold); + +#endif diff --git a/lib/THCUNN/generic/VolumetricMaxPooling.cu b/lib/THCUNN/generic/VolumetricMaxPooling.cu new file mode 100644 index 00000000..cd913a3a --- /dev/null +++ b/lib/THCUNN/generic/VolumetricMaxPooling.cu @@ -0,0 +1,45 @@ +/*#include "THCUNN.h" +#include "common.h" +#include "THCDeviceTensor.cuh" +#include "THCDeviceTensorUtils.cuh" +#include "THCDeviceUtils.cuh" + +#include */ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/VolumetricMaxPooling.cu" +#else + +#include "../common.h" + +void THNN_(VolumetricMaxPooling_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + THCTensor *indices, + int kT, int kW, int kH, + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceilMode) +{ + THNN_(VolumetricDilatedMaxPooling_updateOutput)( + state, input, output, indices, + kT, kW, kH, dT, dW, dH, padT, padW, padH, 1, 1, 1, ceilMode); + +} + +void THNN_(VolumetricMaxPooling_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + THCTensor *indices, + int dT, int dW, int dH, + int padT, int padW, int padH) +{ + THNN_(VolumetricDilatedMaxPooling_updateGradInput)( + state, input, gradOutput, gradInput, indices, + dT, dW, dH, padT, padW, padH, 1, 1, 1); + +} + +#endif diff --git a/test.lua b/test.lua index 48beced0..68a24c31 100644 --- a/test.lua +++ b/test.lua @@ -6,6 +6,41 @@ local times = {} --e.g.: th -lcunn -e "nn.testcuda{'Sigmoid_forward'}" +local typenames = { + 'torch.CudaTensor', + 'torch.CudaDoubleTensor', +} + +local t2cpu = { + ['torch.CudaTensor'] = 'torch.FloatTensor', + ['torch.CudaDoubleTensor'] = 'torch.DoubleTensor', + +} + +local function checkHalf() + if cutorch.hasHalf then + table.insert(typenames, 'torch.CudaHalfTensor') + t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor' + end +end + +-- half has additional error on top of double/float +local function precision_forward_type(tensor_type) + if (tensor_type == 'torch.CudaHalfTensor') then + return 1e-2 + precision_forward; + else + return precision_forward + end +end + +local function precision_backward_type(tensor_type) + if (tensor_type == 'torch.CudaHalfTensor') then + return 1e-2 + precision_backward; + else + return precision_backward + end +end + local function pointwise_forward(proto_module, name, max_error) local size = math.random(1,100) @@ -4210,67 +4245,77 @@ end function cunntest.SoftPlus_forward() local size = math.random(1,100) + local input = torch.randn(size) - local tm = {} - local title = string.format('SoftPlus forward %d -> %d', size, size) - times[title] = tm + for k, typename in ipairs(typenames) do + local tm = {} + local title = string.format('SoftPlus (%s) forward %d -> %d', typename, size, size) + times[title] = tm - local input = torch.randn(size) - local sconv = nn.SoftPlus() - local groundtruth = sconv:forward(input) - local a = torch.Timer() - for i = 1,nloop do - groundtruth = sconv:forward(input) - end - tm.cpu = a:time().real + local ctype = t2cpu[typename] + local input = input:type(ctype) + local sconv = nn.SoftPlus():type(ctype) + local groundtruth = sconv:forward(input) + local a = torch.Timer() + for i = 1,nloop do + groundtruth = sconv:forward(input) + end + tm.cpu = a:time().real - input = input:cuda() - local gconv = nn.SoftPlus():cuda() - local rescuda = gconv:forward(input) - a:reset() - for i = 1,nloop do - rescuda = gconv:forward(input) - end - cutorch.synchronize() - tm.gpu = a:time().real + input = input:type(typename) + local gconv = nn.SoftPlus():type(typename) + local rescuda = gconv:forward(input) + a:reset() + for i = 1,nloop do + rescuda = gconv:forward(input) + end + cutorch.synchronize() + tm.gpu = a:time().real - local error = rescuda:float() - groundtruth - mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') + local error = rescuda:double() - groundtruth:double() + mytester:assertlt(error:abs():max(), precision_forward_type(typename), + string.format('error on state (forward) with %s', typename)) + end end function cunntest.SoftPlus_backward() local size = math.random(1,100) - - local tm = {} - local title = string.format('SoftPlus.backward %d -> %d', size, size) - times[title] = tm - local input = torch.randn(size) local gradOutput = torch.randn(size) - local sconv = nn.SoftPlus() - sconv:forward(input) - local groundgrad = sconv:backward(input, gradOutput) - local a = torch.Timer() - for i = 1,nloop do - groundgrad = sconv:backward(input, gradOutput) - end - tm.cpu = a:time().real - input = input:cuda() - gradOutput = gradOutput:cuda() - local gconv = sconv:clone():cuda() - gconv:forward(input) - local rescuda = gconv:backward(input, gradOutput) - a:reset() - for i = 1,nloop do - rescuda = gconv:backward(input, gradOutput) - end - cutorch.synchronize() - tm.gpu = a:time().real + for k, typename in ipairs(typenames) do + local tm = {} + local title = string.format('SoftPlus.backward (%s) %d -> %d', typename, size, size) + times[title] = tm - local error = rescuda:float() - groundgrad + local ctype = t2cpu[typename] + local input = input:type(ctype) + local gradOutput = gradOutput:type(ctype) + local sconv = nn.SoftPlus():type(ctype) + sconv:forward(input) + local groundgrad = sconv:backward(input, gradOutput) + local a = torch.Timer() + for i = 1,nloop do + groundgrad = sconv:backward(input, gradOutput) + end + tm.cpu = a:time().real - mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + input = input:type(typename) + gradOutput = gradOutput:type(typename) + local gconv = sconv:clone():type(typename) + gconv:forward(input) + local rescuda = gconv:backward(input, gradOutput) + a:reset() + for i = 1,nloop do + rescuda = gconv:backward(input, gradOutput) + end + cutorch.synchronize() + tm.gpu = a:time().real + + local error = rescuda:double() - groundgrad:double() + mytester:assertlt(error:abs():max(), precision_backward_type(typename), + string.format('error on state (backward) with %s', typename)) + end end function cunntest.SpatialUpSamplingNearest_forward() @@ -6508,6 +6553,7 @@ function nn.testcuda(tests, print_timing, n_loop, seed) nloop = n_loop or nloop local oldtype = torch.getdefaulttensortype() torch.setdefaulttensortype('torch.FloatTensor') + checkHalf() initSeed(seed) mytester = torch.Tester() mytester:add(cunntest)