diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index db4056083..5778a028c 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -8,9 +8,12 @@ namespace ctranslate2 { StorageView make_relative_positions(dim_t queries_length, dim_t keys_length, - dim_t max_position, - dim_t left_max_position, - dim_t right_max_position); + dim_t max_position); + + StorageView make_asymmetric_relative_positions(dim_t queries_length, + dim_t keys_length, + dim_t left_max_position, + dim_t right_max_position); class RotaryEmbeddings; class Alibi; diff --git a/src/layers/attention.cc b/src/layers/attention.cc index f2ff94995..6ad344410 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -15,22 +15,35 @@ namespace ctranslate2 { StorageView make_relative_positions(dim_t queries_length, dim_t keys_length, - dim_t max_position, - dim_t left_max_position = 0, - dim_t right_max_position = 0) { + dim_t max_position) { StorageView positions({queries_length, keys_length}, DataType::INT32); auto* positions_data = positions.data(); const dim_t offset = keys_length - queries_length; - bool asymmetric = (left_max_position != 0 || right_max_position != 0); for (dim_t i = 0; i < queries_length; ++i) { auto* row = positions_data + i * keys_length; for (dim_t j = 0; j < keys_length; ++j) { - if (asymmetric) - row[j] = std::max(std::min(j - i, right_max_position), -left_max_position) + left_max_position; - else - row[j] = std::min(std::max(j - (i + offset), -max_position), max_position) + max_position; + row[j] = std::min(std::max(j - (i + offset), -max_position), max_position) + max_position; + } + } + + return positions; + } + + StorageView make_asymmetric_relative_positions(dim_t queries_length, + dim_t keys_length, + dim_t left_max_position, + dim_t right_max_position) { + StorageView positions({queries_length, keys_length}, DataType::INT32); + auto* positions_data = positions.data(); + + const dim_t offset = keys_length - queries_length; + + for (dim_t i = 0; i < queries_length; ++i) { + auto* row = positions_data + i * keys_length; + for (dim_t j = 0; j < keys_length; ++j) { + row[j] = std::max(std::min(j - i, right_max_position), -left_max_position) + left_max_position; } } @@ -192,17 +205,14 @@ namespace ctranslate2 { const dim_t key_length = keys.dim(2); if (relative_asymmetric_position_keys) relative_positions = std::make_unique( - make_relative_positions(query_length, + make_asymmetric_relative_positions(query_length, key_length, - /*maximum_relative_position=*/0, relative_left_max_position, relative_right_max_position).to(queries.device())); else relative_positions = std::make_unique( make_relative_positions(query_length, key_length, - maximum_relative_position, - /*relative_left_max_position=*/0, - /*relative_right_max_position=*/0).to(queries.device())); + maximum_relative_position).to(queries.device())); } const ops::MatMul keys_matmul(/*trans_a=*/false, /*trans_b=*/true, queries_scale); diff --git a/tests/layers_test.cc b/tests/layers_test.cc index 222ae1549..cbbaa2d72 100644 --- a/tests/layers_test.cc +++ b/tests/layers_test.cc @@ -22,7 +22,7 @@ TEST(LayerTest, MakeRelativePositions2D) { } TEST(LayerTest, MakeAsymmetricRelativePositions2D) { - const StorageView positions = layers::make_relative_positions(4, 4, 0, 3, 2); + const StorageView positions = layers::make_asymmetric_relative_positions(4, 4, 3, 2); const StorageView expected({4, 4}, std::vector{ 3, 4, 5, 5, 2, 3, 4, 5,