Skip to content

Commit e29cb28

Browse files
committed
feat: implement grouped Conv1D for CUDA
1 parent b8d059d commit e29cb28

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

src/ops/conv1d_gpu.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace ctranslate2 {
2828
const int input_length = input.dim(2);
2929
const int output_length = output.dim(2);
3030
const int out_channels = weight.dim(0);
31+
const int in_channels_per_group = weight.dim(1);
3132
const int kernel_size = weight.dim(2);
3233

3334
cudnnDataType_t data_type = cuda::get_cudnn_data_type(input.dtype());
@@ -45,7 +46,7 @@ namespace ctranslate2 {
4546
cudnnFilterDescriptor_t weight_desc;
4647
CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc));
4748
CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW,
48-
out_channels, in_channels, 1, kernel_size));
49+
out_channels, in_channels_per_group, 1, kernel_size));
4950

5051
cudnnConvolutionDescriptor_t conv_desc;
5152
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
@@ -57,6 +58,8 @@ namespace ctranslate2 {
5758
CUDNN_DATA_FLOAT));
5859

5960
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH));
61+
if (_groups > 1)
62+
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, _groups));
6063
if (data_type == CUDNN_DATA_HALF)
6164
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
6265

tests/ops_test.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,8 +1105,6 @@ TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) {
11051105

11061106
TEST_P(OpDeviceFPTest, Conv1DGroupNoBias) {
11071107
const Device device = GetParam().device;
1108-
if (device != Device::CPU)
1109-
GTEST_SKIP() << "Grouped convolution is not implemented for CUDA.";
11101108
const DataType dtype = GetParam().dtype;
11111109
const float error = GetParam().error;
11121110
const StorageView expected({2, 2, 2}, std::vector<float>{
@@ -1136,7 +1134,7 @@ TEST_P(OpDeviceFPTest, Conv1DGroupNoBiasQuantized) {
11361134
#endif
11371135
const Device device = GetParam().device;
11381136
if (device != Device::CPU)
1139-
GTEST_SKIP() << "Grouped convolution is not implemented for CUDA.";
1137+
GTEST_SKIP() << "Grouped quantized convolution is not implemented for CUDA.";
11401138
const DataType dtype = GetParam().dtype;
11411139
const float error = std::max(GetParam().error, float(3e-3));
11421140
const StorageView expected({2, 2, 2}, std::vector<float>{
@@ -1166,8 +1164,6 @@ TEST_P(OpDeviceFPTest, Conv1DGroupNoBiasQuantized) {
11661164

11671165
TEST_P(OpDeviceFPTest, Conv1DGroup) {
11681166
const Device device = GetParam().device;
1169-
if (device != Device::CPU)
1170-
GTEST_SKIP() << "Grouped convolution is not implemented for CUDA.";
11711167
const DataType dtype = GetParam().dtype;
11721168
const float error = GetParam().error;
11731169
const StorageView expected({2, 2, 2}, std::vector<float>{

0 commit comments

Comments
 (0)