Skip to content

Commit

Permalink
feat: implement grouped Conv1D for DNNL
Browse files Browse the repository at this point in the history
  • Loading branch information
ebraraktas committed Jul 29, 2024
1 parent 9581bce commit b8d059d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/ops/conv1d_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ctranslate2 {

dnnl::memory::dims input_dims(input.shape().begin(), input.shape().end());
dnnl::memory::dims output_dims(output.shape().begin(), output.shape().end());
dnnl::memory::dims weight_dims(weight.shape().begin(), weight.shape().end());
dnnl::memory::dims weight_dims{_groups, weight.dim(0) / _groups, weight.dim(1), weight.dim(2)};

using tag = dnnl::memory::format_tag;
using dt = dnnl::memory::data_type;
Expand All @@ -32,7 +32,7 @@ namespace ctranslate2 {
const_cast<void*>(input.buffer()));
dnnl::memory output_mem({output_dims, dt::f32, tag::ncw}, engine,
output.buffer());
dnnl::memory weight_mem({weight_dims, dt::f32, tag::oiw}, engine,
dnnl::memory weight_mem({weight_dims, dt::f32, tag::goiw}, engine,
const_cast<void*>(weight.buffer()));

dnnl::memory::dims stride{_stride};
Expand Down
3 changes: 3 additions & 0 deletions tests/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,9 @@ TEST_P(OpDeviceFPTest, Conv1DGroupNoBias) {
}

TEST_P(OpDeviceFPTest, Conv1DGroupNoBiasQuantized) {
#ifdef CT2_WITH_DNNL
GTEST_SKIP() << "Quantized convolution is not implemented for DNNL.";
#endif
const Device device = GetParam().device;
if (device != Device::CPU)
GTEST_SKIP() << "Grouped convolution is not implemented for CUDA.";
Expand Down

0 comments on commit b8d059d

Please sign in to comment.