From 0a2d1210bcb98978214bbf4e100922a413afd39d Mon Sep 17 00:00:00 2001 From: pajowu Date: Thu, 30 Mar 2023 19:29:29 +0200 Subject: [PATCH] whisper : add progress callback (#600) --- whisper.cpp | 10 ++++++++++ whisper.h | 7 +++++++ 2 files changed, 17 insertions(+) diff --git a/whisper.cpp b/whisper.cpp index 13e11141cc9..95b6d33905d 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3152,6 +3152,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.new_segment_callback =*/ nullptr, /*.new_segment_callback_user_data =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.encoder_begin_callback =*/ nullptr, /*.encoder_begin_callback_user_data =*/ nullptr, @@ -3868,6 +3871,10 @@ int whisper_full_with_state( fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev); } } + if (params.progress_callback) { + params.progress_callback( + ctx, ctx->state, progress_prev, params.progress_callback_user_data); + } // of only 1 second left, then stop if (seek + 100 >= seek_end) { @@ -4456,6 +4463,9 @@ int whisper_full_parallel( params_cur.new_segment_callback = nullptr; params_cur.new_segment_callback_user_data = nullptr; + params_cur.progress_callback = nullptr; + params_cur.progress_callback_user_data = nullptr; + workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); } diff --git a/whisper.h b/whisper.h index fa6bff4fc8d..a96c96c927e 100644 --- a/whisper.h +++ b/whisper.h @@ -306,6 +306,9 @@ extern "C" { // Use the whisper_full_...() functions to obtain the text segments typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); + // Progress callback + typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data); + // Encoder begin callback // If not NULL, called before the encoder starts // If it returns false, the computation is aborted @@ -392,6 +395,10 @@ extern "C" { whisper_new_segment_callback new_segment_callback; void * new_segment_callback_user_data; + // called on each progress update + whisper_progress_callback progress_callback; + void * progress_callback_user_data; + // called each time before the encoder starts whisper_encoder_begin_callback encoder_begin_callback; void * encoder_begin_callback_user_data;