-
Notifications
You must be signed in to change notification settings - Fork 172
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds the ability to "genericize" cunn modules that can exist simultaneously with non-generic modules (i.e. modules can be genericized one at a time). Allowing both generic and non-generic modules simultaneously requires some extra code that can be removed once every module is genericized. Also genericizes SoftPlus in this way.
- Loading branch information
Showing
12 changed files
with
379 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,42 @@ | ||
#include "THCUNN.h" | ||
#include "common.h" | ||
#include "THCHalf.h" | ||
#include "THCHalfAutoNumerics.cuh" | ||
|
||
template <typename T> | ||
struct softPlusupdateOutput_functor | ||
{ | ||
const float threshold; | ||
const float beta; | ||
const T threshold; | ||
const T beta; | ||
|
||
softPlusupdateOutput_functor(float threshold_, float beta_) | ||
softPlusupdateOutput_functor(T threshold_, T beta_) | ||
: threshold(threshold_) | ||
, beta(beta_) | ||
{} | ||
|
||
__device__ void operator()(float *output, const float *input) const | ||
{ | ||
float betain = beta * (*input); | ||
__device__ void operator()(T *output, const T *input) const { | ||
T betain = beta * (*input); | ||
*output = ((betain) > threshold) ? *input : (1/beta) * log1p(exp(betain)); | ||
} | ||
}; | ||
|
||
void THNN_CudaSoftPlus_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, float beta, float threshold) | ||
{ | ||
THCUNN_assertSameGPU(state, 2, input, output); | ||
THCudaTensor_resizeAs(state, output, input); | ||
THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta)); | ||
} | ||
|
||
template <typename T> | ||
struct softPlusupdateGradInput_functor | ||
{ | ||
const float threshold; | ||
const float beta; | ||
const T threshold; | ||
const T beta; | ||
|
||
softPlusupdateGradInput_functor(float threshold_, float beta_) | ||
softPlusupdateGradInput_functor(T threshold_, T beta_) | ||
: threshold(threshold_) | ||
, beta(beta_) | ||
{} | ||
|
||
__device__ void operator()(float *gradInput, const float *output, const float *gradOutput) const | ||
__device__ void operator()(T *gradInput, const T *output, const T *gradOutput) const | ||
{ | ||
float betaout = beta * (*output); | ||
float exp_bo = exp(betaout); | ||
T betaout = beta * (*output); | ||
T exp_bo = exp(betaout); | ||
*gradInput = ((betaout) > threshold) ? *gradOutput : *gradOutput * (exp_bo - 1) / exp_bo; | ||
} | ||
}; | ||
|
||
void THNN_CudaSoftPlus_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, | ||
THCudaTensor *output, float beta, float threshold) | ||
{ | ||
THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput); | ||
THCudaTensor_resizeAs(state, gradInput, output); | ||
THC_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta)); | ||
} | ||
#include "generic/SoftPlus.cu" | ||
#include "THCGenerateFloatTypes.h" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#ifndef THC_HALF_AUTO_NUMERICS_INC | ||
#define THC_HALF_AUTO_NUMERICS_INC | ||
|
||
#include "THCHalf.h" | ||
#include "THCNumerics.cuh" | ||
|
||
// Half numerics functions defined as free functions, so cunn code can be | ||
//written generically, i.e. without calling THCNumerics<half> functions. | ||
|
||
#ifdef CUDA_HALF_TENSOR | ||
|
||
inline __host__ __device__ half operator+(half a, half b) { | ||
return THCNumerics<half>::add(a, b); | ||
} | ||
|
||
inline __host__ __device__ half operator-(half a, int b) { | ||
return THCNumerics<half>::add(a, THCNumerics<half>::neg(ScalarConvert<int, half>::to(b))); | ||
} | ||
|
||
// This implementation could move to THCNumerics | ||
inline __host__ __device__ half operator*(half a, half b) { | ||
#ifdef __CUDA_ARCH__ | ||
#ifdef CUDA_HALF_INSTRUCTIONS | ||
return __hmul(a, b); | ||
#else | ||
float fa = __half2float(a); | ||
float fb = __half2float(b); | ||
return __float2half( fa * fb ); | ||
#endif | ||
#else // __CUDA_ARCH__ | ||
return THC_float2half(THC_half2float(a) * THC_half2float(b)); | ||
#endif | ||
} | ||
|
||
inline __host__ __device__ bool operator>(half a, half b) { | ||
return THCNumerics<half>::gt(a, b); | ||
} | ||
|
||
inline __host__ __device__ half log1p(half a) { | ||
return THCNumerics<half>::log1p(a); | ||
} | ||
|
||
inline __host__ __device__ half exp(half a) { | ||
return THCNumerics<half>::exp(a); | ||
} | ||
|
||
// This implementation could move to THCNumerics | ||
inline __host__ __device__ half operator/(half a, half b) { | ||
#ifdef __CUDA_ARCH__ | ||
#ifdef CUDA_HALF_INSTRUCTIONS | ||
return __hdiv(a, b); | ||
#else | ||
float fa = __half2float(a); | ||
float fb = __half2float(b); | ||
return __float2half( fa / fb ); | ||
#endif | ||
#else // __CUDA_ARCH__ | ||
return THC_float2half(THC_half2float(a) / THC_half2float(b)); | ||
#endif | ||
} | ||
|
||
inline __host__ __device__ half operator/(int a, half b) { | ||
return ScalarConvert<int, half>::to(a) / b; | ||
} | ||
|
||
#endif | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#ifndef THC_GENERIC_FILE | ||
#define THC_GENERIC_FILE "generic/SoftPlus.cu" | ||
#else | ||
|
||
#include "../common.h" | ||
|
||
void THNN_(SoftPlus_updateOutput)( | ||
THCState *state, | ||
THCTensor *input, | ||
THCTensor *output, | ||
real beta, | ||
real threshold) | ||
{ | ||
THCUNN_assertSameGPU_generic(state, 2, input, output); | ||
THCTensor_(resizeAs)(state, output, input); | ||
THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor<real>(threshold, beta)); | ||
} | ||
|
||
void THNN_(SoftPlus_updateGradInput)( | ||
THCState *state, | ||
THCTensor *input, | ||
THCTensor *gradOutput, | ||
THCTensor *gradInput, | ||
THCTensor *output, | ||
real beta, | ||
real threshold) | ||
{ | ||
THCUNN_assertSameGPU_generic(state, 4, input, output, gradOutput, gradInput); | ||
THCTensor_(resizeAs)(state, gradInput, output); | ||
THC_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor<real>(threshold, beta)); | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
asdf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#ifndef THC_GENERIC_FILE | ||
#define THC_GENERIC_FILE "generic/THCUNN.h" | ||
#else | ||
|
||
TH_API void THNN_(SoftPlus_updateOutput)( | ||
THCState *state, | ||
THCTensor *input, | ||
THCTensor *output, | ||
real beta, | ||
real threshold); | ||
|
||
TH_API void THNN_(SoftPlus_updateGradInput)( | ||
THCState *state, | ||
THCTensor *input, | ||
THCTensor *gradOutput, | ||
THCTensor *gradInput, | ||
THCTensor *output, | ||
real beta, | ||
real threshold); | ||
|
||
#endif |
Oops, something went wrong.