From 739a5b1a3166f6fb0433ef5a102eaba0b9a61408 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 6 Jul 2022 17:32:19 +0200 Subject: [PATCH] Fix application of max_decoding_length in return_alternatives mode (#866) --- src/decoding.cc | 283 ++++++++++++++++++++++----------------- tests/translator_test.cc | 34 +++++ 2 files changed, 191 insertions(+), 126 deletions(-) diff --git a/src/decoding.cc b/src/decoding.cc index 126e88897..ceac800a4 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -915,6 +915,133 @@ namespace ctranslate2 { options.allow_early_exit); } + static DecodingResult + decode_alternatives(layers::Decoder& decoder, + layers::DecoderState& state, + const size_t start_token, + const std::vector& prefix_tokens, + const size_t end_id, + const DecodingOptions& options, + const std::vector* output_ids_map) { + DecodingResult result; + result.hypotheses.resize(options.num_hypotheses); + if (options.return_scores) + result.scores.resize(options.num_hypotheses, 0); + if (options.return_attention) + result.attention.resize(options.num_hypotheses); + + std::vector start_ids{start_token}; + std::vector prefix_ids(prefix_tokens); + if (prefix_ids.size() > options.max_length) + prefix_ids.resize(options.max_length); + + const dim_t min_length = options.min_length; + const dim_t max_length = options.max_length; + dim_t start_step = 0; + + if (!prefix_ids.empty()) { + // Initialize the decoder state with the prefix. + std::vector> prefix_attention; + initialize_decoder_with_prefix(decoder, + state, + start_ids[0], + prefix_ids, + options.return_attention ? &prefix_attention : nullptr); + + for (size_t i = 0; i < options.num_hypotheses; ++i) { + result.hypotheses[i] = prefix_ids; + if (options.return_attention) + result.attention[i] = prefix_attention; + } + + start_ids[0] = prefix_ids.back(); + start_step += prefix_ids.size(); + if (start_step == max_length) + return result; + } + + // Expand the next "num_hypotheses" candidate words using the beam search. + BeamSearch beam(options.num_hypotheses); + DecodingResult expansion_result = beam.search(decoder, + state, + BestSampler(), + start_ids, + end_id, + start_step, + /*max_length=*/1, + /*min_length=*/1, + output_ids_map, + options.normalize_scores, + options.return_scores, + options.return_attention, + options.num_hypotheses)[0]; + + start_ids.resize(options.num_hypotheses); + + for (size_t i = 0; i < options.num_hypotheses; ++i) { + // Add expanded word to the result. + result.hypotheses[i].emplace_back(expansion_result.hypotheses[i].back()); + if (options.return_attention) + result.attention[i].emplace_back(std::move(expansion_result.attention[i].back())); + if (options.return_scores) + result.scores[i] = expansion_result.scores[i]; + + // The next input is the words we just expanded. + start_ids[i] = result.hypotheses[i].back(); + } + + start_step += 1; + if (start_step == max_length) + return result; + + // Continue the decoding from each alternative words independently. + const auto search_strategy = make_search_strategy(options); + const auto sampler = make_sampler(options); + auto suffix_results = search_strategy->search(decoder, + state, + *sampler, + start_ids, + end_id, + start_step, + std::max(max_length - start_step, dim_t(0)), + std::max(min_length - start_step, dim_t(0)), + output_ids_map, + options.normalize_scores, + options.return_scores, + options.return_attention, + /*num_hypotheses=*/1, + options.repetition_penalty, + options.no_repeat_ngram_size); + + // Update the result with the suffix decoding. + for (size_t i = 0; i < options.num_hypotheses; ++i) { + auto& suffix = suffix_results[i]; + + if (options.return_scores) { + if (options.normalize_scores) { + const auto prefix_length = result.hypotheses[i].size(); + const auto suffix_length = suffix.hypotheses[0].size(); + result.scores[i] = ( + (result.scores[i] * prefix_length + suffix.scores[0] * suffix_length) + / (prefix_length + suffix_length)); + } else { + result.scores[i] += suffix.scores[0]; + } + } + + if (options.return_attention) + result.attention[i].insert(result.attention[i].end(), + std::make_move_iterator(suffix.attention[0].begin()), + std::make_move_iterator(suffix.attention[0].end())); + + result.hypotheses[i].insert(result.hypotheses[i].end(), + std::make_move_iterator(suffix.hypotheses[0].begin()), + std::make_move_iterator(suffix.hypotheses[0].end())); + } + + return result; + } + std::vector decode(layers::Decoder& decoder, layers::DecoderState& state, @@ -925,140 +1052,44 @@ namespace ctranslate2 { validate_decoding_options(options, output_ids_map); const size_t batch_size = start_tokens.size(); - if (options.return_alternatives && batch_size > 1) { - // return_alternatives mode currently does not support batch decoding. - std::vector results; - results.reserve(batch_size); - for (size_t i = 0; i < batch_size; ++i) { - layers::DecoderState batch_state = get_batch_state(state, i); - results.emplace_back(decode(decoder, - batch_state, - {start_tokens[i]}, - end_id, - options, - output_ids_map)[0]); - } - return results; - } + std::vector results; std::vector start_ids; std::vector> prefix_ids; std::tie(start_ids, prefix_ids) = split_start_tokens(start_tokens); - dim_t start_step = 0; - dim_t min_length = options.min_length; - dim_t max_length = options.max_length; - - std::vector expansion_results; if (options.return_alternatives) { - std::vector>> prefix_attention; - if (!prefix_ids.empty()) { - if (options.return_attention) - prefix_attention.resize(1); - initialize_decoder_with_prefix(decoder, - state, - start_ids[0], - prefix_ids[0], - options.return_attention ? &prefix_attention[0] : nullptr); - start_ids[0] = prefix_ids[0].back(); - const dim_t prefix_length = prefix_ids[0].size(); - start_step += prefix_length; - max_length = std::max(max_length - prefix_length, dim_t(0)); - min_length = std::max(min_length - prefix_length, dim_t(0)); - } - - // In this translation mode, we first expand the next "num_hypotheses" candidate words - // before running the full decoding on each prefix. This is to ensure that we get unique - // alternatives at this decoding position. - expansion_results = BeamSearch(options.num_hypotheses).search(decoder, - state, - BestSampler(), - start_ids, - end_id, - start_step, - /*max_length=*/1, - /*min_length=*/1, - output_ids_map, - options.normalize_scores, - options.return_scores, - options.return_attention, - options.num_hypotheses); - - start_ids.resize(batch_size * options.num_hypotheses); - for (size_t b = 0; b < batch_size; ++b) { - auto& result = expansion_results[b]; - - for (size_t i = 0; i < options.num_hypotheses; ++i) { - // The next input is the words we just expanded. - start_ids[b * options.num_hypotheses + i] = result.hypotheses[i].back(); - - // Prepend expansion result with the prefix. - if (!prefix_ids.empty()) { - result.hypotheses[i].insert(result.hypotheses[i].begin(), - prefix_ids[b].begin(), - prefix_ids[b].end()); - if (options.return_attention) { - result.attention[i].insert(result.attention[i].begin(), - prefix_attention[b].begin(), - prefix_attention[b].end()); - } - } - } + results.reserve(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + layers::DecoderState batch_state = get_batch_state(state, i); + results.emplace_back(decode_alternatives(decoder, + batch_state, + start_ids[i], + prefix_ids[i], + end_id, + options, + output_ids_map)); } - start_step += 1; - max_length = std::max(max_length - 1, dim_t(0)); - min_length = std::max(min_length - 1, dim_t(0)); - } - - const auto search_strategy = make_search_strategy(options); - const auto sampler = make_sampler(options); - auto results = search_strategy->search(decoder, - state, - *sampler, - start_ids, - end_id, - start_step, - max_length, - min_length, - output_ids_map, - options.normalize_scores, - options.return_scores, - options.return_attention, - options.return_alternatives ? 1 : options.num_hypotheses, - options.repetition_penalty, - options.no_repeat_ngram_size, - options.return_alternatives ? nullptr : &prefix_ids); - - if (options.return_alternatives) { - // Append to expansion results. - for (size_t b = 0; b < batch_size; ++b) { - auto& prefix = expansion_results[b]; - for (size_t i = 0; i < options.num_hypotheses; ++i) { - auto& suffix = results[b * options.num_hypotheses + i]; - - if (!prefix.scores.empty()) { - if (options.normalize_scores) { - const auto prefix_length = prefix.hypotheses[i].size(); - const auto suffix_length = suffix.hypotheses[0].size(); - prefix.scores[i] = ( - (prefix.scores[i] * prefix_length + suffix.scores[0] * suffix_length) - / (prefix_length + suffix_length)); - } else { - prefix.scores[i] += suffix.scores[0]; - } - } - - prefix.hypotheses[i].insert(prefix.hypotheses[i].end(), - std::make_move_iterator(suffix.hypotheses[0].begin()), - std::make_move_iterator(suffix.hypotheses[0].end())); - if (!prefix.attention.empty()) - prefix.attention[i].insert(prefix.attention[i].end(), - std::make_move_iterator(suffix.attention[0].begin()), - std::make_move_iterator(suffix.attention[0].end())); - } - } - results = std::move(expansion_results); + } else { + const auto search_strategy = make_search_strategy(options); + const auto sampler = make_sampler(options); + results = search_strategy->search(decoder, + state, + *sampler, + start_ids, + end_id, + /*start_step=*/0, + options.max_length, + options.min_length, + output_ids_map, + options.normalize_scores, + options.return_scores, + options.return_attention, + options.num_hypotheses, + options.repetition_penalty, + options.no_repeat_ngram_size, + &prefix_ids); } // Remove EOS token. diff --git a/tests/translator_test.cc b/tests/translator_test.cc index fda3a9f5e..be95a3824 100644 --- a/tests/translator_test.cc +++ b/tests/translator_test.cc @@ -793,6 +793,40 @@ TEST(TranslatorTest, AlternativesFromFullTarget) { EXPECT_EQ(result.hypotheses[0], (std::vector{"a", "t", "z", "m", "o", "n", "e"})); } +TEST(TranslatorTest, AlternativesMaxDecodingLength) { + Translator translator = default_translator(); + TranslationOptions options; + options.num_hypotheses = 4; + options.max_decoding_length = 2; + options.return_alternatives = true; + options.return_scores = true; + options.return_attention = true; + + const std::vector input = {"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"}; + const std::vector> target_samples = { + {}, {"a"}, {"a", "t"}, {"a", "t", "z"} + }; + + for (const auto& target : target_samples) { + const auto result = translator.translate_with_prefix(input, target, options); + + for (size_t i = 0; i < result.num_hypotheses(); ++i) { + EXPECT_EQ(result.hypotheses[i].size(), options.max_decoding_length); + EXPECT_EQ(result.attention[i].size(), options.max_decoding_length); + + for (size_t t = 0; t < std::min(target.size(), options.max_decoding_length); ++t) { + EXPECT_EQ(result.hypotheses[i][t], target[t]); + } + + if (target.size() < options.max_decoding_length) { + EXPECT_NE(result.scores[i], 0); + } else { + EXPECT_EQ(result.scores[i], 0); + } + } + } +} + TEST(TranslatorTest, DetachModel) { const std::vector input = {"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"}; Translator translator = default_translator();