-
Notifications
You must be signed in to change notification settings - Fork 172
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #434 from bottler/master
VolumetricFractionalMaxPooling like spatial
- Loading branch information
Showing
3 changed files
with
306 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#include "THCUNN.h" | ||
#include "common.h" | ||
#include "THCDeviceTensor.cuh" | ||
#include "THCDeviceTensorUtils.cuh" | ||
#include "THCDeviceUtils.cuh" | ||
#include "THCHalf.h" | ||
#include "THCHalfAutoNumerics.cuh" | ||
#include "THCAtomics.cuh" | ||
|
||
#include <cfloat> | ||
|
||
template <typename Dtype, typename Acctype> | ||
__device__ inline int getInterval(Acctype sample, | ||
int index, | ||
int inputSize, | ||
int outputSize, | ||
int poolSize) { | ||
Acctype alpha = (Acctype)(inputSize - poolSize) / (Acctype) (outputSize - 1); | ||
if (index == outputSize - 1) { | ||
return inputSize - poolSize; | ||
} else { | ||
return (int) ((index + sample) * alpha) - (int) (sample * alpha); | ||
} | ||
} | ||
|
||
// We template on poolSizeW to allow the innermost loop to be unrolled | ||
template <int PoolSizeTStatic, typename Dtype, typename Acctype> | ||
__global__ void VolumetricFractionalMaxPooling_updateOutput( | ||
THCDeviceTensor<Dtype, 5> input, | ||
THCDeviceTensor<Dtype, 5> output, | ||
THCDeviceTensor<THCIndex_t, 5> indices, | ||
THCDeviceTensor<Dtype, 3> samples, | ||
int poolSizeT, int poolSizeW, int poolSizeH) { | ||
|
||
// Output (h, w) point that this thread is responsible for | ||
int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; | ||
int plane = blockIdx.y; | ||
int batch = blockIdx.z; | ||
|
||
// Each thread generates a specific output point | ||
if (ourOutputPoint < output.getSize(2) * output.getSize(3) * output.getSize(4)){ | ||
int outputT = ourOutputPoint % output.getSize(4); | ||
int outputW = (ourOutputPoint / output.getSize(4)) % output.getSize(3); | ||
int outputH = ourOutputPoint / (output.getSize(3)*output.getSize(4)); | ||
|
||
int poolT = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][0]), outputT, | ||
input.getSize(4), output.getSize(4), poolSizeT); | ||
int poolW = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][1]), outputW, | ||
input.getSize(3), output.getSize(3), poolSizeW); | ||
int poolH = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][2]), outputH, | ||
input.getSize(2), output.getSize(2), poolSizeH); | ||
|
||
Dtype maxVal = THCNumerics<Dtype>::min(); | ||
int maxIndex = -1; | ||
|
||
for (int h = poolH; h < poolH + poolSizeH; ++h) { | ||
for (int w = poolW; w < poolW + poolSizeW; ++w) { | ||
if (PoolSizeTStatic == -1) { | ||
for (int t = poolT; t < poolT + poolSizeT; ++t) { | ||
Dtype val = input[batch][plane][h][w][t]; | ||
// for consistency with THNN, favor the first max | ||
if (val > maxVal) { | ||
maxIndex = h * input.getSize(3)*input.getSize(4) + w * input.getSize(4) + t; | ||
maxVal = val; | ||
} | ||
} | ||
} else { | ||
#pragma unroll | ||
for (int i = 0; i < PoolSizeTStatic; ++i) { | ||
int t = i + poolT; | ||
Dtype val = input[batch][plane][h][w][t]; | ||
// for consistency with THNN, favor the first max | ||
if (val > maxVal) { | ||
maxIndex = h * input.getSize(3)*input.getSize(4) + w * input.getSize(4) + t; | ||
maxVal = val; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
assert(THCNumerics<Dtype>::ne(maxVal, THCNumerics<Dtype>::min())); | ||
assert(maxIndex != -1); | ||
|
||
// +1 for Lua index | ||
indices[batch][plane][outputH][outputW][outputT] = maxIndex + TH_INDEX_BASE; | ||
output[batch][plane][outputH][outputW][outputT] = maxVal; | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
__global__ void VolumetricFractionalMaxPooling_updateGradInput( | ||
THCDeviceTensor<Dtype, 5> gradInput, | ||
THCDeviceTensor<Dtype, 5> gradOutput, | ||
THCDeviceTensor<THCIndex_t, 5> indices) { | ||
// Output (h, w) point that this thread is responsible for | ||
int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; | ||
int plane = blockIdx.y; | ||
int batch = blockIdx.z; | ||
|
||
// Each thread generates a specific output point | ||
if (ourOutputPoint < gradOutput.getSize(2) * gradOutput.getSize(3) * gradOutput.getSize(4)) { | ||
int outputT = ourOutputPoint % gradOutput.getSize(4); | ||
int outputW = (ourOutputPoint / gradOutput.getSize(4)) % gradOutput.getSize(3); | ||
int outputH = ourOutputPoint / (gradOutput.getSize(3)*gradOutput.getSize(4)); | ||
|
||
int index = indices[batch][plane][outputH][outputW][outputT] - TH_INDEX_BASE; | ||
assert(index >= 0); | ||
int inputT = index % gradInput.getSize(4); | ||
int inputW = (index / gradInput.getSize(4)) % gradInput.getSize(3); | ||
int inputH = index / (gradInput.getSize(3) * gradInput.getSize(4)); | ||
assert(inputH < gradInput.getSize(2)); | ||
|
||
atomicAdd(gradInput[batch][plane][inputH][inputW][inputT].data(), | ||
gradOutput[batch][plane][outputH][outputW][outputT]); | ||
} | ||
} | ||
|
||
#include "generic/VolumetricFractionalMaxPooling.cu" | ||
#include "THCGenerateFloatTypes.h" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
#ifndef THC_GENERIC_FILE | ||
#define THC_GENERIC_FILE "generic/VolumetricFractionalMaxPooling.cu" | ||
#else | ||
|
||
void THNN_(VolumetricFractionalMaxPooling_updateOutput)( | ||
THCState *state, | ||
THCTensor *input, | ||
THCTensor *output, | ||
int outputT, int outputW, int outputH, | ||
int poolSizeT, int poolSizeW, int poolSizeH, | ||
THCIndexTensor *indices, | ||
THCTensor *randomSamples) | ||
{ | ||
int planeDim = 0; | ||
int dimh = 1; | ||
int dimw = 2; | ||
int dimt = 3; | ||
long numBatch = 1; | ||
|
||
long numInputDims = THCTensor_(nDimension)(state, input); | ||
THCUNN_argCheck(state, numInputDims == 4 || numInputDims == 5, 2, input, | ||
"4D or 5D (batch mode) tensor expected for input, but got: %s"); | ||
|
||
if (numInputDims == 5) { | ||
numBatch = THCTensor_(size)(state, input, 0); | ||
planeDim++; | ||
dimh++; | ||
dimw++; | ||
dimt++; | ||
} | ||
|
||
/* sizes */ | ||
long numPlanes = THCTensor_(size)(state, input, planeDim); | ||
long inputH = THCTensor_(size)(state, input, dimh); | ||
long inputW = THCTensor_(size)(state, input, dimw); | ||
long inputT = THCTensor_(size)(state, input, dimt); | ||
|
||
THArgCheck(outputH + poolSizeH - 1 < inputH, 7, | ||
"poolSizeH (%d) too large relative to input height (%d)", | ||
poolSizeH, inputH); | ||
THArgCheck(outputW + poolSizeW - 1 < inputW, 6, | ||
"poolSizeW (%d) too large relative to input width (%d)", | ||
poolSizeW, inputW); | ||
THArgCheck(outputT + poolSizeT - 1 < inputW, 5, | ||
"poolSizeT (%d) too large relative to input time (%d)", | ||
poolSizeT, inputT); | ||
|
||
THCDeviceTensor<real, 5> devInput; | ||
THCDeviceTensor<real, 5> devOutput; | ||
THCDeviceTensor<THCIndex_t, 5> devIndices; | ||
THCDeviceTensor<real, 3> devSamples = | ||
toDeviceTensor<real, 3>(state, randomSamples); | ||
|
||
if (numInputDims == 4) { | ||
/* resize output */ | ||
THCTensor_(resize4d)(state, output, numPlanes, outputH, outputW, outputT); | ||
/* indices will contain the locations for each output point */ | ||
THCIndexTensor_(resize4d)(state, indices, numPlanes, outputH, outputW, outputT); | ||
|
||
devInput = toDeviceTensor<real, 4>(state, input).upcastOuter<5>(); | ||
devOutput = toDeviceTensor<real, 4>(state, output).upcastOuter<5>(); | ||
devIndices = toDeviceTensor<THCIndex_t, 4>(state, indices).upcastOuter<5>(); | ||
} else { | ||
THCTensor_(resize5d)(state, output, numBatch, numPlanes, outputH, outputW, outputT); | ||
/* indices will contain the locations for each output point */ | ||
THCIndexTensor_(resize5d)(state, indices, numBatch, numPlanes, outputH, outputW, outputT); | ||
|
||
devInput = toDeviceTensor<real, 5>(state, input); | ||
devOutput = toDeviceTensor<real, 5>(state, output); | ||
devIndices = toDeviceTensor<THCIndex_t, 5>(state, indices); | ||
} | ||
|
||
// block is limited to 4 warps | ||
// grid handles overflow per each plane | ||
int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3) * devOutput.getSize(4); | ||
dim3 grid(THCCeilDiv(outputPlaneSize, 128), | ||
devInput.getSize(1), | ||
devInput.getSize(0)); | ||
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); | ||
|
||
#define SFMP_UPDATE_OUTPUT(POOL_W) \ | ||
VolumetricFractionalMaxPooling_updateOutput<POOL_W, real, accreal> \ | ||
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ | ||
devInput, devOutput, devIndices, devSamples, poolSizeT, poolSizeW, poolSizeH); | ||
|
||
#define SFMP_UPDATE_OUTPUT_CASE(POOL_W) \ | ||
case POOL_W: SFMP_UPDATE_OUTPUT(POOL_W); break | ||
|
||
switch (poolSizeW) { | ||
SFMP_UPDATE_OUTPUT_CASE(2); | ||
SFMP_UPDATE_OUTPUT_CASE(3); | ||
SFMP_UPDATE_OUTPUT_CASE(4); | ||
SFMP_UPDATE_OUTPUT_CASE(5); | ||
SFMP_UPDATE_OUTPUT_CASE(6); | ||
SFMP_UPDATE_OUTPUT_CASE(7); | ||
default: | ||
// dynamic pool width | ||
SFMP_UPDATE_OUTPUT_CASE(-1); | ||
} | ||
THCudaCheck(cudaGetLastError()); | ||
} | ||
|
||
void THNN_(VolumetricFractionalMaxPooling_updateGradInput)( | ||
THCState *state, | ||
THCTensor *input, | ||
THCTensor *gradOutput, | ||
THCTensor *gradInput, | ||
int outputT, int outputW, int outputH, | ||
int poolSizeT, int poolSizeW, int poolSizeH, | ||
THCIndexTensor *indices) | ||
{ | ||
int dimh = 1; | ||
int dimw = 2; | ||
int dimt = 3; | ||
|
||
long numInputDims = THCTensor_(nDimension)(state, input); | ||
if (numInputDims == 5) { | ||
dimh++; | ||
dimw++; | ||
dimt++; | ||
} | ||
|
||
/* sizes */ | ||
long inputH = THCTensor_(size)(state, input, dimh); | ||
long inputW = THCTensor_(size)(state, input, dimw); | ||
long inputT = THCTensor_(size)(state, input, dimt); | ||
|
||
THArgCheck(outputH == THCTensor_(size)(state, gradOutput, dimh), 3, | ||
"gradOutput height unexpected"); | ||
THArgCheck(outputW == THCTensor_(size)(state, gradOutput, dimw), 3, | ||
"gradOutput width unexpected"); | ||
THArgCheck(outputT == THCTensor_(size)(state, gradOutput, dimt), 3, | ||
"gradOutput time unexpected"); | ||
|
||
/* resize */ | ||
THCTensor_(resizeAs)(state, gradInput, input); | ||
THCTensor_(zero)(state, gradInput); | ||
|
||
THCDeviceTensor<real, 5> devGradInput; | ||
THCDeviceTensor<real, 5> devGradOutput; | ||
THCDeviceTensor<THCIndex_t, 5> devIndices; | ||
|
||
/* backprop */ | ||
if (numInputDims == 4) { | ||
devGradInput = toDeviceTensor<real, 4>(state, gradInput).upcastOuter<5>(); | ||
devGradOutput = toDeviceTensor<real, 4>(state, gradOutput).upcastOuter<5>(); | ||
devIndices = toDeviceTensor<THCIndex_t, 4>(state, indices).upcastOuter<5>(); | ||
} else { | ||
devGradInput = toDeviceTensor<real, 5>(state, gradInput); | ||
devGradOutput = toDeviceTensor<real, 5>(state, gradOutput); | ||
devIndices = toDeviceTensor<THCIndex_t, 5>(state, indices); | ||
} | ||
|
||
// block is limited to 4 warps | ||
// grid handles overflow per each plane | ||
int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3) * devGradOutput.getSize(4); | ||
dim3 grid(THCCeilDiv(outputPlaneSize, 128), | ||
devGradInput.getSize(1), | ||
devGradInput.getSize(0)); | ||
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); | ||
|
||
VolumetricFractionalMaxPooling_updateGradInput | ||
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( | ||
devGradInput, devGradOutput, devIndices); | ||
THCudaCheck(cudaGetLastError()); | ||
} | ||
|
||
#endif |