diff --git a/include/ctranslate2/decoding_utils.h b/include/ctranslate2/decoding_utils.h index 3d146c4a1..fc4d862a1 100644 --- a/include/ctranslate2/decoding_utils.h +++ b/include/ctranslate2/decoding_utils.h @@ -38,6 +38,8 @@ namespace ctranslate2 { DisableTokens(StorageView& logits, const float disable_value = std::numeric_limits::lowest()); + void reserve(dim_t additional); + void add(dim_t batch_id, dim_t token_id) { const auto flat_index = batch_id * _vocabulary_size + token_id; @@ -46,10 +48,8 @@ namespace ctranslate2 { _logits_data[flat_index] = _disable_value; } else { - // On GPU we prepare a list of unique index to disable. - const auto it = std::lower_bound(_flat_indices.begin(), _flat_indices.end(), flat_index); - if (it == _flat_indices.end() || *it != flat_index) - _flat_indices.insert(it, flat_index); + // On GPU we collect indices that will be processed in a single kernel. + _flat_indices.push_back(static_cast(flat_index)); } } diff --git a/src/decoding_utils.cc b/src/decoding_utils.cc index fed4670d3..f3d6ba646 100644 --- a/src/decoding_utils.cc +++ b/src/decoding_utils.cc @@ -1,5 +1,6 @@ #include "ctranslate2/decoding_utils.h" +#include #include #include "ctranslate2/ops/ops.h" @@ -16,7 +17,21 @@ namespace ctranslate2 { { } + void DisableTokens::reserve(dim_t additional) { + if (!_logits_data && additional > 0) + _flat_indices.reserve(_flat_indices.size() + static_cast(additional)); + } + void DisableTokens::apply() { + if (_logits_data) + return; + + if (_flat_indices.empty()) + return; + + std::sort(_flat_indices.begin(), _flat_indices.end()); + _flat_indices.erase(std::unique(_flat_indices.begin(), _flat_indices.end()), _flat_indices.end()); + const dim_t num_indices = _flat_indices.size(); if (num_indices == 0) return; @@ -159,11 +174,17 @@ namespace ctranslate2 { } void SuppressTokens::apply(dim_t, - StorageView&, + StorageView& logits, DisableTokens& disable_tokens, const StorageView&, const std::vector&, const std::vector>*) { + if (_ids.empty()) + return; + + const dim_t batch_size = logits.dim(0); + disable_tokens.reserve(batch_size * static_cast(_ids.size())); + for (const auto token_id : _ids) disable_tokens.add(token_id); } @@ -182,6 +203,11 @@ namespace ctranslate2 { const std::vector>* prefix) { const dim_t batch_size = logits.dim(0); + if (_ids.empty()) + return; + + disable_tokens.reserve(batch_size * static_cast(_ids.size())); + for (dim_t batch_id = 0; batch_id < batch_size; ++batch_id) { const dim_t sample_begin = get_sample_begin(batch_size, batch_id, batch_offset, prefix);