Skip to content

Commit 6647945

Browse files
authoredAug 26, 2024··
Add log probs for all tokens (#1755)
* add log probs * fix compilation * fix compilation * fix test * fix black * return logits * fix compilation * fix test * last clean
1 parent 8ba828c commit 6647945

17 files changed

+122
-13
lines changed
 

‎include/ctranslate2/decoding.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@ namespace ctranslate2 {
1515
std::vector<std::vector<size_t>> hypotheses;
1616
std::vector<float> scores;
1717
std::vector<std::vector<std::vector<float>>> attention;
18+
std::vector<std::vector<StorageView>> logits_vocab;
1819
};
1920

2021
struct DecodingStepResult {
2122
size_t step;
2223
size_t batch_id;
2324
size_t token_id;
2425
size_t hypothesis_id;
25-
std::optional<float> log_prob;
26+
std::optional<float> score;
27+
std::optional<StorageView> logits;
2628
bool is_last = false;
2729
};
2830

@@ -41,6 +43,7 @@ namespace ctranslate2 {
4143
const dim_t min_length,
4244
const bool return_scores = false,
4345
const bool return_attention = false,
46+
const bool return_logits_vocab = true,
4447
const bool return_prefix = true,
4548
const size_t num_hypotheses = 1,
4649
const bool include_eos_in_hypotheses = true,
@@ -67,6 +70,7 @@ namespace ctranslate2 {
6770
const dim_t min_length,
6871
const bool return_scores = false,
6972
const bool return_attention = false,
73+
const bool return_logits_vocab = true,
7074
const bool return_prefix = true,
7175
const size_t num_hypotheses = 1,
7276
const bool include_eos_in_hypotheses = true,
@@ -118,6 +122,7 @@ namespace ctranslate2 {
118122
const dim_t min_length,
119123
const bool return_scores = false,
120124
const bool return_attention = false,
125+
const bool return_logits_vocab = true,
121126
const bool return_prefix = true,
122127
const size_t num_hypotheses = 1,
123128
const bool include_eos_in_hypotheses = true,
@@ -149,6 +154,7 @@ namespace ctranslate2 {
149154
bool include_eos_in_hypotheses = true;
150155
bool return_scores = false;
151156
bool return_attention = false;
157+
bool return_logits_vocab = false;
152158
bool return_alternatives = false;
153159
bool return_prefix = true;
154160
float min_alternative_expansion_prob = 0;

‎include/ctranslate2/generation.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ namespace ctranslate2 {
5353

5454
// Include scores in the result.
5555
bool return_scores = false;
56+
// Include log probs of each token in the result
57+
bool return_logits_vocab = false;
5658

5759
// Return alternatives at the first unconstrained decoding position. This is typically
5860
// used with a prefix to provide alternatives at a specifc location.
@@ -79,6 +81,7 @@ namespace ctranslate2 {
7981
std::vector<std::vector<std::string>> sequences;
8082
std::vector<std::vector<size_t>> sequences_ids;
8183
std::vector<float> scores;
84+
std::vector<std::vector<StorageView>> logits;
8285

8386
size_t num_sequences() const {
8487
return sequences.size();
@@ -95,7 +98,8 @@ namespace ctranslate2 {
9598
size_t token_id;
9699
size_t hypothesis_id;
97100
std::string token;
98-
std::optional<float> log_prob;
101+
std::optional<float> score;
102+
std::optional<StorageView> logits;
99103
bool is_last;
100104

101105
GenerationStepResult() = default;
@@ -105,7 +109,8 @@ namespace ctranslate2 {
105109
, token_id(result.token_id)
106110
, hypothesis_id(result.hypothesis_id)
107111
, token(vocabulary.to_token(result.token_id))
108-
, log_prob(result.log_prob)
112+
, score(result.score)
113+
, logits(result.logits)
109114
, is_last(result.is_last)
110115
{
111116
}

‎include/ctranslate2/models/whisper.h

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ namespace ctranslate2 {
4141
// Include scores in the result.
4242
bool return_scores = false;
4343

44+
// Include log probs of each token in the result
45+
bool return_logits_vocab = false;
46+
4447
// Include the probability of the no speech token in the result.
4548
bool return_no_speech_prob = false;
4649

@@ -59,6 +62,7 @@ namespace ctranslate2 {
5962
std::vector<std::vector<std::string>> sequences;
6063
std::vector<std::vector<size_t>> sequences_ids;
6164
std::vector<float> scores;
65+
std::vector<std::vector<StorageView>> logits;
6266
float no_speech_prob = 0;
6367

6468
size_t num_sequences() const {

‎include/ctranslate2/translation.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ namespace ctranslate2 {
6767
bool return_scores = false;
6868
// Store attention vectors in the TranslationResult class.
6969
bool return_attention = false;
70+
// Store log probs matrix in the TranslationResult class.
71+
bool return_logits_vocab = false;
7072

7173
// Return alternatives at the first unconstrained decoding position. This is typically
7274
// used with a target prefix to provide alternatives at a specifc location in the
@@ -87,6 +89,7 @@ namespace ctranslate2 {
8789
std::vector<std::vector<std::string>> hypotheses;
8890
std::vector<float> scores;
8991
std::vector<std::vector<std::vector<float>>> attention;
92+
std::vector<std::vector<StorageView>> logits;
9093

9194
TranslationResult(std::vector<std::vector<std::string>> hypotheses_)
9295
: hypotheses(std::move(hypotheses_))
@@ -95,10 +98,12 @@ namespace ctranslate2 {
9598

9699
TranslationResult(std::vector<std::vector<std::string>> hypotheses_,
97100
std::vector<float> scores_,
98-
std::vector<std::vector<std::vector<float>>> attention_)
101+
std::vector<std::vector<std::vector<float>>> attention_,
102+
std::vector<std::vector<StorageView>> logits_)
99103
: hypotheses(std::move(hypotheses_))
100104
, scores(std::move(scores_))
101105
, attention(std::move(attention_))
106+
, logits(std::move(logits_))
102107
{
103108
}
104109

‎python/cpp/generation_result.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ namespace ctranslate2 {
2121
"Index of the hypothesis in the batch.")
2222
.def_readonly("token", &GenerationStepResult::token,
2323
"String value of the generated token.")
24-
.def_readonly("log_prob", &GenerationStepResult::log_prob,
24+
.def_readonly("log_prob", &GenerationStepResult::score,
2525
"Log probability of the token (``None`` if :obj:`return_log_prob` was disabled).")
26+
.def_readonly("logits", &GenerationStepResult::logits,
27+
"Log probability on the vocab of all tokens.")
2628
.def_readonly("is_last", &GenerationStepResult::is_last,
2729
"Whether this step is the last decoding step for this batch.")
2830

@@ -32,7 +34,8 @@ namespace ctranslate2 {
3234
+ ", token_id=" + std::string(py::repr(py::cast(result.token_id)))
3335
+ ", hypothesis_id=" + std::string(py::repr(py::cast(result.hypothesis_id)))
3436
+ ", token=" + std::string(py::repr(py::cast(result.token)))
35-
+ ", log_prob=" + std::string(py::repr(py::cast(result.log_prob)))
37+
+ ", log_prob=" + std::string(py::repr(py::cast(result.score)))
38+
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
3639
+ ", is_last=" + std::string(py::repr(py::cast(result.is_last)))
3740
+ ")";
3841
})
@@ -46,11 +49,14 @@ namespace ctranslate2 {
4649
"Generated sequences of token IDs.")
4750
.def_readonly("scores", &GenerationResult::scores,
4851
"Score of each sequence (empty if :obj:`return_scores` was disabled).")
52+
.def_readonly("logits", &GenerationResult::logits,
53+
"Score of each sequence (empty if :obj:`return_logits_vocab` was disabled).")
4954

5055
.def("__repr__", [](const GenerationResult& result) {
5156
return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences)))
5257
+ ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids)))
5358
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
59+
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
5460
+ ")";
5561
})
5662
;

‎python/cpp/generator.cc

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace ctranslate2 {
3333
bool cache_static_prompt,
3434
bool include_prompt_in_result,
3535
bool return_scores,
36+
bool return_logits_vocab,
3637
bool return_alternatives,
3738
float min_alternative_expansion_prob,
3839
size_t sampling_topk,
@@ -58,6 +59,7 @@ namespace ctranslate2 {
5859
options.num_hypotheses = num_hypotheses;
5960
options.return_end_token = return_end_token;
6061
options.return_scores = return_scores;
62+
options.return_logits_vocab = return_logits_vocab;
6163
options.return_alternatives = return_alternatives;
6264
options.cache_static_prompt = cache_static_prompt;
6365
options.include_prompt_in_result = include_prompt_in_result;
@@ -203,6 +205,7 @@ namespace ctranslate2 {
203205
py::arg("cache_static_prompt")=true,
204206
py::arg("include_prompt_in_result")=true,
205207
py::arg("return_scores")=false,
208+
py::arg("return_logits_vocab")=false,
206209
py::arg("return_alternatives")=false,
207210
py::arg("min_alternative_expansion_prob")=0,
208211
py::arg("sampling_topk")=1,
@@ -260,6 +263,7 @@ namespace ctranslate2 {
260263
reuse it for future generations using the same static prompt.
261264
include_prompt_in_result: Include the :obj:`start_tokens` in the result.
262265
return_scores: Include the scores in the output.
266+
return_logits_vocab: Include log probs for each token in the output
263267
return_alternatives: Return alternatives at the first unconstrained decoding position.
264268
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.
265269
sampling_topk: Randomly sample predictions from the top K candidates.

‎python/cpp/storage_view.cc

+6
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ namespace ctranslate2 {
192192
return stream.str();
193193
})
194194

195+
.def("__repr__", [](const StorageView& view) {
196+
std::ostringstream stream;
197+
stream << view;
198+
return stream.str();
199+
})
200+
195201
.def("to",
196202
[](const StorageView& view, DataType dtype) {
197203
ScopedDeviceSetter device_setter(view.device(), view.device_index());

‎python/cpp/translation_result.cc

+3
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@ namespace ctranslate2 {
1616
"Score of each translation hypothesis (empty if :obj:`return_scores` was disabled).")
1717
.def_readonly("attention", &TranslationResult::attention,
1818
"Attention matrix of each translation hypothesis (empty if :obj:`return_attention` was disabled).")
19+
.def_readonly("logits", &TranslationResult::logits,
20+
"Score of each translation hypothesis (empty if :obj:`return_logits_vocab` was disabled).")
1921

2022
.def("__repr__", [](const TranslationResult& result) {
2123
return "TranslationResult(hypotheses=" + std::string(py::repr(py::cast(result.hypotheses)))
2224
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
2325
+ ", attention=" + std::string(py::repr(py::cast(result.attention)))
26+
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
2427
+ ")";
2528
})
2629

‎python/cpp/translator.cc

+4
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ namespace ctranslate2 {
141141
size_t min_decoding_length,
142142
bool use_vmap,
143143
bool return_scores,
144+
bool return_logits_vocab,
144145
bool return_attention,
145146
bool return_alternatives,
146147
float min_alternative_expansion_prob,
@@ -172,6 +173,7 @@ namespace ctranslate2 {
172173
options.use_vmap = use_vmap;
173174
options.return_end_token = return_end_token;
174175
options.return_scores = return_scores;
176+
options.return_logits_vocab = return_logits_vocab;
175177
options.return_attention = return_attention;
176178
options.return_alternatives = return_alternatives;
177179
options.min_alternative_expansion_prob = min_alternative_expansion_prob;
@@ -354,6 +356,7 @@ namespace ctranslate2 {
354356
py::arg("min_decoding_length")=1,
355357
py::arg("use_vmap")=false,
356358
py::arg("return_scores")=false,
359+
py::arg("return_logits_vocab")=false,
357360
py::arg("return_attention")=false,
358361
py::arg("return_alternatives")=false,
359362
py::arg("min_alternative_expansion_prob")=0,
@@ -396,6 +399,7 @@ namespace ctranslate2 {
396399
min_decoding_length: Minimum prediction length.
397400
use_vmap: Use the vocabulary mapping file saved in this model
398401
return_scores: Include the scores in the output.
402+
return_logits_vocab: Include the log probs of each token in the output
399403
return_attention: Include the attention vectors in the output.
400404
return_alternatives: Return alternatives at the first unconstrained decoding position.
401405
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.

‎python/cpp/whisper.cc

+4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ namespace ctranslate2 {
4040
size_t no_repeat_ngram_size,
4141
size_t max_length,
4242
bool return_scores,
43+
bool return_logits_vocab,
4344
bool return_no_speech_prob,
4445
size_t max_initial_timestamp_index,
4546
bool suppress_blank,
@@ -59,6 +60,7 @@ namespace ctranslate2 {
5960
options.max_length = max_length;
6061
options.num_hypotheses = num_hypotheses;
6162
options.return_scores = return_scores;
63+
options.return_logits_vocab = return_logits_vocab;
6264
options.return_no_speech_prob = return_no_speech_prob;
6365
options.max_initial_timestamp_index = max_initial_timestamp_index;
6466
options.suppress_blank = suppress_blank;
@@ -247,6 +249,7 @@ namespace ctranslate2 {
247249
py::arg("no_repeat_ngram_size")=0,
248250
py::arg("max_length")=448,
249251
py::arg("return_scores")=false,
252+
py::arg("return_logits_vocab")=false,
250253
py::arg("return_no_speech_prob")=false,
251254
py::arg("max_initial_timestamp_index")=50,
252255
py::arg("suppress_blank")=true,
@@ -276,6 +279,7 @@ namespace ctranslate2 {
276279
(set 0 to disable).
277280
max_length: Maximum generation length.
278281
return_scores: Include the scores in the output.
282+
return_logits_vocab: Include the log probs in the output
279283
return_no_speech_prob: Include the probability of the no speech token in the
280284
result.
281285
max_initial_timestamp_index: Maximum index of the first predicted timestamp.

‎python/tests/test_translator.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,12 @@ def test_batch_translation(max_batch_size):
111111
assert output[0].scores[0] < 0
112112
assert not output[0].attention
113113

114-
expected_repr = "TranslationResult(hypotheses=%s, scores=%s, attention=[])" % (
115-
output[0].hypotheses,
116-
output[0].scores,
114+
expected_repr = (
115+
"TranslationResult(hypotheses=%s, scores=%s, attention=[], logits=[])"
116+
% (
117+
output[0].hypotheses,
118+
output[0].scores,
119+
)
117120
)
118121
assert repr(output[0]) == expected_repr
119122

0 commit comments

Comments
 (0)
Please sign in to comment.