Skip to content

Commit

Permalink
Add generic support for SoftPlus.
Browse files Browse the repository at this point in the history
Adds the ability to "genericize" cunn modules that can exist
simultaneously with non-generic modules (i.e. modules can
be genericized one at a time).  Allowing both generic and
non-generic modules simultaneously requires some extra code
that can be removed once every module is genericized.
Also genericizes SoftPlus in this way.
  • Loading branch information
gchanan committed Nov 8, 2016
1 parent aa256bc commit 69491a1
Show file tree
Hide file tree
Showing 12 changed files with 379 additions and 90 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ FILE(WRITE THCUNN_h.lua "return [[")
FILE(APPEND THCUNN_h.lua ${THCUNN_headers})
FILE(APPEND THCUNN_h.lua "]]")

FILE(STRINGS lib/THCUNN/generic/THCUNN.h THCUNN_generic_headers NEWLINE_CONSUME)
FILE(WRITE THCUNN_generic_h.lua "return [[")
FILE(APPEND THCUNN_generic_h.lua ${THCUNN_generic_headers})
FILE(APPEND THCUNN_generic_h.lua "]]")

FILE(GLOB luasrc *.lua)

ADD_SUBDIRECTORY(lib)
Expand Down
86 changes: 86 additions & 0 deletions THCUNN.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@ local THCUNN_h = require 'cunn.THCUNN_h'
THCUNN_h = THCUNN_h:gsub("\n#[^\n]*", "")
THCUNN_h = THCUNN_h:gsub("^#[^\n]*\n", "")

local THCUNN_generic_h = require 'cunn.THCUNN_generic_h'
-- strip all lines starting with #
-- to remove preprocessor directives originally present
-- in THNN.h
THCUNN_generic_h = THCUNN_generic_h:gsub("\n#[^\n]*", "")
THCUNN_generic_h = THCUNN_generic_h:gsub("^#[^\n]*\n", "")

ffi.cdef("half THC_float2half(float a);")

local preprocessed = string.gsub(THCUNN_h, 'TH_API ', '')
local preprocessed_generic = string.gsub(THCUNN_generic_h, 'TH_API void THNN_%(([%a%d_]+)%)', 'void THNN_TYPE%1')

local replacements =
{
Expand All @@ -31,6 +41,38 @@ local replacements =
}
}

local cct2lt = {
['THCudaFloatTensor'] = 'torch.CudaTensor',
['THCudaDoubleTensor'] = 'torch.CudaDoubleTensor',
}

local replacements_generic =
{
{
['THCTensor'] = 'THCudaTensor',
['THIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'Cuda',
['real'] = 'float'
},
{
['THCTensor'] = 'THCudaDoubleTensor',
['THIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'CudaDouble',
['real'] = 'double',
}
}

if cutorch.hasHalf then
cct2lt['THCudaHalfTensor'] = 'torch.CudaHalfTensor'
local half_replacement = {
['THCTensor'] = 'THCudaHalfTensor',
['THIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'CudaHalf',
['real'] = 'half'
}
table.insert(replacements_generic, half_replacement)
end

for i=1,#replacements do
local r = replacements[i]
local s = preprocessed
Expand All @@ -40,6 +82,15 @@ for i=1,#replacements do
ffi.cdef(s)
end

for i=1,#replacements_generic do
local r = replacements_generic[i]
local s = preprocessed_generic
for k,v in pairs(r) do
s = string.gsub(s, k, v)
end
ffi.cdef(s)
end

local function extract_function_names(s)
local t = {}
for n in string.gmatch(s, 'TH_API void THNN_Cuda([%a%d_]+)') do
Expand All @@ -48,10 +99,45 @@ local function extract_function_names(s)
return t
end

local function extract_function_names_generic(s)
local t = {}
for n in string.gmatch(s, 'TH_API void THNN_%(([%a%d_]+)%)') do
t[#t+1] = n
end
return t
end

-- build function table
local function_names = extract_function_names(THCUNN_h)
local function_names_generic = extract_function_names_generic(THCUNN_generic_h)

-- combine function names for CudaTensor
for k,v in pairs(function_names_generic) do
function_names[#function_names+1] = v
end

THNN.kernels['torch.CudaTensor'] = THNN.bind(THCUNN.C, function_names, 'Cuda', THCUNN.getState)
torch.getmetatable('torch.CudaTensor').THNN = THNN.kernels['torch.CudaTensor']

THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'CudaDouble', THCUNN.getState)
torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor']

-- in order to call 'half' functions from lua, convert number arguments from
-- float to half
local transform_args_to_half = function(t)
for k,v in pairs(t) do
if torch.type(v) == 'number' then
t[k] = ffi.C.THC_float2half(t[k])
end
end
return t
end

local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
for k,v in pairs(raw_half_functions) do
raw_half_functions[k] = function(...) v(unpack(transform_args_to_half({...}))) end
end
THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']

return THCUNN
1 change: 1 addition & 0 deletions lib/THCUNN/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ ENDIF()

FILE(GLOB src-cuda *.cu)

CUDA_INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR})
CUDA_ADD_LIBRARY(THCUNN MODULE ${src-cuda})

INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR})
Expand Down
44 changes: 17 additions & 27 deletions lib/THCUNN/SoftPlus.cu
Original file line number Diff line number Diff line change
@@ -1,52 +1,42 @@
#include "THCUNN.h"
#include "common.h"
#include "THCHalf.h"
#include "THCHalfAutoNumerics.cuh"

template <typename T>
struct softPlusupdateOutput_functor
{
const float threshold;
const float beta;
const T threshold;
const T beta;

softPlusupdateOutput_functor(float threshold_, float beta_)
softPlusupdateOutput_functor(T threshold_, T beta_)
: threshold(threshold_)
, beta(beta_)
{}

__device__ void operator()(float *output, const float *input) const
{
float betain = beta * (*input);
__device__ void operator()(T *output, const T *input) const {
T betain = beta * (*input);
*output = ((betain) > threshold) ? *input : (1/beta) * log1p(exp(betain));
}
};

void THNN_CudaSoftPlus_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, float beta, float threshold)
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta));
}

template <typename T>
struct softPlusupdateGradInput_functor
{
const float threshold;
const float beta;
const T threshold;
const T beta;

softPlusupdateGradInput_functor(float threshold_, float beta_)
softPlusupdateGradInput_functor(T threshold_, T beta_)
: threshold(threshold_)
, beta(beta_)
{}

__device__ void operator()(float *gradInput, const float *output, const float *gradOutput) const
__device__ void operator()(T *gradInput, const T *output, const T *gradOutput) const
{
float betaout = beta * (*output);
float exp_bo = exp(betaout);
T betaout = beta * (*output);
T exp_bo = exp(betaout);
*gradInput = ((betaout) > threshold) ? *gradOutput : *gradOutput * (exp_bo - 1) / exp_bo;
}
};

void THNN_CudaSoftPlus_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput,
THCudaTensor *output, float beta, float threshold)
{
THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, output);
THC_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta));
}
#include "generic/SoftPlus.cu"
#include "THCGenerateFloatTypes.h"
67 changes: 67 additions & 0 deletions lib/THCUNN/THCHalfAutoNumerics.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#ifndef THC_HALF_AUTO_NUMERICS_INC
#define THC_HALF_AUTO_NUMERICS_INC

#include "THCHalf.h"
#include "THCNumerics.cuh"

// Half numerics functions defined as free functions, so cunn code can be
//written generically, i.e. without calling THCNumerics<half> functions.

#ifdef CUDA_HALF_TENSOR

inline __host__ __device__ half operator+(half a, half b) {
return THCNumerics<half>::add(a, b);
}

inline __host__ __device__ half operator-(half a, int b) {
return THCNumerics<half>::add(a, THCNumerics<half>::neg(ScalarConvert<int, half>::to(b)));
}

// This implementation could move to THCNumerics
inline __host__ __device__ half operator*(half a, half b) {
#ifdef __CUDA_ARCH__
#ifdef CUDA_HALF_INSTRUCTIONS
return __hmul(a, b);
#else
float fa = __half2float(a);
float fb = __half2float(b);
return __float2half( fa * fb );
#endif
#else // __CUDA_ARCH__
return THC_float2half(THC_half2float(a) * THC_half2float(b));
#endif
}

inline __host__ __device__ bool operator>(half a, half b) {
return THCNumerics<half>::gt(a, b);
}

inline __host__ __device__ half log1p(half a) {
return THCNumerics<half>::log1p(a);
}

inline __host__ __device__ half exp(half a) {
return THCNumerics<half>::exp(a);
}

// This implementation could move to THCNumerics
inline __host__ __device__ half operator/(half a, half b) {
#ifdef __CUDA_ARCH__
#ifdef CUDA_HALF_INSTRUCTIONS
return __hdiv(a, b);
#else
float fa = __half2float(a);
float fb = __half2float(b);
return __float2half( fa / fb );
#endif
#else // __CUDA_ARCH__
return THC_float2half(THC_half2float(a) / THC_half2float(b));
#endif
}

inline __host__ __device__ half operator/(int a, half b) {
return ScalarConvert<int, half>::to(a) / b;
}

#endif
#endif
20 changes: 5 additions & 15 deletions lib/THCUNN/THCUNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#define THIndexTensor THCudaLongTensor
#define THIndexTensor_(NAME) THCudaLongTensor_ ## NAME

#define THNN_(NAME) TH_CONCAT_3(THNN_, CReal, NAME)

TH_API void THNN_CudaAbs_updateOutput(
THCState *state,
THCudaTensor *input,
Expand Down Expand Up @@ -342,21 +344,6 @@ TH_API void THNN_CudaSoftMax_updateGradInput(
THCudaTensor *gradInput,
THCudaTensor *output);

TH_API void THNN_CudaSoftPlus_updateOutput(
THCState *state,
THCudaTensor *input,
THCudaTensor *output,
float beta,
float threshold);
TH_API void THNN_CudaSoftPlus_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput,
THCudaTensor *output,
float beta,
float threshold);

TH_API void THNN_CudaSoftShrink_updateOutput(
THCState *state,
THCudaTensor *input,
Expand Down Expand Up @@ -1088,3 +1075,6 @@ TH_API void THNN_CudaVolumetricReplicationPadding_updateGradInput(
int pleft, int pright,
int ptop, int pbottom,
int pfront, int pback);

#include "generic/THCUNN.h"
#include "THCGenerateFloatTypes.h"
4 changes: 4 additions & 0 deletions lib/THCUNN/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#define THCUNN_assertSameGPU(...) THAssertMsg(THCudaTensor_checkGPU(__VA_ARGS__), \
"Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one.")

// _generic can be removed once everything is genericized
#define THCUNN_assertSameGPU_generic(...) THAssertMsg(THCTensor_(checkGPU)(__VA_ARGS__), \
"Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one.")

// Use 1024 threads per block, which requires cuda sm_2x or above
const int CUDA_NUM_THREADS = 1024;

Expand Down
33 changes: 33 additions & 0 deletions lib/THCUNN/generic/SoftPlus.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/SoftPlus.cu"
#else

#include "../common.h"

void THNN_(SoftPlus_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
real beta,
real threshold)
{
THCUNN_assertSameGPU_generic(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor<real>(threshold, beta));
}

void THNN_(SoftPlus_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
real beta,
real threshold)
{
THCUNN_assertSameGPU_generic(state, 4, input, output, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, output);
THC_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor<real>(threshold, beta));
}

#endif
1 change: 1 addition & 0 deletions lib/THCUNN/generic/SparseLinear.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
asdf
21 changes: 21 additions & 0 deletions lib/THCUNN/generic/THCUNN.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCUNN.h"
#else

TH_API void THNN_(SoftPlus_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
real beta,
real threshold);

TH_API void THNN_(SoftPlus_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
real beta,
real threshold);

#endif
Loading

0 comments on commit 69491a1

Please sign in to comment.