From 514cd044524f88e11d7061bd377110c696b2e978 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 14 Apr 2023 19:16:34 +0300 Subject: [PATCH] whisper : fix bug in prompt processing (close #705) Was dereferencing a dangling pointer --- examples/main/main.cpp | 4 ++-- whisper.cpp | 44 ++++++++++++++++++++++-------------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7b2885c7339..3f1caf28832 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -208,8 +208,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper std::string speaker = ""; - int64_t t0; - int64_t t1; + int64_t t0 = 0; + int64_t t1 = 0; // print the last n_new segments const int s0 = n_segments - n_new; diff --git a/whisper.cpp b/whisper.cpp index 24b9e5d8bd0..3c9bdc45402 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1260,12 +1260,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con break; } - int64_t nelements = 1; - int64_t ne[3] = { 1, 1, 1 }; + int32_t nelements = 1; + int32_t ne[3] = { 1, 1, 1 }; for (int i = 0; i < n_dims; ++i) { - int32_t ne_cur; - read_safe(loader, ne_cur); - ne[i] = ne_cur; + read_safe(loader, ne[i]); nelements *= ne[i]; } @@ -1286,15 +1284,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%lld, %lld, %lld], expected [%lld, %lld, %lld]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]); + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); return false; } const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); if (nelements*bpe != ggml_nbytes(tensor)) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %llu\n", + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); return false; } @@ -3819,22 +3817,26 @@ int whisper_full_with_state( prompt_past.clear(); } - // initial prompt - if (!params.prompt_tokens && params.initial_prompt) { + // prepare prompt + { std::vector prompt_tokens; - prompt_tokens.resize(1024); - prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size())); - params.prompt_tokens = prompt_tokens.data(); - params.prompt_n_tokens = prompt_tokens.size(); - } - // prepend the prompt tokens to the prompt_past - if (params.prompt_tokens && params.prompt_n_tokens > 0) { - // parse tokens from the pointer - for (int i = 0; i < params.prompt_n_tokens; i++) { - prompt_past.push_back(params.prompt_tokens[i]); + // initial prompt + if (!params.prompt_tokens && params.initial_prompt) { + prompt_tokens.resize(1024); + prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size())); + params.prompt_tokens = prompt_tokens.data(); + params.prompt_n_tokens = prompt_tokens.size(); + } + + // prepend the prompt tokens to the prompt_past + if (params.prompt_tokens && params.prompt_n_tokens > 0) { + // parse tokens from the pointer + for (int i = 0; i < params.prompt_n_tokens; i++) { + prompt_past.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); } - std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); } // overwrite audio_ctx, max allowed is hparams.n_audio_ctx