Skip to content

Commit

Permalink
Fix application of max_decoding_length in return_alternatives mode (#866
Browse files Browse the repository at this point in the history
)
  • Loading branch information
guillaumekln authored Jul 6, 2022
1 parent 0598ef9 commit 739a5b1
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 126 deletions.
283 changes: 157 additions & 126 deletions src/decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,133 @@ namespace ctranslate2 {
options.allow_early_exit);
}

static DecodingResult
decode_alternatives(layers::Decoder& decoder,
layers::DecoderState& state,
const size_t start_token,
const std::vector<size_t>& prefix_tokens,
const size_t end_id,
const DecodingOptions& options,
const std::vector<size_t>* output_ids_map) {
DecodingResult result;
result.hypotheses.resize(options.num_hypotheses);
if (options.return_scores)
result.scores.resize(options.num_hypotheses, 0);
if (options.return_attention)
result.attention.resize(options.num_hypotheses);

std::vector<size_t> start_ids{start_token};
std::vector<size_t> prefix_ids(prefix_tokens);
if (prefix_ids.size() > options.max_length)
prefix_ids.resize(options.max_length);

const dim_t min_length = options.min_length;
const dim_t max_length = options.max_length;
dim_t start_step = 0;

if (!prefix_ids.empty()) {
// Initialize the decoder state with the prefix.
std::vector<std::vector<float>> prefix_attention;
initialize_decoder_with_prefix(decoder,
state,
start_ids[0],
prefix_ids,
options.return_attention ? &prefix_attention : nullptr);

for (size_t i = 0; i < options.num_hypotheses; ++i) {
result.hypotheses[i] = prefix_ids;
if (options.return_attention)
result.attention[i] = prefix_attention;
}

start_ids[0] = prefix_ids.back();
start_step += prefix_ids.size();
if (start_step == max_length)
return result;
}

// Expand the next "num_hypotheses" candidate words using the beam search.
BeamSearch beam(options.num_hypotheses);
DecodingResult expansion_result = beam.search(decoder,
state,
BestSampler(),
start_ids,
end_id,
start_step,
/*max_length=*/1,
/*min_length=*/1,
output_ids_map,
options.normalize_scores,
options.return_scores,
options.return_attention,
options.num_hypotheses)[0];

start_ids.resize(options.num_hypotheses);

for (size_t i = 0; i < options.num_hypotheses; ++i) {
// Add expanded word to the result.
result.hypotheses[i].emplace_back(expansion_result.hypotheses[i].back());
if (options.return_attention)
result.attention[i].emplace_back(std::move(expansion_result.attention[i].back()));
if (options.return_scores)
result.scores[i] = expansion_result.scores[i];

// The next input is the words we just expanded.
start_ids[i] = result.hypotheses[i].back();
}

start_step += 1;
if (start_step == max_length)
return result;

// Continue the decoding from each alternative words independently.
const auto search_strategy = make_search_strategy(options);
const auto sampler = make_sampler(options);
auto suffix_results = search_strategy->search(decoder,
state,
*sampler,
start_ids,
end_id,
start_step,
std::max(max_length - start_step, dim_t(0)),
std::max(min_length - start_step, dim_t(0)),
output_ids_map,
options.normalize_scores,
options.return_scores,
options.return_attention,
/*num_hypotheses=*/1,
options.repetition_penalty,
options.no_repeat_ngram_size);

// Update the result with the suffix decoding.
for (size_t i = 0; i < options.num_hypotheses; ++i) {
auto& suffix = suffix_results[i];

if (options.return_scores) {
if (options.normalize_scores) {
const auto prefix_length = result.hypotheses[i].size();
const auto suffix_length = suffix.hypotheses[0].size();
result.scores[i] = (
(result.scores[i] * prefix_length + suffix.scores[0] * suffix_length)
/ (prefix_length + suffix_length));
} else {
result.scores[i] += suffix.scores[0];
}
}

if (options.return_attention)
result.attention[i].insert(result.attention[i].end(),
std::make_move_iterator(suffix.attention[0].begin()),
std::make_move_iterator(suffix.attention[0].end()));

result.hypotheses[i].insert(result.hypotheses[i].end(),
std::make_move_iterator(suffix.hypotheses[0].begin()),
std::make_move_iterator(suffix.hypotheses[0].end()));
}

return result;
}

std::vector<DecodingResult>
decode(layers::Decoder& decoder,
layers::DecoderState& state,
Expand All @@ -925,140 +1052,44 @@ namespace ctranslate2 {
validate_decoding_options(options, output_ids_map);
const size_t batch_size = start_tokens.size();

if (options.return_alternatives && batch_size > 1) {
// return_alternatives mode currently does not support batch decoding.
std::vector<DecodingResult> results;
results.reserve(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
layers::DecoderState batch_state = get_batch_state(state, i);
results.emplace_back(decode(decoder,
batch_state,
{start_tokens[i]},
end_id,
options,
output_ids_map)[0]);
}
return results;
}
std::vector<DecodingResult> results;

std::vector<size_t> start_ids;
std::vector<std::vector<size_t>> prefix_ids;
std::tie(start_ids, prefix_ids) = split_start_tokens(start_tokens);

dim_t start_step = 0;
dim_t min_length = options.min_length;
dim_t max_length = options.max_length;

std::vector<DecodingResult> expansion_results;
if (options.return_alternatives) {
std::vector<std::vector<std::vector<float>>> prefix_attention;
if (!prefix_ids.empty()) {
if (options.return_attention)
prefix_attention.resize(1);
initialize_decoder_with_prefix(decoder,
state,
start_ids[0],
prefix_ids[0],
options.return_attention ? &prefix_attention[0] : nullptr);
start_ids[0] = prefix_ids[0].back();
const dim_t prefix_length = prefix_ids[0].size();
start_step += prefix_length;
max_length = std::max(max_length - prefix_length, dim_t(0));
min_length = std::max(min_length - prefix_length, dim_t(0));
}

// In this translation mode, we first expand the next "num_hypotheses" candidate words
// before running the full decoding on each prefix. This is to ensure that we get unique
// alternatives at this decoding position.
expansion_results = BeamSearch(options.num_hypotheses).search(decoder,
state,
BestSampler(),
start_ids,
end_id,
start_step,
/*max_length=*/1,
/*min_length=*/1,
output_ids_map,
options.normalize_scores,
options.return_scores,
options.return_attention,
options.num_hypotheses);

start_ids.resize(batch_size * options.num_hypotheses);
for (size_t b = 0; b < batch_size; ++b) {
auto& result = expansion_results[b];

for (size_t i = 0; i < options.num_hypotheses; ++i) {
// The next input is the words we just expanded.
start_ids[b * options.num_hypotheses + i] = result.hypotheses[i].back();

// Prepend expansion result with the prefix.
if (!prefix_ids.empty()) {
result.hypotheses[i].insert(result.hypotheses[i].begin(),
prefix_ids[b].begin(),
prefix_ids[b].end());
if (options.return_attention) {
result.attention[i].insert(result.attention[i].begin(),
prefix_attention[b].begin(),
prefix_attention[b].end());
}
}
}
results.reserve(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
layers::DecoderState batch_state = get_batch_state(state, i);
results.emplace_back(decode_alternatives(decoder,
batch_state,
start_ids[i],
prefix_ids[i],
end_id,
options,
output_ids_map));
}

start_step += 1;
max_length = std::max(max_length - 1, dim_t(0));
min_length = std::max(min_length - 1, dim_t(0));
}

const auto search_strategy = make_search_strategy(options);
const auto sampler = make_sampler(options);
auto results = search_strategy->search(decoder,
state,
*sampler,
start_ids,
end_id,
start_step,
max_length,
min_length,
output_ids_map,
options.normalize_scores,
options.return_scores,
options.return_attention,
options.return_alternatives ? 1 : options.num_hypotheses,
options.repetition_penalty,
options.no_repeat_ngram_size,
options.return_alternatives ? nullptr : &prefix_ids);

if (options.return_alternatives) {
// Append to expansion results.
for (size_t b = 0; b < batch_size; ++b) {
auto& prefix = expansion_results[b];
for (size_t i = 0; i < options.num_hypotheses; ++i) {
auto& suffix = results[b * options.num_hypotheses + i];

if (!prefix.scores.empty()) {
if (options.normalize_scores) {
const auto prefix_length = prefix.hypotheses[i].size();
const auto suffix_length = suffix.hypotheses[0].size();
prefix.scores[i] = (
(prefix.scores[i] * prefix_length + suffix.scores[0] * suffix_length)
/ (prefix_length + suffix_length));
} else {
prefix.scores[i] += suffix.scores[0];
}
}

prefix.hypotheses[i].insert(prefix.hypotheses[i].end(),
std::make_move_iterator(suffix.hypotheses[0].begin()),
std::make_move_iterator(suffix.hypotheses[0].end()));
if (!prefix.attention.empty())
prefix.attention[i].insert(prefix.attention[i].end(),
std::make_move_iterator(suffix.attention[0].begin()),
std::make_move_iterator(suffix.attention[0].end()));
}
}
results = std::move(expansion_results);
} else {
const auto search_strategy = make_search_strategy(options);
const auto sampler = make_sampler(options);
results = search_strategy->search(decoder,
state,
*sampler,
start_ids,
end_id,
/*start_step=*/0,
options.max_length,
options.min_length,
output_ids_map,
options.normalize_scores,
options.return_scores,
options.return_attention,
options.num_hypotheses,
options.repetition_penalty,
options.no_repeat_ngram_size,
&prefix_ids);
}

// Remove EOS token.
Expand Down
34 changes: 34 additions & 0 deletions tests/translator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,40 @@ TEST(TranslatorTest, AlternativesFromFullTarget) {
EXPECT_EQ(result.hypotheses[0], (std::vector<std::string>{"a", "t", "z", "m", "o", "n", "e"}));
}

TEST(TranslatorTest, AlternativesMaxDecodingLength) {
Translator translator = default_translator();
TranslationOptions options;
options.num_hypotheses = 4;
options.max_decoding_length = 2;
options.return_alternatives = true;
options.return_scores = true;
options.return_attention = true;

const std::vector<std::string> input = {"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"};
const std::vector<std::vector<std::string>> target_samples = {
{}, {"a"}, {"a", "t"}, {"a", "t", "z"}
};

for (const auto& target : target_samples) {
const auto result = translator.translate_with_prefix(input, target, options);

for (size_t i = 0; i < result.num_hypotheses(); ++i) {
EXPECT_EQ(result.hypotheses[i].size(), options.max_decoding_length);
EXPECT_EQ(result.attention[i].size(), options.max_decoding_length);

for (size_t t = 0; t < std::min(target.size(), options.max_decoding_length); ++t) {
EXPECT_EQ(result.hypotheses[i][t], target[t]);
}

if (target.size() < options.max_decoding_length) {
EXPECT_NE(result.scores[i], 0);
} else {
EXPECT_EQ(result.scores[i], 0);
}
}
}
}

TEST(TranslatorTest, DetachModel) {
const std::vector<std::string> input = {"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"};
Translator translator = default_translator();
Expand Down

0 comments on commit 739a5b1

Please sign in to comment.