Skip to content

Commit

Permalink
whisper : add progress callback (ggerganov#600)
Browse files Browse the repository at this point in the history
  • Loading branch information
pajowu authored Mar 30, 2023
1 parent 859ffc9 commit 0a2d121
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
10 changes: 10 additions & 0 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3152,6 +3152,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,

/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,

/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,

Expand Down Expand Up @@ -3868,6 +3871,10 @@ int whisper_full_with_state(
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
}
}
if (params.progress_callback) {
params.progress_callback(
ctx, ctx->state, progress_prev, params.progress_callback_user_data);
}

// of only 1 second left, then stop
if (seek + 100 >= seek_end) {
Expand Down Expand Up @@ -4456,6 +4463,9 @@ int whisper_full_parallel(
params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr;

params_cur.progress_callback = nullptr;
params_cur.progress_callback_user_data = nullptr;

workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
}

Expand Down
7 changes: 7 additions & 0 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ extern "C" {
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);

// Progress callback
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);

// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
Expand Down Expand Up @@ -392,6 +395,10 @@ extern "C" {
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;

// called on each progress update
whisper_progress_callback progress_callback;
void * progress_callback_user_data;

// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
Expand Down

0 comments on commit 0a2d121

Please sign in to comment.