Skip to content

Commit

Permalink
remove limitations on output_padding in Conv* routines
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Aug 3, 2017
1 parent bbebfdc commit e9ef2d5
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 22 deletions.
2 changes: 1 addition & 1 deletion lib/THCUNN/generic/SpatialConvolutionLocal.cu
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, fgradInput_n),
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
1, 1, THCTensor_(data)(state, gradInput_n)
);

Expand Down
2 changes: 1 addition & 1 deletion lib/THCUNN/generic/SpatialConvolutionMM.cu
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ void THNN_(SpatialConvolutionMM_updateGradInput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, gradColumns),
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
1, 1, THCTensor_(data)(state, gradInput_n)
);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/THCUNN/generic/SpatialDepthWiseConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, gradColumns),
1, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
1, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
1, 1, THCTensor_(data)(state, gradInput_i)
);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/THCUNN/generic/SpatialDilatedConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, gradColumns),
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW,
THCTensor_(data)(state, gradInput_n)
);
Expand Down
8 changes: 4 additions & 4 deletions lib/THCUNN/generic/SpatialFullDilatedConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ static inline void THNN_(SpatialFullDilatedConvolution_shapeCheck)(
"kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
THArgCheck(dW > 0 && dH > 0, 11,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
THArgCheck(adjW < dW && adjH < dH, 15,
"output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d",
adjH, adjW, dH, dW);
THArgCheck(dilationW > 0 && dilationH > 0, 15,
"dilation should be greater than zero, but got dilationH: %d, dilationW: %d",
dilationH, dilationW);
THArgCheck((adjW < dW || adjW < dilationW) && (adjH < dH || adjH < dilationH), 15,
"output padding must be smaller than either stride or dilation, but got adjH: %d adjW: %d dH: %d dW: %d dilationH: %d dilationW: %d",
adjH, adjW, dH, dW, dilationH, dilationW);
THArgCheck(THCTensor_(isContiguous)(state, weight), 4,
"weight tensor has to be contiguous");
THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5,
Expand Down Expand Up @@ -160,7 +160,7 @@ void THNN_(SpatialFullDilatedConvolution_updateOutput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, columns),
nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
nOutputPlane, outputHeight, outputWidth, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, THCTensor_(data)(state, output_n)
);

Expand Down
1 change: 1 addition & 0 deletions lib/THCUNN/generic/VolumetricDilatedConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)(
THCState_getCurrentStream(state),
THCTensor_(data)(state, gradColumns),
nInputPlane, inputDepth, inputHeight, inputWidth,
outputDepth, outputHeight, outputWidth,
kT, kH, kW, padT, padH, padW, dT, dH, dW,
dilationT, dilationH, dilationW,
THCTensor_(data)(state, gradInput_n)
Expand Down
15 changes: 10 additions & 5 deletions lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@ static inline void THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
"bias tensor has to be contiguous");
THArgCheck(dT > 0 && dW > 0 && dH > 0, 8,
"stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW);
THArgCheck(adjT < dT && adjW < dW && adjH < dH, 14,
"output adjustment must be smaller than stride, but got "
"adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d",
adjT, adjH, adjW, dT, dH, dW);
THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 15,
"dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d",
dilationT, dilationH, dilationW);
THArgCheck((adjT < dT || adjT < dilationT)
&& (adjW < dW || adjW < dilationW)
&& (adjH < dH || adjH < dilationH), 15,
"output padding must be smaller than either stride or dilation,"
" but got adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d "
"dilationT: %d dilationH: %d dilationW: %d",
adjT, adjH, adjW, dT, dH, dW, dilationT, dilationH, dilationW);

int ndim = input->nDimension;
int nInputPlane = THCTensor_(size)(state, weight, 0);
Expand Down Expand Up @@ -178,7 +181,9 @@ void THNN_(VolumetricFullDilatedConvolution_updateOutput)(
col2vol<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, columns),
nOutputPlane, outputDepth, outputHeight, outputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW,
nOutputPlane, outputDepth, outputHeight, outputWidth,
inputDepth, inputHeight, inputWidth,
kT, kH, kW, padT, padH, padW, dT, dH, dW,
dilationT, dilationH, dilationW,
THCTensor_(data)(state, output_n)
);
Expand Down
7 changes: 2 additions & 5 deletions lib/THCUNN/im2col.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,18 @@ __global__ void col2im_kernel(const int n, const Dtype* data_col,
template <typename Dtype, typename Acctype>
void col2im(cudaStream_t stream, const Dtype* data_col, const int channels,
const int height, const int width,
const int output_height, const int output_width,
const int patch_h, const int patch_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, Dtype* data_im) {
int height_col = (height + 2 * pad_h - (dilation_h * (patch_h - 1) + 1))
/ stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (patch_w - 1) + 1))
/ stride_w + 1;
int num_kernels = channels * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im_kernel<Dtype, Acctype> <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>> (
num_kernels, data_col, height, width, channels,
patch_h, patch_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w,
height_col, width_col, data_im
output_height, output_width, data_im
);
THCudaCheck(cudaGetLastError());
}
Expand Down
6 changes: 2 additions & 4 deletions lib/THCUNN/vol2col.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,20 @@ __global__ void vol2im_kernel(const int n, const Dtype* data_col,
template <typename Dtype, typename Acctype>
void col2vol(cudaStream_t stream, const Dtype* data_col, const int channels,
const int depth, const int height, const int width,
const int output_depth, const int output_height, const int output_width,
const int patch_t, const int patch_h, const int patch_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
Dtype* data_vol) {
int depth_col = (depth + 2 * pad_t - (dilation_t * (patch_t - 1) + 1)) / stride_t + 1;
int height_col = (height + 2 * pad_h - (dilation_h * (patch_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (patch_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * depth * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
vol2im_kernel<Dtype, Acctype> <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>> (
num_kernels, data_col, depth, height, width, channels,
patch_t, patch_h, patch_w, pad_t, pad_h, pad_w, stride_t, stride_h, stride_w,
dilation_t, dilation_h, dilation_w,
depth_col, height_col, width_col, data_vol
output_depth, output_height, output_width, data_vol
);
THCudaCheck(cudaGetLastError());
}
Expand Down

0 comments on commit e9ef2d5

Please sign in to comment.