Skip to content

Commit 239310a

Browse files
committed
add log probs
1 parent e6a8f94 commit 239310a

11 files changed

+100
-7
lines changed

include/ctranslate2/decoding.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@ namespace ctranslate2 {
1515
std::vector<std::vector<size_t>> hypotheses;
1616
std::vector<float> scores;
1717
std::vector<std::vector<std::vector<float>>> attention;
18+
std::vector<std::vector<StorageView>> log_probs_vocab;
1819
};
1920

2021
struct DecodingStepResult {
2122
size_t step;
2223
size_t batch_id;
2324
size_t token_id;
2425
size_t hypothesis_id;
25-
std::optional<float> log_prob;
26+
std::optional<float> score;
27+
std::optional<StorageView> log_probs;
2628
bool is_last = false;
2729
};
2830

@@ -41,6 +43,7 @@ namespace ctranslate2 {
4143
const dim_t min_length,
4244
const bool return_scores = false,
4345
const bool return_attention = false,
46+
const bool return_log_probs_vocab = true,
4447
const bool return_prefix = true,
4548
const size_t num_hypotheses = 1,
4649
const bool include_eos_in_hypotheses = true,
@@ -67,6 +70,7 @@ namespace ctranslate2 {
6770
const dim_t min_length,
6871
const bool return_scores = false,
6972
const bool return_attention = false,
73+
const bool return_log_probs_vocab = true,
7074
const bool return_prefix = true,
7175
const size_t num_hypotheses = 1,
7276
const bool include_eos_in_hypotheses = true,
@@ -118,6 +122,7 @@ namespace ctranslate2 {
118122
const dim_t min_length,
119123
const bool return_scores = false,
120124
const bool return_attention = false,
125+
const bool return_log_probs_vocab = true,
121126
const bool return_prefix = true,
122127
const size_t num_hypotheses = 1,
123128
const bool include_eos_in_hypotheses = true,
@@ -149,6 +154,7 @@ namespace ctranslate2 {
149154
bool include_eos_in_hypotheses = true;
150155
bool return_scores = false;
151156
bool return_attention = false;
157+
bool return_log_probs_vocab = false;
152158
bool return_alternatives = false;
153159
bool return_prefix = true;
154160
float min_alternative_expansion_prob = 0;

include/ctranslate2/generation.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ namespace ctranslate2 {
5353

5454
// Include scores in the result.
5555
bool return_scores = false;
56+
// Include log probs of each token in the result
57+
bool return_log_probs_vocab = false;
5658

5759
// Return alternatives at the first unconstrained decoding position. This is typically
5860
// used with a prefix to provide alternatives at a specifc location.
@@ -79,6 +81,7 @@ namespace ctranslate2 {
7981
std::vector<std::vector<std::string>> sequences;
8082
std::vector<std::vector<size_t>> sequences_ids;
8183
std::vector<float> scores;
84+
std::vector<std::vector<StorageView>> log_probs;
8285

8386
size_t num_sequences() const {
8487
return sequences.size();
@@ -95,7 +98,8 @@ namespace ctranslate2 {
9598
size_t token_id;
9699
size_t hypothesis_id;
97100
std::string token;
98-
std::optional<float> log_prob;
101+
std::optional<float> score;
102+
std::optional<StorageView> log_probs;
99103
bool is_last;
100104

101105
GenerationStepResult() = default;
@@ -105,7 +109,8 @@ namespace ctranslate2 {
105109
, token_id(result.token_id)
106110
, hypothesis_id(result.hypothesis_id)
107111
, token(vocabulary.to_token(result.token_id))
108-
, log_prob(result.log_prob)
112+
, score(result.score)
113+
, log_probs(result.log_probs)
109114
, is_last(result.is_last)
110115
{
111116
}

include/ctranslate2/models/whisper.h

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ namespace ctranslate2 {
4141
// Include scores in the result.
4242
bool return_scores = false;
4343

44+
// Include log probs of each token in the result
45+
bool return_log_probs_vocab = false;
46+
4447
// Include the probability of the no speech token in the result.
4548
bool return_no_speech_prob = false;
4649

@@ -59,6 +62,7 @@ namespace ctranslate2 {
5962
std::vector<std::vector<std::string>> sequences;
6063
std::vector<std::vector<size_t>> sequences_ids;
6164
std::vector<float> scores;
65+
std::vector<std::vector<StorageView>> log_probs;
6266
float no_speech_prob = 0;
6367

6468
size_t num_sequences() const {

include/ctranslate2/translation.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ namespace ctranslate2 {
6767
bool return_scores = false;
6868
// Store attention vectors in the TranslationResult class.
6969
bool return_attention = false;
70+
// Store log probs matrix in the TranslationResult class.
71+
bool return_log_probs_vocab = false;
7072

7173
// Return alternatives at the first unconstrained decoding position. This is typically
7274
// used with a target prefix to provide alternatives at a specifc location in the
@@ -87,6 +89,7 @@ namespace ctranslate2 {
8789
std::vector<std::vector<std::string>> hypotheses;
8890
std::vector<float> scores;
8991
std::vector<std::vector<std::vector<float>>> attention;
92+
std::vector<std::vector<StorageView>> log_probs;
9093

9194
TranslationResult(std::vector<std::vector<std::string>> hypotheses_)
9295
: hypotheses(std::move(hypotheses_))
@@ -95,10 +98,12 @@ namespace ctranslate2 {
9598

9699
TranslationResult(std::vector<std::vector<std::string>> hypotheses_,
97100
std::vector<float> scores_,
98-
std::vector<std::vector<std::vector<float>>> attention_)
101+
std::vector<std::vector<std::vector<float>>> attention_,
102+
std::vector<std::vector<StorageView>> log_probs_)
99103
: hypotheses(std::move(hypotheses_))
100104
, scores(std::move(scores_))
101105
, attention(std::move(attention_))
106+
, log_probs(std::move(log_probs_))
102107
{
103108
}
104109

python/cpp/generator.cc

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace ctranslate2 {
3333
bool cache_static_prompt,
3434
bool include_prompt_in_result,
3535
bool return_scores,
36+
bool return_log_probs_vocab,
3637
bool return_alternatives,
3738
float min_alternative_expansion_prob,
3839
size_t sampling_topk,
@@ -58,6 +59,7 @@ namespace ctranslate2 {
5859
options.num_hypotheses = num_hypotheses;
5960
options.return_end_token = return_end_token;
6061
options.return_scores = return_scores;
62+
options.return_log_probs_vocab = return_log_probs_vocab;
6163
options.return_alternatives = return_alternatives;
6264
options.cache_static_prompt = cache_static_prompt;
6365
options.include_prompt_in_result = include_prompt_in_result;
@@ -203,6 +205,7 @@ namespace ctranslate2 {
203205
py::arg("cache_static_prompt")=true,
204206
py::arg("include_prompt_in_result")=true,
205207
py::arg("return_scores")=false,
208+
py::arg("return_log_probs_vocab")=false,
206209
py::arg("return_alternatives")=false,
207210
py::arg("min_alternative_expansion_prob")=0,
208211
py::arg("sampling_topk")=1,
@@ -260,6 +263,7 @@ namespace ctranslate2 {
260263
reuse it for future generations using the same static prompt.
261264
include_prompt_in_result: Include the :obj:`start_tokens` in the result.
262265
return_scores: Include the scores in the output.
266+
return_log_probs_vocab: Include log probs for each token in the output
263267
return_alternatives: Return alternatives at the first unconstrained decoding position.
264268
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.
265269
sampling_topk: Randomly sample predictions from the top K candidates.

python/cpp/translator.cc

+5
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ namespace ctranslate2 {
5252
size_t min_decoding_length,
5353
bool use_vmap,
5454
bool with_scores,
55+
bool return_log_probs_vocab,
5556
size_t sampling_topk,
5657
float sampling_topp,
5758
float sampling_temperature,
@@ -141,6 +142,7 @@ namespace ctranslate2 {
141142
size_t min_decoding_length,
142143
bool use_vmap,
143144
bool return_scores,
145+
bool return_log_probs_vocab,
144146
bool return_attention,
145147
bool return_alternatives,
146148
float min_alternative_expansion_prob,
@@ -172,6 +174,7 @@ namespace ctranslate2 {
172174
options.use_vmap = use_vmap;
173175
options.return_end_token = return_end_token;
174176
options.return_scores = return_scores;
177+
options.return_log_probs_vocab = return_log_probs_vocab;
175178
options.return_attention = return_attention;
176179
options.return_alternatives = return_alternatives;
177180
options.min_alternative_expansion_prob = min_alternative_expansion_prob;
@@ -354,6 +357,7 @@ namespace ctranslate2 {
354357
py::arg("min_decoding_length")=1,
355358
py::arg("use_vmap")=false,
356359
py::arg("return_scores")=false,
360+
py::arg("return_log_probs_vocab")=false,
357361
py::arg("return_attention")=false,
358362
py::arg("return_alternatives")=false,
359363
py::arg("min_alternative_expansion_prob")=0,
@@ -396,6 +400,7 @@ namespace ctranslate2 {
396400
min_decoding_length: Minimum prediction length.
397401
use_vmap: Use the vocabulary mapping file saved in this model
398402
return_scores: Include the scores in the output.
403+
return_log_probs_vocab: Include the log probs of each token in the output
399404
return_attention: Include the attention vectors in the output.
400405
return_alternatives: Return alternatives at the first unconstrained decoding position.
401406
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.

python/cpp/whisper.cc

+4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ namespace ctranslate2 {
4040
size_t no_repeat_ngram_size,
4141
size_t max_length,
4242
bool return_scores,
43+
bool return_log_probs_vocab,
4344
bool return_no_speech_prob,
4445
size_t max_initial_timestamp_index,
4546
bool suppress_blank,
@@ -59,6 +60,7 @@ namespace ctranslate2 {
5960
options.max_length = max_length;
6061
options.num_hypotheses = num_hypotheses;
6162
options.return_scores = return_scores;
63+
options.return_log_probs_vocab = return_log_probs_vocab;
6264
options.return_no_speech_prob = return_no_speech_prob;
6365
options.max_initial_timestamp_index = max_initial_timestamp_index;
6466
options.suppress_blank = suppress_blank;
@@ -247,6 +249,7 @@ namespace ctranslate2 {
247249
py::arg("no_repeat_ngram_size")=0,
248250
py::arg("max_length")=448,
249251
py::arg("return_scores")=false,
252+
py::arg("return_log_probs_vocab")=false,
250253
py::arg("return_no_speech_prob")=false,
251254
py::arg("max_initial_timestamp_index")=50,
252255
py::arg("suppress_blank")=true,
@@ -276,6 +279,7 @@ namespace ctranslate2 {
276279
(set 0 to disable).
277280
max_length: Maximum generation length.
278281
return_scores: Include the scores in the output.
282+
return_log_probs_vocab: Include the log probs in the output
279283
return_no_speech_prob: Include the probability of the no speech token in the
280284
result.
281285
max_initial_timestamp_index: Maximum index of the first predicted timestamp.

0 commit comments

Comments
 (0)