Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement grouped Conv1D for CUDA
Browse files Browse the repository at this point in the history
ebraraktas committed Jul 29, 2024
1 parent b8d059d commit e29cb28
Showing 2 changed files with 5 additions and 6 deletions.
5 changes: 4 additions & 1 deletion src/ops/conv1d_gpu.cu
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ namespace ctranslate2 {
const int input_length = input.dim(2);
const int output_length = output.dim(2);
const int out_channels = weight.dim(0);
const int in_channels_per_group = weight.dim(1);
const int kernel_size = weight.dim(2);

cudnnDataType_t data_type = cuda::get_cudnn_data_type(input.dtype());
@@ -45,7 +46,7 @@ namespace ctranslate2 {
cudnnFilterDescriptor_t weight_desc;
CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW,
out_channels, in_channels, 1, kernel_size));
out_channels, in_channels_per_group, 1, kernel_size));

cudnnConvolutionDescriptor_t conv_desc;
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
@@ -57,6 +58,8 @@ namespace ctranslate2 {
CUDNN_DATA_FLOAT));

CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH));
if (_groups > 1)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, _groups));
if (data_type == CUDNN_DATA_HALF)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));

6 changes: 1 addition & 5 deletions tests/ops_test.cc
Original file line number Diff line number Diff line change
@@ -1105,8 +1105,6 @@ TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) {

TEST_P(OpDeviceFPTest, Conv1DGroupNoBias) {
const Device device = GetParam().device;
if (device != Device::CPU)
GTEST_SKIP() << "Grouped convolution is not implemented for CUDA.";
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;
const StorageView expected({2, 2, 2}, std::vector<float>{
@@ -1136,7 +1134,7 @@ TEST_P(OpDeviceFPTest, Conv1DGroupNoBiasQuantized) {
#endif
const Device device = GetParam().device;
if (device != Device::CPU)
GTEST_SKIP() << "Grouped convolution is not implemented for CUDA.";
GTEST_SKIP() << "Grouped quantized convolution is not implemented for CUDA.";
const DataType dtype = GetParam().dtype;
const float error = std::max(GetParam().error, float(3e-3));
const StorageView expected({2, 2, 2}, std::vector<float>{
@@ -1166,8 +1164,6 @@ TEST_P(OpDeviceFPTest, Conv1DGroupNoBiasQuantized) {

TEST_P(OpDeviceFPTest, Conv1DGroup) {
const Device device = GetParam().device;
if (device != Device::CPU)
GTEST_SKIP() << "Grouped convolution is not implemented for CUDA.";
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;
const StorageView expected({2, 2, 2}, std::vector<float>{

0 comments on commit e29cb28

Please sign in to comment.