diff --git a/src/ops/conv1d_gpu.cu b/src/ops/conv1d_gpu.cu index 6f4d10b39..11ba48a1f 100644 --- a/src/ops/conv1d_gpu.cu +++ b/src/ops/conv1d_gpu.cu @@ -28,6 +28,7 @@ namespace ctranslate2 { const int input_length = input.dim(2); const int output_length = output.dim(2); const int out_channels = weight.dim(0); + const int in_channels_per_group = weight.dim(1); const int kernel_size = weight.dim(2); cudnnDataType_t data_type = cuda::get_cudnn_data_type(input.dtype()); @@ -45,7 +46,7 @@ namespace ctranslate2 { cudnnFilterDescriptor_t weight_desc; CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc)); CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW, - out_channels, in_channels, 1, kernel_size)); + out_channels, in_channels_per_group, 1, kernel_size)); cudnnConvolutionDescriptor_t conv_desc; CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); @@ -57,6 +58,8 @@ namespace ctranslate2 { CUDNN_DATA_FLOAT)); CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); + if (_groups > 1) + CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, _groups)); if (data_type == CUDNN_DATA_HALF) CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH)); diff --git a/tests/ops_test.cc b/tests/ops_test.cc index 8c18fd6da..7d7b376fa 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -1105,8 +1105,6 @@ TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) { TEST_P(OpDeviceFPTest, Conv1DGroupNoBias) { const Device device = GetParam().device; - if (device != Device::CPU) - GTEST_SKIP() << "Grouped convolution is not implemented for CUDA."; const DataType dtype = GetParam().dtype; const float error = GetParam().error; const StorageView expected({2, 2, 2}, std::vector{ @@ -1136,7 +1134,7 @@ TEST_P(OpDeviceFPTest, Conv1DGroupNoBiasQuantized) { #endif const Device device = GetParam().device; if (device != Device::CPU) - GTEST_SKIP() << "Grouped convolution is not implemented for CUDA."; + GTEST_SKIP() << "Grouped quantized convolution is not implemented for CUDA."; const DataType dtype = GetParam().dtype; const float error = std::max(GetParam().error, float(3e-3)); const StorageView expected({2, 2, 2}, std::vector{ @@ -1166,8 +1164,6 @@ TEST_P(OpDeviceFPTest, Conv1DGroupNoBiasQuantized) { TEST_P(OpDeviceFPTest, Conv1DGroup) { const Device device = GetParam().device; - if (device != Device::CPU) - GTEST_SKIP() << "Grouped convolution is not implemented for CUDA."; const DataType dtype = GetParam().dtype; const float error = GetParam().error; const StorageView expected({2, 2, 2}, std::vector{