Skip to content

Commit

Permalink
ref ggerganov#10 : option to keep context in "stream" example
Browse files Browse the repository at this point in the history
Seems the results become worse when we keep the context, so by default
this is not enabled
  • Loading branch information
ggerganov committed Oct 7, 2022
1 parent 3f15bb8 commit 481cd68
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
5 changes: 5 additions & 0 deletions stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct whisper_params {

bool verbose = false;
bool translate = false;
bool no_context = true;
bool print_special_tokens = false;
bool no_timestamps = true;

Expand All @@ -64,6 +65,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.verbose = true;
} else if (arg == "--translate") {
params.translate = true;
} else if (arg == "-kc" || arg == "--keep-context") {
params.no_context = false;
} else if (arg == "-l" || arg == "--language") {
params.language = argv[++i];
if (whisper_lang_id(params.language.c_str()) == -1) {
Expand Down Expand Up @@ -103,6 +106,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms);
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -nc, --no-context disable context from earlier audio (default: false)\n");
fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
Expand Down Expand Up @@ -273,6 +277,7 @@ int main(int argc, char ** argv) {
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = params.no_context;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;

Expand Down
19 changes: 12 additions & 7 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ struct whisper_context {

std::vector<whisper_result> result_cur;
std::vector<whisper_segment> result_all;

std::vector<whisper_token> prompt_past;
};

// load the model from a ggml file
Expand Down Expand Up @@ -1020,8 +1022,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// - model: the model
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
// - mel_inp: input mel spectrogram
// - features: output encoded features
//
bool whisper_encode(
whisper_context & wctx,
Expand Down Expand Up @@ -1405,10 +1405,9 @@ bool whisper_encode(
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: prompt length
// - prompt: text prompt
// - logits_out: output logits
// - probs_out: output probabilities
// - tokens: text prompt
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
bool whisper_decode(
whisper_context & wctx,
Expand Down Expand Up @@ -2259,6 +2258,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
.offset_ms = 0,

.translate = false,
.no_context = false,
.print_special_tokens = false,
.print_progress = true,
.print_realtime = false,
Expand All @@ -2279,6 +2279,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
.offset_ms = 0,

.translate = false,
.no_context = false,
.print_special_tokens = false,
.print_progress = true,
.print_realtime = false,
Expand All @@ -2297,6 +2298,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat

return result;
}

int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
Expand All @@ -2309,7 +2311,10 @@ int whisper_full(
}

// the accumulated text context so far
std::vector<whisper_token> prompt_past = { };
auto & prompt_past = ctx->prompt_past;
if (params.no_context) {
prompt_past.clear();
}

// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
Expand Down
1 change: 1 addition & 0 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ extern "C" {
int offset_ms;

bool translate;
bool no_context;
bool print_special_tokens;
bool print_progress;
bool print_realtime;
Expand Down

0 comments on commit 481cd68

Please sign in to comment.