From 8de452c18bf83bb0d8c5735a86c7dc40156f5792 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Jan 2023 11:29:57 +0200 Subject: [PATCH] Improve decoding (#291) * whisper : prepare infra for new decoding strategies * whisper : apply logit filters and compute logprobs * whisper : add whisper_get_logits() * whisper : separate self and cross attention memory Initial step needed for supporting parallel decoders * whisper : move probs_id buffer to whisper_context * whisper : refactor kv cache into separate struct * whisper : move self-attention kv cache to whisper_decoder * whisper : wip decoding parameters + strategies * whisper : wip decoding parameters + strategies (part 2) * whisper : wip decoding parameters + strategies (part 3) * whisper : wip decoding parameters + strategies (part 4) * whisper : fix prompt_past update to not include prompt_init * whisper : temperature + best_of support * whisper : support for compression_ration_threshold We actually use entropy, but it is similar * command : fix example to use logits instead of obsolete probs * whisper : handle empty sequence ranking * whisper : add WHISPER_DEBUG + diagnostic prints + new main args * whisper : minor fixes * whisper : add beam-search support * whisper : bug fix when there no previous context * whisper : add comments * stream : disable temperature fallback For real-time processing, we always want a single decoder running at T=0 * whisper.swiftui : update example - fix paths + add empty folders --- .gitignore | 2 + README.md | 12 +- examples/command/command.cpp | 107 +- examples/main/main.cpp | 91 +- examples/stream.wasm/emscripten.cpp | 3 + examples/stream/stream.cpp | 3 + .../Resources/models/.gitignore | 0 .../Resources/samples/.gitignore | 0 .../whisper.swiftui.xcodeproj/project.pbxproj | 11 +- whisper.cpp | 2042 +++++++++++------ whisper.h | 60 +- 11 files changed, 1539 insertions(+), 792 deletions(-) create mode 100644 examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore create mode 100644 examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore diff --git a/.gitignore b/.gitignore index 8a495199e75..5ca3702c331 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ build/ build-em/ build-debug/ build-release/ +build-static/ build-sanitize-addr/ build-sanitize-thread/ @@ -18,6 +19,7 @@ build-sanitize-thread/ /bench sync.sh +libwhisper.a libwhisper.so compile_commands.json diff --git a/README.md b/README.md index f22724a5054..448e7588059 100644 --- a/README.md +++ b/README.md @@ -212,17 +212,7 @@ make large ## Limitations - Inference only -- No GPU support -- Very basic greedy sampling scheme - always pick up the token with highest probability. - This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274) - from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure - to run the python code with the following parameters: - - ``` - whisper --best_of None --beam_size None ... - ``` - - In the future, `whisper.cpp` will support more sampling strategies. +- No GPU support (yet) ## Another example diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 3dae3a5e31c..2bdaf87c45c 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -671,56 +671,81 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const break; } - const auto * probs = whisper_get_probs(ctx); - std::vector> probs_id; - - double psum = 0.0; - for (int i = 0; i < (int) allowed_commands.size(); ++i) { - probs_id.emplace_back(probs[allowed_tokens[i][0]], i); - for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { - probs_id.back().first += probs[allowed_tokens[i][j]]; - } - probs_id.back().first /= allowed_tokens[i].size(); - psum += probs_id.back().first; - } + // estimate command probability + // NOTE: not optimal + { + const auto * logits = whisper_get_logits(ctx); - // normalize - for (auto & p : probs_id) { - p.first /= psum; - } + std::vector probs(whisper_n_vocab(ctx), 0.0f); - // sort descending - { - using pair_type = decltype(probs_id)::value_type; - std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { - return a.first > b.first; - }); - } + // compute probs from logits via softmax + { + float max = -1e9; + for (int i = 0; i < (int) probs.size(); ++i) { + max = std::max(max, logits[i]); + } - // print the commands and the respective probabilities - { - fprintf(stdout, "\n"); - for (const auto & cmd : probs_id) { - fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); - for (int token : allowed_tokens[cmd.second]) { - fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]); + float sum = 0.0f; + for (int i = 0; i < (int) probs.size(); ++i) { + probs[i] = expf(logits[i] - max); + sum += probs[i]; + } + + for (int i = 0; i < (int) probs.size(); ++i) { + probs[i] /= sum; } + } + + std::vector> probs_id; + + double psum = 0.0; + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + probs_id.emplace_back(probs[allowed_tokens[i][0]], i); + for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { + probs_id.back().first += probs[allowed_tokens[i][j]]; + } + probs_id.back().first /= allowed_tokens[i].size(); + psum += probs_id.back().first; + } + + // normalize + for (auto & p : probs_id) { + p.first /= psum; + } + + // sort descending + { + using pair_type = decltype(probs_id)::value_type; + std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // print the commands and the respective probabilities + { fprintf(stdout, "\n"); + for (const auto & cmd : probs_id) { + fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); + for (int token : allowed_tokens[cmd.second]) { + fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]); + } + fprintf(stdout, "\n"); + } } - } - // best command - { - const auto t_end = std::chrono::high_resolution_clock::now(); + // best command + { + const auto t_end = std::chrono::high_resolution_clock::now(); - const float prob = probs_id[0].first; - const int index = probs_id[0].second; + const float prob = probs_id[0].first; + const int index = probs_id[0].second; - fprintf(stdout, "\n"); - fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, - "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob, - (int) std::chrono::duration_cast(t_end - t_start).count()); - fprintf(stdout, "\n"); + fprintf(stdout, "\n"); + fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, + "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob, + (int) std::chrono::duration_cast(t_end - t_start).count()); + fprintf(stdout, "\n"); + } } audio.clear(); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 48e02923d01..65b06ca516a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -59,8 +59,12 @@ struct whisper_params { int32_t duration_ms = 0; int32_t max_context = -1; int32_t max_len = 0; + int32_t best_of = 5; + int32_t beam_size = -1; - float word_thold = 0.01f; + float word_thold = 0.01f; + float entropy_thold = 2.4f; + float logprob_thold = -1.0f; bool speed_up = false; bool translate = false; @@ -104,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } @@ -136,31 +144,35 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); - fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); - fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); - fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); - fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, "\n"); } @@ -235,7 +247,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi const char * text = whisper_full_get_token_text(ctx, i, j); const float p = whisper_full_get_token_p (ctx, i, j); - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); } @@ -331,20 +343,19 @@ bool output_csv(struct whisper_context * ctx, const char * fname) { const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); - if (text[0] == ' ') - text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. + if (text[0] == ' ') { + text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. + } const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. - fout << 10 * t0 << ", " - << 10 * t1 << ", \"" - << text << "\"\n"; + + //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. + fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n"; } return true; } - // karaoke video generation // outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments @@ -620,6 +631,8 @@ int main(int argc, char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + wparams.print_realtime = false; wparams.print_progress = params.print_progress; wparams.print_timestamps = !params.no_timestamps; @@ -633,12 +646,18 @@ int main(int argc, char ** argv) { wparams.token_timestamps = params.output_wts || params.max_len > 0; wparams.thold_pt = params.word_thold; + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.speed_up = params.speed_up; - wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); - wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + wparams.temperature_inc = -1; + + wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); whisper_print_user_data user_data = { ¶ms, &pcmf32s }; diff --git a/examples/stream.wasm/emscripten.cpp b/examples/stream.wasm/emscripten.cpp index e4cdf639a40..144a14d268f 100644 --- a/examples/stream.wasm/emscripten.cpp +++ b/examples/stream.wasm/emscripten.cpp @@ -49,6 +49,9 @@ void stream_main(size_t index) { wparams.max_tokens = 32; wparams.audio_ctx = 768; // partial encoder context for better performance + // disable temperature fallback + wparams.temperature_inc = -1.0f; + wparams.language = "en"; printf("stream: using %d threads\n", wparams.n_threads); diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 9f0c16c669a..e1251704f5d 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -615,6 +615,9 @@ int main(int argc, char ** argv) { wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; + // disable temperature fallback + wparams.temperature_inc = -1.0f; + wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); diff --git a/examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore b/examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore b/examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj index 9cc09c09b52..cc0afbcae4f 100644 --- a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj +++ b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj @@ -35,10 +35,10 @@ 0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = ""; }; 0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; 0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = ""; }; - 0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = whisper.cpp; path = ../../../whisper.cpp; sourceTree = ""; }; - 0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = whisper.h; path = ../../../whisper.h; sourceTree = ""; }; - 0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = ""; }; - 0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = ""; }; + 0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = whisper.cpp; sourceTree = ""; }; + 0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = whisper.h; sourceTree = ""; }; + 0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = ggml.c; sourceTree = ""; }; + 0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = ""; }; 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = ""; }; 0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = ""; }; /* End PBXFileReference section */ @@ -129,7 +129,8 @@ 0AAC5DC729539EB0003032C3 /* whisper.cpp */, 0AAC5DC829539EB0003032C3 /* whisper.h */, ); - path = whisper.cpp; + name = whisper.cpp; + path = ../..; sourceTree = ""; }; 0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = { diff --git a/whisper.cpp b/whisper.cpp index a64505693f7..c40085675ba 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -15,9 +15,31 @@ #include #include #include +#include + +#define WHISPER_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define WHISPER_DEBUG + +#if defined(WHISPER_DEBUG) +#define WHISPER_PRINT_DEBUG(...) \ + do { \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) +#else +#define WHISPER_PRINT_DEBUG(...) +#endif -#define USE_FLASH_ATTN -//#define USE_FLASH_FF +#define WHISPER_USE_FLASH_ATTN +//#define WHISPER_USE_FLASH_FF +#define WHISPER_MAX_DECODERS 16 // available whisper models enum e_model { @@ -141,12 +163,20 @@ static const std::map MEM_REQ_MODEL = { { MODEL_LARGE, 2952ull*MB }, }; -static const std::map MEM_REQ_MEMORY = { - { MODEL_TINY, 12ull*MB }, - { MODEL_BASE, 24ull*MB }, - { MODEL_SMALL, 70ull*MB }, - { MODEL_MEDIUM, 184ull*MB }, - { MODEL_LARGE, 306ull*MB }, +static const std::map MEM_REQ_KV_SELF = { + { MODEL_TINY, 3ull*MB }, + { MODEL_BASE, 6ull*MB }, + { MODEL_SMALL, 16ull*MB }, + { MODEL_MEDIUM, 43ull*MB }, + { MODEL_LARGE, 71ull*MB }, +}; + +static const std::map MEM_REQ_KV_CROSS = { + { MODEL_TINY, 9ull*MB }, + { MODEL_BASE, 18ull*MB }, + { MODEL_SMALL, 53ull*MB }, + { MODEL_MEDIUM, 141ull*MB }, + { MODEL_LARGE, 235ull*MB }, }; static const std::map MEM_REQ_ENCODE = { @@ -204,10 +234,6 @@ struct whisper_vocab { std::map token_to_id; std::map id_to_token; - // used to avoid memory allocations during sampling - // TODO: move to whisper_context in the future - std::vector> probs_id; - id token_eot = 50256; id token_sot = 50257; id token_prev = 50360; @@ -349,6 +375,17 @@ struct whisper_layer_decoder { struct ggml_tensor * mlp_1_b; }; +struct whisper_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx; + + std::vector buf; + + int n; // number of tokens currently in the cache +}; + struct whisper_model { e_model type = MODEL_UNKNOWN; @@ -371,34 +408,64 @@ struct whisper_model { struct ggml_tensor * e_ln_b; // decoder.positional_embedding - struct ggml_tensor * d_pe; // DD + struct ggml_tensor * d_pe; // decoder.token_embedding - struct ggml_tensor * d_te; // DD + struct ggml_tensor * d_te; // decoder.ln - struct ggml_tensor * d_ln_w; // DD - struct ggml_tensor * d_ln_b; // DD + struct ggml_tensor * d_ln_w; + struct ggml_tensor * d_ln_b; std::vector layers_encoder; std::vector layers_decoder; - // key + value memory - struct ggml_tensor * memory_k; - struct ggml_tensor * memory_v; - - struct ggml_tensor * memory_cross_k; - struct ggml_tensor * memory_cross_v; - // context struct ggml_context * ctx; - struct ggml_context * ctx_mem; + + // the model memory buffer is read-only and can be shared between processors + std::vector * buf; // tensors int n_loaded; std::map tensors; }; +struct whisper_sequence { + std::vector tokens; + + // the accumulated transcription in the current interation (used to truncate the tokens array) + int result_len; + + double sum_logprobs_all; // the sum of the log probabilities of the tokens + double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens) + double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens + double score; // likelihood rank score +}; + +// TAGS: WHISPER_DECODER_INIT +struct whisper_decoder { + // each decoders keeps its own KV-cache + whisper_kv_cache kv_self; + + // the currently generated sequence of tokens + whisper_sequence sequence; + + int seek_delta; // the window shift found so far based on the decoded timestamp tokens + + bool failed; // has the current segment failed to decode? + bool completed; // has the decoder completed the current segment? + bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? + + // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector probs; + std::vector logits; + std::vector logprobs; + + std::vector tokens_tmp; // used for whisper_decode calls +}; + struct whisper_context { int64_t t_load_us = 0; int64_t t_mel_us = 0; @@ -407,24 +474,33 @@ struct whisper_context { int64_t t_decode_us = 0; int64_t t_start_us = 0; - std::vector * buf_model; // the model buffer is read-only and can be shared between processors - std::vector buf_memory; - std::vector buf_compute; - std::vector buf_compute_layer; - ggml_type wtype; // weight type (FP32 or FP16) + whisper_mel mel; + whisper_model model; whisper_vocab vocab; - whisper_mel mel; + // cross-attention KV cache for the decoders + // shared between all decoders + whisper_kv_cache kv_cross; - std::vector probs; + whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; + + // memory buffers used by encode / decode contexts + std::vector buf_compute; + std::vector buf_compute_layer; + + // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; std::vector result_all; + std::vector prompt_past; + + // work container used to avoid memory allocations + std::vector> logits_id; - std::vector prompt_past; + mutable std::mt19937 rng; // used for sampling at t > 0.0 // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; @@ -441,6 +517,72 @@ static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); } +static bool kv_cache_init( + const struct whisper_hparams & hparams, + const size_t mem_bytes, + struct whisper_kv_cache & cache, + ggml_type wtype, + int n_ctx) { + cache.buf.resize(mem_bytes); + + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mem = n_text_layer*n_ctx; + const int n_elements = n_text_state*n_mem; + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static bool kv_cache_reinit(struct whisper_kv_cache & cache) { + WHISPER_ASSERT(cache.ctx); + + const int n_elements = ggml_nelements(cache.k); + WHISPER_ASSERT(n_elements == ggml_nelements(cache.v)); + + const ggml_type wtype = cache.k->type; + WHISPER_ASSERT(wtype == cache.v->type); + + WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype)); + + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static void kv_cache_free(struct whisper_kv_cache & cache) { + if (cache.ctx) { + ggml_free(cache.ctx); + cache.ctx = nullptr; + } +} + // load the model from a ggml file // // file format: @@ -455,6 +597,10 @@ static void read_safe(whisper_model_loader * loader, T & dest) { static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { fprintf(stderr, "%s: loading model\n", __func__); + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + auto & model = wctx.model; auto & vocab = wctx.vocab; @@ -506,6 +652,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.type = e_model::MODEL_LARGE; } + // for the big tensors, we have the option to store the data in 16-bit floats + // in order to save memory and also to speed up the computation + wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + const size_t scale = model.hparams.f16 ? 1 : 2; + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); @@ -519,11 +671,51 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); fprintf(stderr, "%s: type = %d\n", __func__, model.type); - wctx.buf_model = new std::vector(); - wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); - wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); - wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); + // print memory requirements + { + // this is the total memory required to run the inference + const size_t mem_required = + scale*MEM_REQ_MODEL.at (model.type) + + scale*MEM_REQ_KV_CROSS.at (model.type) + + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) + + scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)); + + // this is the memory required by one decoder + const size_t mem_required_decoder = + scale*MEM_REQ_KV_SELF.at(model.type); + + fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); + } + + // initialize all memory buffers + // always have at least one decoder + + wctx.model.buf = new std::vector(); + wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } + + wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); } // load mel filters @@ -607,30 +799,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - vocab.probs_id.reserve(n_vocab); - } + wctx.logits_id.reserve(n_vocab); - { - // this is the total memory required to run the inference - const size_t mem_required = - wctx.buf_model->size() + - wctx.buf_memory.size() + - wctx.buf_compute.size() + - wctx.buf_compute_layer.size(); + // TAGS: WHISPER_DECODER_INIT + wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); - fprintf(stderr, "%s: mem_required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); + wctx.decoders[0].probs.reserve (vocab.n_vocab); + wctx.decoders[0].logits.reserve (vocab.n_vocab); + wctx.decoders[0].logprobs.reserve(vocab.n_vocab); } - // for the big tensors, we have the option to store the data in 16-bit floats - // in order to save memory and also to speed up the computation - wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + size_t ctx_size = 0; const ggml_type wtype = wctx.wtype; - size_t ctx_size = 0; - { const auto & hparams = model.hparams; @@ -738,14 +921,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead - fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); } // create the ggml context { struct ggml_init_params params; - params.mem_size = wctx.buf_model->size(); - params.mem_buffer = wctx.buf_model->data(); + params.mem_size = wctx.model.buf->size(); + params.mem_buffer = wctx.model.buf->data(); model.ctx = ggml_init(params); if (!model.ctx) { @@ -950,56 +1133,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - // create the ggml memory context - { - struct ggml_init_params params; - params.mem_size = wctx.buf_memory.size(); - params.mem_buffer = wctx.buf_memory.data(); - - model.ctx_mem = ggml_init(params); - if (!model.ctx_mem) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - } - - // key + value memory - { - auto & ctx = model.ctx_mem; - - const auto & hparams = model.hparams; - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; - - // key/value memory for the self-attention layer - { - const int n_mem = n_text_layer*n_text_ctx; - const int n_elements = n_text_state*n_mem; - - model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements); - } - - // key/value memory for the cross-attention layer - { - const int n_audio_ctx = hparams.n_audio_ctx; - - const int n_mem = n_text_layer*n_audio_ctx; - const int n_elements = n_text_state*n_mem; - - model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements); - } - - const size_t memory_size = - ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + - ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); - - fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - // load weights { size_t total_size = 0; @@ -1073,6 +1206,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } + wctx.rng = std::mt19937(0); + + wctx.t_load_us = ggml_time_us() - t_start_us; + return true; } @@ -1086,9 +1223,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // - mel_offset: offset in the mel spectrogram (i.e. audio offset) // static bool whisper_encode( - whisper_context & wctx, - const int n_threads, - const int mel_offset) { + whisper_context & wctx, + const int mel_offset, + const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + const auto & model = wctx.model; const auto & mel_inp = wctx.mel; const auto & hparams = model.hparams; @@ -1229,7 +1368,7 @@ static bool whisper_encode( // ------ -#ifdef USE_FLASH_ATTN +#ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = ggml_permute(ctxL, ggml_cpy(ctxL, @@ -1340,7 +1479,7 @@ static bool whisper_encode( ggml_repeat(ctxL, layer.mlp_ln_b, cur)); } -#ifdef USE_FLASH_FF +#ifdef WHISPER_USE_FLASH_FF cur = ggml_flash_ff(ctxL, ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); @@ -1461,10 +1600,10 @@ static bool whisper_encode( Vcross), Vcross); - //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx)); + //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); @@ -1480,6 +1619,8 @@ static bool whisper_encode( ggml_free(ctx0); + wctx.t_encode_us += ggml_time_us() - t_start_us; + return true; } @@ -1494,16 +1635,22 @@ static bool whisper_encode( // - n_past: number of past tokens to prefix the prompt with // static bool whisper_decode( - whisper_context & wctx, - const int n_threads, - const whisper_token * tokens, - const int n_tokens, - const int n_past) { + whisper_context & wctx, + whisper_decoder & decoder, + const whisper_token * tokens, + const int n_tokens, + const int n_past, + const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + const auto & model = wctx.model; const auto & hparams = model.hparams; + auto & kv_self = decoder.kv_self; + + WHISPER_ASSERT(!!kv_self.ctx); + auto & logits_out = wctx.logits; - auto & probs_out = wctx.probs; const int n_vocab = hparams.n_vocab; @@ -1515,6 +1662,8 @@ static bool whisper_decode( const int N = n_tokens; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + struct ggml_init_params params; params.mem_size = wctx.buf_compute.size(); params.mem_buffer = wctx.buf_compute.data(); @@ -1593,8 +1742,8 @@ static bool whisper_decode( // store key and value to memory { - struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); @@ -1612,7 +1761,7 @@ static bool whisper_decode( struct ggml_tensor * K = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state), + ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); @@ -1632,7 +1781,7 @@ static bool whisper_decode( struct ggml_tensor * V_trans = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state), + ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); @@ -1687,12 +1836,12 @@ static bool whisper_decode( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state), + ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * Vcross = ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state), + ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), n_state/n_head, n_head, M); // ------ @@ -1823,25 +1972,18 @@ static bool whisper_decode( struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); - // logits -> probs - cur = ggml_dup(ctx0, logits); - cur = ggml_soft_max(ctx0, cur); // in-place - // run the computation { struct ggml_cgraph gf = {}; gf.n_threads = n_threads; - ggml_build_forward_expand(&gf, cur); + ggml_build_forward_expand(&gf, logits); ggml_graph_compute (ctx0, &gf); } logits_out.resize(N*n_vocab); memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); - probs_out.resize(N*n_vocab); - memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab); - if (N > 1) { //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); @@ -1850,98 +1992,9 @@ static bool whisper_decode( ggml_free(ctx0); - return true; -} - -// the most basic sampling scheme - select the top token -static whisper_token_data whisper_sample_best( - whisper_vocab & vocab, - const float * probs, - bool force_timestamp, - bool is_initial) { - whisper_token_data result = { - 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, - }; - - const int n_logits = vocab.n_vocab; - - auto & probs_id = vocab.probs_id; - - probs_id.clear(); - for (int i = 0; i < n_logits; i++) { - probs_id.emplace_back(probs[i], i); - } - - { - double sum_ts = 0.0; - double max_ts = -1.0; - double max_tx = -1.0; - - for (int i = 0; i < vocab.token_beg; i++) { - max_tx = std::max(max_tx, probs_id[i].first); - } - - const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg; - const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits; - - // the initial timestamp cannot be larger than 100 - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 - if (is_initial) { - for (int i = i0; i < n_logits; ++ i) { - probs_id[i].first = -INFINITY; - } - } - - for (int i = vocab.token_beg; i < i1; i++) { - sum_ts += probs_id[i].first; - if (probs_id[i].first > max_ts) { - max_ts = probs_id[i].first; - result.tid = probs_id[i].second; - } - } - - // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a - // timestamp token - if (sum_ts > max_tx || force_timestamp) { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 - for (int i = 0; i < vocab.token_beg; i++) { - probs_id[i].first = -INFINITY; - } - } - - result.pt = max_ts/(sum_ts + 1e-10); - result.ptsum = sum_ts; - } - - // find the top K tokens - const int top_k = 4; - - std::partial_sort( - probs_id.begin(), - probs_id.begin() + top_k, probs_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); - - probs_id.resize(top_k); - - //printf("\n"); - //for (int i = 0; i < (int) probs_id.size(); i++) { - // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); - //} - - int res = 0; - while ((probs_id[res].second == vocab.token_sot || - probs_id[res].second == vocab.token_solm || - probs_id[res].second == vocab.token_not) && - res < (int) probs_id.size() - 1) { - res++; - } - - result.id = probs_id[res].second; - result.p = probs_id[res].first; + wctx.t_decode_us += ggml_time_us() - t_start_us; - return result; + return true; } // 500 -> 00:05.000 @@ -2043,16 +2096,18 @@ static void fft(const std::vector & in, std::vector & out) { // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 static bool log_mel_spectrogram( - const float * samples, - const int n_samples, - const int /*sample_rate*/, - const int fft_size, - const int fft_step, - const int n_mel, - const int n_threads, - const whisper_filters & filters, - const bool speed_up, - whisper_mel & mel) { + whisper_context & wctx, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int fft_size, + const int fft_step, + const int n_mel, + const int n_threads, + const whisper_filters & filters, + const bool speed_up, + whisper_mel & mel) { + const int64_t t_start_us = ggml_time_us(); // Hanning window std::vector hann; @@ -2161,6 +2216,8 @@ static bool log_mel_spectrogram( mel.data[i] = (mel.data[i] + 4.0)/4.0; } + wctx.t_mel_us += ggml_time_us() - t_start_us; + return true; } @@ -2305,10 +2362,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) { whisper_context * ctx = new whisper_context; - const int64_t t_start_us = ggml_time_us(); - - ctx->t_start_us = t_start_us; - if (!whisper_model_load(loader, *ctx)) { loader->close(loader->context); fprintf(stderr, "%s: failed to load model\n", __func__); @@ -2316,8 +2369,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) { return nullptr; } - ctx->t_load_us = ggml_time_us() - t_start_us; - loader->close(loader->context); return ctx; @@ -2328,40 +2379,37 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.ctx) { ggml_free(ctx->model.ctx); } - if (ctx->model.ctx_mem) { - ggml_free(ctx->model.ctx_mem); + if (ctx->model.buf) { + delete ctx->model.buf; + } + if (ctx->kv_cross.ctx) { + ggml_free(ctx->kv_cross.ctx); } - if (ctx->buf_model) { - delete ctx->buf_model; + for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { + if (ctx->decoders[i].kv_self.ctx) { + ggml_free(ctx->decoders[i].kv_self.ctx); + } } delete ctx; } } int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - const int64_t t_start_us = ggml_time_us(); - - if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } - ctx->t_mel_us = ggml_time_us() - t_start_us; - return 0; } // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - const int64_t t_start_us = ggml_time_us(); - - if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { + if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } - ctx->t_mel_us = ggml_time_us() - t_start_us; - return 0; } @@ -2385,51 +2433,26 @@ int whisper_set_mel( } int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { - const int64_t t_start_us = ggml_time_us(); - - if (!whisper_encode(*ctx, n_threads, offset)) { + if (!whisper_encode(*ctx, offset, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return -1; } - ctx->t_encode_us += ggml_time_us() - t_start_us; - return 0; } int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - const int64_t t_start_us = ggml_time_us(); + // TODO: add selected_decoder_id to context + const int selected_decoder_id = 0; - if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) { + if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } - ctx->t_decode_us += ggml_time_us() - t_start_us; - return 0; } -struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { - const int64_t t_start_sample_us = ggml_time_us(); - - const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - - return res; -} - -struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { - const int64_t t_start_sample_us = ggml_time_us(); - - const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - - return res; -} - int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { const auto res = tokenize(ctx->vocab, text); @@ -2510,34 +2533,39 @@ int whisper_lang_auto_detect( return -7; } - std::vector> probs_id; + auto & logits_id = ctx->logits_id; + logits_id.clear(); + for (const auto & kv : g_lang) { const auto token_lang = whisper_token_lang(ctx, kv.second.first); - probs_id.emplace_back(ctx->probs[token_lang], kv.second.first); + logits_id.emplace_back(ctx->logits[token_lang], kv.second.first); } // sort descending { - using pair_type = decltype(probs_id)::value_type; - std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + using pair_type = std::remove_reference::type::value_type; + std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) { return a.first > b.first; }); } // softmax { - float sum = 0; - for (const auto & kv : probs_id) { - sum += exp(kv.first); + const auto max = logits_id[0].first; + + double sum = 0.0f; + for (auto & kv : logits_id) { + kv.first = exp(kv.first - max); + sum += kv.first; } - for (auto & kv : probs_id) { - kv.first = exp(kv.first) / sum; + for (auto & kv : logits_id) { + kv.first /= sum; } } { - for (const auto & prob : probs_id) { + for (const auto & prob : logits_id) { if (lang_probs) { lang_probs[prob.second] = prob.first; } @@ -2546,7 +2574,7 @@ int whisper_lang_auto_detect( } } - return probs_id[0].second; + return logits_id[0].second; } int whisper_n_len(struct whisper_context * ctx) { @@ -2569,8 +2597,8 @@ int whisper_is_multilingual(struct whisper_context * ctx) { return ctx->vocab.is_multilingual() ? 1 : 0; } -float * whisper_get_probs(struct whisper_context * ctx) { - return ctx->probs.data(); +float * whisper_get_logits(struct whisper_context * ctx) { + return ctx->logits.data(); } const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { @@ -2654,105 +2682,77 @@ const char * whisper_print_system_info(void) { //////////////////////////////////////////////////////////////////////////// struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { - struct whisper_full_params result; + struct whisper_full_params result = { + /*.strategy =*/ WHISPER_SAMPLING_GREEDY, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, + + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + + /*.suppress_blank =*/ true, + + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, + + /*.temperature_inc =*/ 0.2f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, + + /*.greedy =*/ { + /*.best_of =*/ -1, + }, + + /*.beam_search =*/ { + /*.beam_size =*/ -1, + + /*.patience =*/ -1.0f, + }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + }; switch (strategy) { case WHISPER_SAMPLING_GREEDY: { - result = { - /*.strategy =*/ WHISPER_SAMPLING_GREEDY, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - - /*.greedy =*/ { - /*.n_past =*/ 0, - }, - - /*.beam_search =*/ { - /*.n_past =*/ -1, - /*.beam_width =*/ -1, - /*.n_best =*/ -1, - }, - - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, - - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, + result.greedy = { + /*.best_of =*/ 1, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: { - result = { - /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - - /*.greedy =*/ { - /*.n_past =*/ -1, - }, - - /*.beam_search =*/ { - /*.n_past =*/ 0, - /*.beam_width =*/ 10, - /*.n_best =*/ 5, - }, - - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, - - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, + result.beam_search = { + /*.beam_size =*/ 5, + + /*.patience =*/ -1.0f, }; } break; } @@ -2763,15 +2763,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str // forward declarations static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); static void whisper_exp_compute_token_level_timestamps( - struct whisper_context * ctx, - int i_segment, - float thold_pt, - float thold_ptsum); + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum); // wrap the last segment to max_len characters // returns the number of new segments -static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { - auto segment = ctx->result_all.back(); +static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { + auto segment = ctx.result_all.back(); int res = 1; int acc = 0; @@ -2780,34 +2780,34 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { for (int i = 0; i < (int) segment.tokens.size(); i++) { const auto & token = segment.tokens[i]; - if (token.id >= whisper_token_eot(ctx)) { + if (token.id >= whisper_token_eot(&ctx)) { continue; } - const auto txt = whisper_token_to_str(ctx, token.id); + const auto txt = whisper_token_to_str(&ctx, token.id); const int cur = strlen(txt); if (acc + cur > max_len && i > 0) { // split here - ctx->result_all.back().text = std::move(text); - ctx->result_all.back().t1 = token.t0; - ctx->result_all.back().tokens.resize(i); + ctx.result_all.back().text = std::move(text); + ctx.result_all.back().t1 = token.t0; + ctx.result_all.back().tokens.resize(i); - ctx->result_all.push_back({}); - ctx->result_all.back().t0 = token.t0; - ctx->result_all.back().t1 = segment.t1; + ctx.result_all.push_back({}); + ctx.result_all.back().t0 = token.t0; + ctx.result_all.back().t1 = segment.t1; // add tokens [i, end] to the new segment - ctx->result_all.back().tokens.insert( - ctx->result_all.back().tokens.end(), + ctx.result_all.back().tokens.insert( + ctx.result_all.back().tokens.end(), segment.tokens.begin() + i, segment.tokens.end()); acc = 0; text = ""; - segment = ctx->result_all.back(); + segment = ctx.result_all.back(); i = -1; res++; @@ -2817,52 +2817,409 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { } } - ctx->result_all.back().text = std::move(text); + ctx.result_all.back().text = std::move(text); return res; } -int whisper_full( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples) { - // clear old results - auto & result_all = ctx->result_all; - - result_all.clear(); +// process the logits for the selected decoder +// - applies logit filters +// - computes logprobs and probs +static void whisper_process_logits( + const struct whisper_context & ctx, + const struct whisper_full_params params, + struct whisper_decoder & decoder, + float temperature) { + const auto & vocab = ctx.vocab; + const auto & tokens_cur = decoder.sequence.tokens; + + const bool is_initial = tokens_cur.size() == 0; + const int n_logits = vocab.id_to_token.size(); + + WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); + + // extract the logits for the last token + // we will be mutating and therefore we don't want to use the ctx.logits buffer directly + auto & probs = decoder.probs; + auto & logits = decoder.logits; + auto & logprobs = decoder.logprobs; + { + logits.resize(n_logits); + memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); - // compute log mel spectrogram - if (params.speed_up) { - if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); - return -1; - } - } else { - if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); - return -2; + if (temperature > 0.0f) { + for (int i = 0; i < n_logits; i++) { + logits[i] /= temperature; + } } - } - // auto-detect language if not specified - if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { - std::vector probs(whisper_lang_max_id() + 1, 0.0f); + // will be populated a bit later + probs.resize(n_logits); + logprobs.resize(n_logits); + } - const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); - if (lang_id < 0) { - fprintf(stderr, "%s: failed to auto-detect language\n", __func__); - return -3; + // apply logit filters here + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 + { + // suppress blank + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 + if (params.suppress_blank) { + if (is_initial) { + logits[vocab.token_eot] = -INFINITY; + logits[vocab.token_to_id.at(" ")] = -INFINITY; + } } - params.language = whisper_lang_str(lang_id); + // suppress <|notimestamps|> token + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 + logits[vocab.token_not] = -INFINITY; - fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + // suppress sot and solm tokens + logits[vocab.token_sot] = -INFINITY; + logits[vocab.token_solm] = -INFINITY; + + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 + { + const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; + const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + + //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } else { + for (int i = 0; i < vocab.token_eot; ++i) { + logits[i] = -INFINITY; + } + } + } + } + + // the initial timestamp cannot be larger than max_initial_ts + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial && params.max_initial_ts > 0.0f) { + const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; + const int tid0 = std::round(params.max_initial_ts/precision); + + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } + + // if sum of probability over timestamps is above any other token, sample timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 + { + // logsumexp over timestamps + float timestamp_logprob = -INFINITY; + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); + for (int i = vocab.token_beg; i < n_logits; ++i) { + if (logprobs[i] > -INFINITY) { + logsumexp += expf(logprobs[i] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); + + //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + + if (timestamp_logprob > max_text_token_logprob) { + for (int i = 0; i < vocab.token_beg; ++i) { + logits[i] = -INFINITY; + logprobs[i] = -INFINITY; + } + } + } + } + + // compute probs + { + for (int i = 0; i < n_logits; ++i) { + if (logits[i] == -INFINITY) { + probs[i] = 0.0f; + } else { + probs[i] = expf(logprobs[i]); + } + } + } + +#if 0 + // print first 100 logits - token string : logit + for (int i = 0; i < 100; i++) { + const auto token = vocab.id_to_token.at(i); + const auto prob = probs[i]; + const auto logit = logits[i]; + const auto logprob = logprobs[i]; + printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + } + + // "And", "and", " And", " and" + printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); +#endif +} + +static whisper_token_data whisper_sample_token( + const whisper_context & ctx, + const whisper_decoder & decoder, + bool best) { + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; + + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + result.tid = i; + } + } + + result.pt = max_ts/(sum_ts + 1e-10); + result.ptsum = sum_ts; + } + + if (best) { + for (int i = 0; i < n_logits; ++i) { + if (result.p < probs[i]) { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } + } + } else { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + result.id = dist(ctx.rng); + result.p = probs[result.id]; + result.plog = logprobs[result.id]; + } + + if (result.id >= vocab.token_beg) { + result.tid = result.id; + result.pt = result.p; + } + + return result; +} + +static std::vector whisper_sample_token_topk( + whisper_context & ctx, + const whisper_decoder & decoder, + int k) { + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logits = decoder.logits; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + auto & logits_id = ctx.logits_id; + + logits_id.clear(); + for (int i = 0; i < n_logits; ++i) { + logits_id.push_back({ logits[i], i }); + } + + std::partial_sort( + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + std::vector result; + result.reserve(k); + + whisper_token tid; + + float pt; + float ptsum; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + tid = i; + } + } + + pt = max_ts/(sum_ts + 1e-10); + ptsum = sum_ts; + } + + for (int i = 0; i < k; ++i) { + const auto id = logits_id[i].second; + + result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); + + if (result[i].id >= vocab.token_beg) { + result[i].tid = result[i].id; + result[i].pt = result[i].p; + } + } + + return result; +} + +// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 +static void whisper_sequence_score( + const struct whisper_full_params & params, + whisper_sequence & sequence) { + if (sequence.result_len == 0) { + return; + } + + double result = 0.0f; + + for (int i = 0; i < sequence.result_len; ++i) { + result += sequence.tokens[i].plog; + } + + sequence.sum_logprobs = result; + sequence.avg_logprobs = result/sequence.result_len; + + double penalty = sequence.result_len; + + if (params.length_penalty > 0.0f) { + penalty = pow((5.0 + penalty)/6.0, params.length_penalty); + } + + sequence.score = result/penalty; + + // compute the entropy of the sequence of the last 32 tokens + { + const int n = 32; + + int cnt = 0; + double entropy = 0.0f; + + std::map token_counts; + for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) { + token_counts[sequence.tokens[i].id]++; + cnt++; + } + + for (const auto & kv : token_counts) { + const auto p = kv.second/(double)cnt; + entropy -= p*log(p); + + //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + } + + sequence.entropy = entropy; + } +} + +int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples) { + // clear old results + auto & result_all = ctx->result_all; + + result_all.clear(); + + // compute log mel spectrogram + if (params.speed_up) { + if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -1; + } + } else { + if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { + std::vector probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); + if (lang_id < 0) { + fprintf(stderr, "%s: failed to auto-detect language\n", __func__); + return -3; + } + + params.language = whisper_lang_str(lang_id); + + fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); } if (params.token_timestamps) { - ctx->t_beg = 0; - ctx->t_last = 0; + ctx->t_beg = 0; + ctx->t_last = 0; ctx->tid_last = 0; ctx->energy = get_signal_energy(samples, n_samples, 32); } @@ -2877,6 +3234,54 @@ int whisper_full( return 0; } + // a set of temperatures to use + // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] + std::vector temperatures; + if (params.temperature_inc > 0.0f) { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) { + temperatures.push_back(t); + } + } else { + temperatures.push_back(params.temperature); + } + + // initialize the decoders + int n_decoders = 1; + + switch (params.strategy) { + case WHISPER_SAMPLING_GREEDY: + { + n_decoders = params.greedy.best_of; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); + } break; + }; + + n_decoders = std::max(1, n_decoders); + + // TAGS: WHISPER_DECODER_INIT + for (int j = 1; j < n_decoders; j++) { + auto & decoder = ctx->decoders[j]; + + if (decoder.kv_self.ctx == nullptr) { + decoder.kv_self = ctx->decoders[0].kv_self; + if (!kv_cache_reinit(decoder.kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); + return -4; + } + + WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); + + decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); + } + } + // the accumulated text context so far auto & prompt_past = ctx->prompt_past; if (params.no_context) { @@ -2895,7 +3300,7 @@ int whisper_full( // overwrite audio_ctx, max allowed is hparams.n_audio_ctx if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); - return -4; + return -5; } ctx->exp_n_audio_ctx = params.audio_ctx; @@ -2914,14 +3319,31 @@ int whisper_full( int progress_prev = 0; int progress_step = 5; - std::vector tokens_cur; - tokens_cur.reserve(whisper_n_text_ctx(ctx)); + int seek = seek_start; std::vector prompt; prompt.reserve(whisper_n_text_ctx(ctx)); + // beam-search helpers + struct kv_buf { + std::vector k; + std::vector v; + }; + + std::vector kv_bufs; + + struct beam_candidate { + int decoder_idx; + int seek_delta; + + bool has_ts; + + whisper_sequence sequence; + }; + + std::vector beam_candidates; + // main loop - int seek = seek_start; while (true) { const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); while (progress_cur >= progress_prev + progress_step) { @@ -2936,12 +3358,6 @@ int whisper_full( break; } - // if there is a very short audio segment left to process, we remove any past prompt since it tends - // to confuse the decoder and often make it repeat or hallucinate stuff - if (seek > seek_start && seek + 500 >= seek_end) { - prompt_past.clear(); - } - if (params.encoder_begin_callback) { if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); @@ -2950,239 +3366,526 @@ int whisper_full( } // encode audio features starting at offset seek - if (whisper_encode(ctx, seek, params.n_threads) != 0) { + if (!whisper_encode(*ctx, seek, params.n_threads)) { fprintf(stderr, "%s: failed to encode\n", __func__); - return -4; + return -6; + } + + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) { + prompt_past.clear(); } - int n_past = 0; - prompt.clear(); + int best_decoder_id = 0; - // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty()) { - int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + for (int it = 0; it < (int) temperatures.size(); ++it) { + const float t_cur = temperatures[it]; - prompt = { whisper_token_prev(ctx) }; - prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + int n_decoders_cur = 1; - prompt_past.clear(); - prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); - } + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } else { + n_decoders_cur = params.beam_search.beam_size; + } + } break; + }; - prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + n_decoders_cur = std::max(1, n_decoders_cur); - int seek_delta = 100*WHISPER_CHUNK_SIZE; + WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); - // print the prompt - //printf("\n\n"); - //for (int i = 0; i < prompt.size(); i++) { - // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); - //} - //printf("\n\n"); + // TAGS: WHISPER_DECODER_INIT + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; - // the accumulated transcription in the current interation - int result_len = 0; - tokens_cur.clear(); + decoder.kv_self.n = 0; - bool failed = false; - bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? + decoder.sequence.tokens.clear(); + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs_all = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0; + decoder.sequence.score = -INFINITY; - for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { - if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to decode\n", __func__); - return -5; - } + decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; - n_past += prompt.size(); - prompt.clear(); + decoder.failed = false; + decoder.completed = false; + decoder.has_ts = false; + } - // very basic greedy sampling strategy: - // - // - always take the most probable token - // - // more sophisticated sampling strategies could be implemented here, but we keep it simple - // feel free to experiment! - // + // init prompt and kv cache for the current iteration + // run whisper_decoder() only for decoder 0 and copy the results for the other decoders { - const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); + prompt.clear(); - // timestamp token - update sliding window - if (token.id > whisper_token_beg(ctx)) { - const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + // if we have already generated some text, use it as a prompt to condition the next generation + if (!prompt_past.empty() && t_cur > 0.5f) { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); - // do not allow to go back in time - if (has_ts && seek_delta > seek_delta_new && result_len < i) { - break; + prompt = { whisper_token_prev(ctx) }; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + } + + // init new transcription with sot, language (opt) and task tokens + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + + // print the prompt + //WHISPER_PRINT_DEBUG("\n\n"); + //for (int i = 0; i < (int) prompt.size(); i++) { + // WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + //} + //WHISPER_PRINT_DEBUG("\n\n"); + + if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } + + { + const int64_t t_start_sample_us = ggml_time_us(); + + whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); + + ctx->decoders[0].kv_self.n += prompt.size(); + + for (int j = 1; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); + memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); + + decoder.kv_self.n += prompt.size(); + + memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); } - seek_delta = seek_delta_new; - result_len = i + 1; - has_ts = true; + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } + } - // add it to the context - prompt.push_back(token.id); - tokens_cur.push_back(token); + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + const int64_t t_start_sample_us = ggml_time_us(); - //{ - // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); - //} + // store the KV caches of all decoders when doing beam-search + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + kv_bufs.resize(n_decoders_cur); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; - // end of segment - if (token.id == whisper_token_eot(ctx) || // end of text token - (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached - (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached - ) { - if (result_len == 0) { - if (seek + seek_delta + 100 >= seek_end) { - result_len = i + 1; - } else { - failed = true; - break; + if (decoder.completed || decoder.failed) { + continue; } - } - if (params.single_segment) { - result_len = i + 1; - seek_delta = 100*WHISPER_CHUNK_SIZE; + kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k)); + kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v)); + + memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size()); + memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size()); } - break; + beam_candidates.clear(); } - // TESTS: if no tensors are loaded, it means we are running tests - if (ctx->model.n_loaded == 0) { - seek_delta = 100*WHISPER_CHUNK_SIZE; - break; + // generate new sequence candidates for each decoder + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); + beam_candidates.back().sequence.tokens.push_back(token); + beam_candidates.back().sequence.sum_logprobs_all += token.plog; + + //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); + } + } break; + }; } - } - // sometimes, the decoding can get stuck in a repetition loop - // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance - // the sliding window by 1 second - if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { - failed = true; - break; - } - } + // for beam-search, choose the top candidates and update the KV caches + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + std::sort( + beam_candidates.begin(), + beam_candidates.end(), + [](const beam_candidate & a, const beam_candidate & b) { + return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; + }); - if (failed) { - // when we fail to sample timestamp token, retry by clearing the past prompt - // if it fails again, then we advance the window by 1 second - if (!prompt_past.empty()) { - prompt_past.clear(); - } else { - fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__); - seek += 100; - } - continue; - } + int cur_c = 0; - // shrink down to result_len - tokens_cur.resize(result_len); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; - for (const auto & r : tokens_cur) { - prompt_past.push_back(r.id); - } + if (decoder.completed || decoder.failed) { + continue; + } - // store the text from this iteration - if (!tokens_cur.empty()) { - int i0 = 0; - auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + auto & cur = beam_candidates[cur_c++]; - std::string text; + while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { + ++cur_c; + } - for (int i = 0; i < (int) tokens_cur.size(); i++) { - //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, - // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, - // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + decoder.sequence = cur.sequence; + decoder.seek_delta = cur.seek_delta; + decoder.has_ts = cur.has_ts; - if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { - } else { - text += whisper_token_to_str(ctx, tokens_cur[i].id); + memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size()); + memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size()); + + WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", + __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); + } } - if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { - const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); - if (!text.empty()) { - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; - - if (params.print_realtime) { - if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { - printf("%s", text.c_str()); - fflush(stdout); + + // update the decoder state + // - check if the sequence is completed + // - check if the sequence is failed + // - update sliding window based on timestamp tokens + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & has_ts = decoder.has_ts; + auto & failed = decoder.failed; + auto & completed = decoder.completed; + auto & seek_delta = decoder.seek_delta; + auto & result_len = decoder.sequence.result_len; + + { + const auto & token = decoder.sequence.tokens.back(); + + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (has_ts && seek_delta > seek_delta_new && result_len < i) { + failed = true; // TODO: maybe this is not a failure ? + continue; } - } - result_all.push_back({ tt0, tt1, text, {} }); - for (int j = i0; j <= i; j++) { - result_all.back().tokens.push_back(tokens_cur[j]); + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; } - int n_new = 1; +#ifdef WHISPER_DEBUG + { + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + } +#endif - if (params.token_timestamps) { - whisper_exp_compute_token_level_timestamps( - ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) { + if (result_len == 0) { + if (seek + seek_delta + 100 >= seek_end) { + result_len = i + 1; + } else { + failed = true; + continue; + } + } - if (params.max_len > 0) { - n_new = whisper_wrap_segment(ctx, params.max_len); + if (params.single_segment) { + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; } + + completed = true; + continue; } - if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) { + seek_delta = 100*WHISPER_CHUNK_SIZE; + completed = true; + continue; } } - text = ""; - while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { - i++; + + // sometimes, the decoding can get stuck in a repetition loop + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + failed = true; + continue; } - i--; - t0 = t1; - i0 = i + 1; } - } - if (!text.empty()) { - const auto t1 = seek + seek_delta; + // check if all decoders have finished (i.e. completed or failed) + { + bool completed_all = true; - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; - if (params.print_realtime) { - if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { - printf("%s", text.c_str()); - fflush(stdout); + if (decoder.completed || decoder.failed) { + continue; + } + + completed_all = false; + } + + if (completed_all) { + break; } } - result_all.push_back({ tt0, tt1, text, {} }); - for (int j = i0; j < (int) tokens_cur.size(); j++) { - result_all.back().tokens.push_back(tokens_cur[j]); + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + + // obtain logits for the next token + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + decoder.tokens_tmp.resize(1); + decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; + + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); + + if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -8; + } + + { + const int64_t t_start_sample_us = ggml_time_us(); + + whisper_process_logits(*ctx, params, decoder, t_cur); + + ++decoder.kv_self.n; + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } + } + + // rank the resulting sequences and select the best one + { + double best_score = -INFINITY; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.failed) { + continue; + } + + decoder.sequence.tokens.resize(decoder.sequence.result_len); + whisper_sequence_score(params, decoder.sequence); - int n_new = 1; + WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); - if (params.token_timestamps) { - whisper_exp_compute_token_level_timestamps( - ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) { + WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + __func__, j, decoder.sequence.entropy, params.entropy_thold); - if (params.max_len > 0) { - n_new = whisper_wrap_segment(ctx, params.max_len); + decoder.failed = true; + + continue; + } + + if (best_score < decoder.sequence.score) { + best_score = decoder.sequence.score; + best_decoder_id = j; } } - if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + + WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); + } + + // was the decoding successful for the current temperature? + { + bool success = true; + + const auto & decoder = ctx->decoders[best_decoder_id]; + + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { + success = false; + } + + if (success) { + //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + //} + + break; } } + + WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } - seek += seek_delta; + // output results through a user-provided callback + { + const auto & best_decoder = ctx->decoders[best_decoder_id]; + + const auto seek_delta = best_decoder.seek_delta; + const auto result_len = best_decoder.sequence.result_len; + + const auto & tokens_cur = best_decoder.sequence.tokens; + + //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); + + // update prompt_past + prompt_past.clear(); + if (prompt.front() == whisper_token_prev(ctx)) { + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + } + + for (int i = 0; i < result_len; ++i) { + prompt_past.push_back(tokens_cur[i].id); + } + + // store the text from this iteration + if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { + int i0 = 0; + auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + + std::string text; + + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { + } else { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { + const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); + + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + } + } + text = ""; + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + i++; + } + i--; + t0 = t1; + i0 = i + 1; + } + } + + if (!text.empty()) { + const auto t1 = seek + seek_delta; + + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + } + } + } + + // update audio window + seek += seek_delta; + + WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); + } } return 0; @@ -3204,52 +3907,31 @@ int whisper_full_parallel( std::vector ctxs(n_processors - 1); for (int i = 0; i < n_processors - 1; ++i) { - ctxs[i] = *ctx; - - auto & model = ctxs[i].model; - - // create the ggml memory context - { - struct ggml_init_params params; - params.mem_size = ctxs[i].buf_memory.size(); - params.mem_buffer = ctxs[i].buf_memory.data(); + auto & ctx_p = ctxs[i]; - model.ctx_mem = ggml_init(params); - if (!model.ctx_mem) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - } + ctx_p = *ctx; - // separate key + value memory for each processor - { - auto & mctx = model.ctx_mem; - - const auto & hparams = model.hparams; + ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; + ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab); - // key/value memory for the self-attention layer - { - const int n_mem = n_text_layer*n_text_ctx; - const int n_elements = n_text_state*n_mem; + if (!kv_cache_reinit(ctx_p.kv_cross)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); + return false; + } - model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); + // TAGS: WHISPER_DECODER_INIT + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); + return false; } - // key/value memory for the cross-attention layer - { - const int n_audio_ctx = hparams.n_audio_ctx; - - const int n_mem = n_text_layer*n_audio_ctx; - const int n_elements = n_text_state*n_mem; + ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); - model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - } + ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab); + ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); + ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); } } @@ -3314,6 +3996,12 @@ int whisper_full_parallel( ctx->t_sample_us += ctxs[i].t_sample_us; ctx->t_encode_us += ctxs[i].t_encode_us; ctx->t_decode_us += ctxs[i].t_decode_us; + + kv_cache_free(ctx->kv_cross); + + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + kv_cache_free(ctx->decoders[j].kv_self); + } } // average the timings @@ -3438,14 +4126,14 @@ static std::vector get_signal_energy(const float * signal, int n_samples, } static void whisper_exp_compute_token_level_timestamps( - struct whisper_context * ctx, - int i_segment, - float thold_pt, - float thold_ptsum) { - auto & segment = ctx->result_all[i_segment]; + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = ctx.result_all[i_segment]; auto & tokens = segment.tokens; - const int n_samples = ctx->energy.size(); + const int n_samples = ctx.energy.size(); if (n_samples == 0) { fprintf(stderr, "%s: no signal data available\n", __func__); @@ -3468,28 +4156,28 @@ static void whisper_exp_compute_token_level_timestamps( return; } - auto & t_beg = ctx->t_beg; - auto & t_last = ctx->t_last; - auto & tid_last = ctx->tid_last; + auto & t_beg = ctx.t_beg; + auto & t_last = ctx.t_last; + auto & tid_last = ctx.tid_last; for (int j = 0; j < n; ++j) { auto & token = tokens[j]; if (j == 0) { - if (token.id == whisper_token_beg(ctx)) { + if (token.id == whisper_token_beg(&ctx)) { tokens[j ].t0 = t0; tokens[j ].t1 = t0; tokens[j + 1].t0 = t0; t_beg = t0; t_last = t0; - tid_last = whisper_token_beg(ctx); + tid_last = whisper_token_beg(&ctx); } else { tokens[j ].t0 = t_last; } } - const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx)); tokens[j].id = token.id; tokens[j].tid = token.tid; @@ -3497,7 +4185,7 @@ static void whisper_exp_compute_token_level_timestamps( tokens[j].pt = token.pt; tokens[j].ptsum = token.ptsum; - tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id)); + tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { if (j > 0) { @@ -3529,6 +4217,8 @@ static void whisper_exp_compute_token_level_timestamps( p1--; } + //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); + if (p1 > p0) { double psum = 0.0; for (int j = p0; j <= p1; j++) { @@ -3576,7 +4266,7 @@ static void whisper_exp_compute_token_level_timestamps( const int hw = WHISPER_SAMPLE_RATE/8; for (int j = 0; j < n; j++) { - if (tokens[j].id >= whisper_token_eot(ctx)) { + if (tokens[j].id >= whisper_token_eot(&ctx)) { continue; } @@ -3591,15 +4281,15 @@ static void whisper_exp_compute_token_level_timestamps( float sum = 0.0f; for (int k = ss0; k < ss1; k++) { - sum += ctx->energy[k]; + sum += ctx.energy[k]; } const float thold = 0.5*sum/ns; { int k = s0; - if (ctx->energy[k] > thold && j > 0) { - while (k > 0 && ctx->energy[k] > thold) { + if (ctx.energy[k] > thold && j > 0) { + while (k > 0 && ctx.energy[k] > thold) { k--; } tokens[j].t0 = sample_to_timestamp(k); @@ -3609,7 +4299,7 @@ static void whisper_exp_compute_token_level_timestamps( s0 = k; } } else { - while (ctx->energy[k] < thold && k < s1) { + while (ctx.energy[k] < thold && k < s1) { k++; } s0 = k; @@ -3619,8 +4309,8 @@ static void whisper_exp_compute_token_level_timestamps( { int k = s1; - if (ctx->energy[k] > thold) { - while (k < n_samples - 1 && ctx->energy[k] > thold) { + if (ctx.energy[k] > thold) { + while (k < n_samples - 1 && ctx.energy[k] > thold) { k++; } tokens[j].t1 = sample_to_timestamp(k); @@ -3630,7 +4320,7 @@ static void whisper_exp_compute_token_level_timestamps( s1 = k; } } else { - while (ctx->energy[k] < thold && k > s0) { + while (ctx.energy[k] < thold && k > s0) { k--; } s1 = k; @@ -3657,11 +4347,11 @@ static void whisper_exp_compute_token_level_timestamps( // debug info //for (int j = 0; j < n; ++j) { // const auto & token = tokens[j]; - // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, - // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id)); + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id)); - // if (tokens[j].id >= whisper_token_eot(ctx)) { + // if (tokens[j].id >= whisper_token_eot(&ctx)) { // continue; // } //} diff --git a/whisper.h b/whisper.h index 63f61af5114..84504b7b23f 100644 --- a/whisper.h +++ b/whisper.h @@ -74,6 +74,7 @@ extern "C" { whisper_token tid; // forced timestamp token id float p; // probability of the token + float plog; // log probability of the token float pt; // probability of the timestamp token float ptsum; // sum of probabilities of all timestamp tokens @@ -136,6 +137,7 @@ extern "C" { // tokens + n_tokens is the provided context for the decoder. // n_past is the number of tokens to use from previous decoder calls. // Returns 0 on success + // TODO: add support for multiple decoders WHISPER_API int whisper_decode( struct whisper_context * ctx, const whisper_token * tokens, @@ -143,14 +145,6 @@ extern "C" { int n_past, int n_threads); - // Token sampling methods. - // These are provided for convenience and can be used after each call to whisper_decode(). - // You can also implement your own sampling method using the whisper_get_probs() function. - // whisper_sample_best() returns the token with the highest probability - // whisper_sample_timestamp() returns the most probable timestamp token - WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); - WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); - // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens @@ -192,8 +186,11 @@ extern "C" { WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); - // The probabilities for the next token - WHISPER_API float * whisper_get_probs(struct whisper_context * ctx); + // Token logits obtained from the last call to whisper_decode() + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); // Token Id -> String. Uses the vocabulary in the provided context WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); @@ -222,8 +219,8 @@ extern "C" { // Available sampling strategies enum whisper_sampling_strategy { - WHISPER_SAMPLING_GREEDY, // Always select the most probable token - WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! + WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder + WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder }; // Text segment callback @@ -243,17 +240,17 @@ extern "C" { enum whisper_sampling_strategy strategy; int n_threads; - int n_max_text_ctx; + int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder int offset_ms; // start offset in ms int duration_ms; // audio duration to process in ms bool translate; - bool no_context; + bool no_context; // do not use initial prompt for the decoder (if any) bool single_segment; // force single segment output (useful for streaming) - bool print_special; - bool print_progress; - bool print_realtime; - bool print_timestamps; + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime // [EXPERIMENTAL] token-level timestamps bool token_timestamps; // enable token-level timestamps @@ -263,10 +260,11 @@ extern "C" { int max_tokens; // max tokens per segment (0 = no limit) // [EXPERIMENTAL] speed-up techniques + // note: these can significantly reduce the quality of the output bool speed_up; // speed-up the audio by 2x using Phase Vocoder int audio_ctx; // overwrite the audio context size (0 = use default) - // tokens to provide the whisper model as initial prompt + // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call const whisper_token * prompt_tokens; int prompt_n_tokens; @@ -274,19 +272,35 @@ extern "C" { // for auto-detection, set to nullptr, "" or "auto" const char * language; + // common decoding parameters: + bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + + float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 + float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 + + // fallback parameters + // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + float temperature_inc; + float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" + float logprob_thold; + float no_speech_thold; // TODO: not implemented + struct { - int n_past; + int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 } greedy; struct { - int n_past; - int beam_width; - int n_best; + int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 + + float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf } beam_search; + // called for every newly generated text segment whisper_new_segment_callback new_segment_callback; void * new_segment_callback_user_data; + // called each time before the encoder starts whisper_encoder_begin_callback encoder_begin_callback; void * encoder_begin_callback_user_data; };