Skip to content

Commit

Permalink
fix shape of c in gemm ops with a transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthuc2502 committed Mar 7, 2024
1 parent b4b3ac0 commit 309ab75
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/ops/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ namespace ctranslate2 {
{
Shape output_shape(a.shape());
output_shape[output_shape.size() - 1] = n;
output_shape[output_shape.size() - 2] = a.dim(_trans_a ? -1 : -2);
c.resize(std::move(output_shape));
}

Expand Down
32 changes: 32 additions & 0 deletions tests/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,38 @@ TEST_P(OpDeviceFPTest, Gemm) {
expect_storage_eq(y.to_float32(), expected, error);
};

TEST_P(OpDeviceFPTest, GemmATranspose) {
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;
StorageView a(
{2, 4}, std::vector<float>{1, 0, 0, 0, 0, 1, 0, 0}, device);
StorageView b(a);
StorageView y({4, 4}, 2.f, device);
StorageView expected(
{2, 2}, std::vector<float>{3, 2, 2, 3}, device);
ops::Gemm op(1.0, 1.0, false, true);
y = y.to(dtype);
op(a.to(dtype), b.to(dtype), y);
expect_storage_eq(y.to_float32(), expected, error);
}

TEST_P(OpDeviceFPTest, GemmBTranspose) {
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;
StorageView a(
{2, 4}, std::vector<float>{1, 0, 0, 0, 0, 1, 0, 0}, device);
StorageView b(a);
StorageView y({4, 4}, 2.f, device);
StorageView expected(
{4, 4}, std::vector<float>{3, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, device);
ops::Gemm op(1.0, 1.0, true, false);
y = y.to(dtype);
op(a.to(dtype), b.to(dtype), y);
expect_storage_eq(y.to_float32(), expected, error);
}

TEST_P(OpDeviceTest, GemmInt8) {
Device device = GetParam();
if (!mayiuse_int8(device))
Expand Down

0 comments on commit 309ab75

Please sign in to comment.