Skip to content

Commit df5491d

Browse files
..
1 parent 62eaf33 commit df5491d

File tree

6 files changed

+9
-4
lines changed

6 files changed

+9
-4
lines changed

Diff for: include/ctranslate2/layers/attention.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace ctranslate2 {
3030
Alibi* alibi = nullptr);
3131
DataType output_type() const override;
3232
dim_t output_size() const override;
33-
void operator()(const StorageView& queries,
33+
virtual void operator()(const StorageView& queries,
3434
const StorageView& values,
3535
const StorageView* values_lengths,
3636
const StorageView* values_offsets,

Diff for: include/ctranslate2/layers/attention_layer.h

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace ctranslate2 {
2929
virtual void operator()(const StorageView& queries,
3030
const StorageView& values,
3131
const StorageView* values_lengths,
32+
const StorageView* values_offsets,
3233
StorageView& output,
3334
StorageView* cached_keys = nullptr,
3435
StorageView* cached_values = nullptr,

Diff for: include/ctranslate2/layers/flash_attention.h

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace ctranslate2 {
2121
void operator()(const StorageView& queries,
2222
const StorageView& values,
2323
const StorageView* values_lengths,
24+
const StorageView* values_offsets,
2425
StorageView& output,
2526
StorageView* cached_keys = nullptr,
2627
StorageView* cached_values = nullptr,

Diff for: src/layers/transformer.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ namespace ctranslate2 {
165165
(*_self_attention)(hidden,
166166
hidden,
167167
input_length,
168+
input_offsets,
168169
context,
169170
cached_self_attn_keys,
170171
cached_self_attn_values,
@@ -576,6 +577,8 @@ namespace ctranslate2 {
576577

577578
std::unique_ptr<StorageView> left_padding_self_attn;
578579

580+
bool multi_query = _layers.front()->get_self_attention().multi_query();
581+
579582
if (lengths) {
580583
if (allow_padding_removal) {
581584
input_padder = std::make_unique<Padder>(*lengths, left_padding, max_time);
@@ -592,7 +595,7 @@ namespace ctranslate2 {
592595
num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks());
593596
}
594597

595-
StorageView lengths_mask = layers::MultiHeadAttention::prepare_values_mask(
598+
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
596599
*lengths,
597600
num_heads,
598601
max_time,

Diff for: src/models/whisper.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ namespace ctranslate2 {
536536
std::vector<int32_t>(num_frames.begin(), num_frames.end()),
537537
device);
538538
const StorageView frame_sizes_mask(
539-
layers::MultiHeadAttention::prepare_values_mask(frame_sizes,
539+
layers::MultiHeadAttention::prepare_length_mask(frame_sizes,
540540
attention_weights.dim(1),
541541
attention_weights.dim(2)));
542542

Diff for: tests/ops_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) {
697697
0.8421174, 0.9135181, 0.77135813
698698
}, device);
699699
StorageView lengths({2}, std::vector<int32_t>{3, 2}, device);
700-
StorageView mask = layers::MultiHeadAttention::prepare_values_mask(lengths, 2, 3, true);
700+
StorageView mask = layers::MultiHeadAttention::prepare_length_mask(lengths, 2, 3, true);
701701
StorageView expected({2, 2, 3, 3}, std::vector<float>{
702702
1, 0, 0,
703703
0.28861094, 0.71138906, 0,

0 commit comments

Comments
 (0)