Skip to content

Commit

Permalink
separate the asymmetric relative positions
Browse files Browse the repository at this point in the history
  • Loading branch information
hkwon committed Sep 11, 2024
1 parent 4deda1f commit c8f74b5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
9 changes: 6 additions & 3 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ namespace ctranslate2 {

StorageView make_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t max_position,
dim_t left_max_position,
dim_t right_max_position);
dim_t max_position);

StorageView make_asymmetric_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t left_max_position,
dim_t right_max_position);

class RotaryEmbeddings;
class Alibi;
Expand Down
36 changes: 23 additions & 13 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,35 @@ namespace ctranslate2 {

StorageView make_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t max_position,
dim_t left_max_position = 0,
dim_t right_max_position = 0) {
dim_t max_position) {
StorageView positions({queries_length, keys_length}, DataType::INT32);
auto* positions_data = positions.data<int32_t>();

const dim_t offset = keys_length - queries_length;
bool asymmetric = (left_max_position != 0 || right_max_position != 0);

for (dim_t i = 0; i < queries_length; ++i) {
auto* row = positions_data + i * keys_length;
for (dim_t j = 0; j < keys_length; ++j) {
if (asymmetric)
row[j] = std::max(std::min(j - i, right_max_position), -left_max_position) + left_max_position;
else
row[j] = std::min(std::max(j - (i + offset), -max_position), max_position) + max_position;
row[j] = std::min(std::max(j - (i + offset), -max_position), max_position) + max_position;
}
}

return positions;
}

StorageView make_asymmetric_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t left_max_position,
dim_t right_max_position) {
StorageView positions({queries_length, keys_length}, DataType::INT32);
auto* positions_data = positions.data<int32_t>();

const dim_t offset = keys_length - queries_length;

for (dim_t i = 0; i < queries_length; ++i) {
auto* row = positions_data + i * keys_length;
for (dim_t j = 0; j < keys_length; ++j) {
row[j] = std::max(std::min(j - i, right_max_position), -left_max_position) + left_max_position;
}
}

Expand Down Expand Up @@ -192,17 +205,14 @@ namespace ctranslate2 {
const dim_t key_length = keys.dim(2);
if (relative_asymmetric_position_keys)
relative_positions = std::make_unique<StorageView>(
make_relative_positions(query_length,
make_asymmetric_relative_positions(query_length,
key_length,
/*maximum_relative_position=*/0,
relative_left_max_position,
relative_right_max_position).to(queries.device()));
else relative_positions = std::make_unique<StorageView>(
make_relative_positions(query_length,
key_length,
maximum_relative_position,
/*relative_left_max_position=*/0,
/*relative_right_max_position=*/0).to(queries.device()));
maximum_relative_position).to(queries.device()));
}

const ops::MatMul keys_matmul(/*trans_a=*/false, /*trans_b=*/true, queries_scale);
Expand Down
2 changes: 1 addition & 1 deletion tests/layers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ TEST(LayerTest, MakeRelativePositions2D) {
}

TEST(LayerTest, MakeAsymmetricRelativePositions2D) {
const StorageView positions = layers::make_relative_positions(4, 4, 0, 3, 2);
const StorageView positions = layers::make_asymmetric_relative_positions(4, 4, 3, 2);
const StorageView expected({4, 4}, std::vector<int32_t>{
3, 4, 5, 5,
2, 3, 4, 5,
Expand Down

0 comments on commit c8f74b5

Please sign in to comment.