Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bindings/go/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (p *Params) SetPrintTimestamps(v bool) {
p.print_timestamps = toBool(v)
}


// Set language id
func (p *Params) SetLanguage(lang int) error {
if lang == -1 {
Expand Down Expand Up @@ -146,6 +147,10 @@ func (p *Params) SetInitialPrompt(prompt string) {
p.initial_prompt = C.CString(prompt)
}

func (p *Params) SetCarryInitialPrompt(v bool) {
p.carry_initial_prompt = toBool(v)
}

///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

Expand Down Expand Up @@ -199,6 +204,9 @@ func (p *Params) String() string {
if p.token_timestamps {
str += " token_timestamps"
}
if p.carry_initial_prompt {
str += " carry_initial_prompt"
}

return str + ">"
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ public void tdrzEnable(boolean enable) {
/** Tokens to provide to the whisper decoder as an initial prompt.
* These are prepended to any existing text context from a previous call. */
public String initial_prompt;
/** Always prepend initial_prompt for every decode chunk. */
public CBool carry_initial_prompt;

/** Prompt tokens. (int*) */
public Pointer prompt_tokens;
Expand Down Expand Up @@ -336,8 +338,8 @@ protected List<String> getFieldOrder() {
"no_timestamps", "single_segment", "print_special",
"print_progress", "print_realtime", "print_timestamps",
"token_timestamps", "thold_pt", "thold_ptsum", "max_len",
"split_on_word", "max_tokens", "debug_mode", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt",
"split_on_word", "max_tokens", "debug_mode", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt", "carry_initial_prompt",
"prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_nst", "temperature",
"max_initial_ts", "length_penalty", "temperature_inc",
Expand Down
69 changes: 46 additions & 23 deletions bindings/ruby/ext/ruby_whisper_params.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);

#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37

extern VALUE cParams;
extern VALUE cVADParams;
Expand All @@ -46,6 +46,7 @@ static ID id_print_special;
static ID id_print_progress;
static ID id_print_realtime;
static ID id_print_timestamps;
static ID id_carry_initial_prompt;
static ID id_suppress_blank;
static ID id_suppress_nst;
static ID id_token_timestamps;
Expand Down Expand Up @@ -455,6 +456,26 @@ ruby_whisper_params_get_print_timestamps(VALUE self)
{
BOOL_PARAMS_GETTER(self, print_timestamps)
}

/*
* call-seq:
* carry_initial_prompt -> true or false
*/
static VALUE
ruby_whisper_params_get_carry_initial_prompt(VALUE self)
{
BOOL_PARAMS_GETTER(self, carry_initial_prompt)
}

/*
* call-seq:
* carry_initial_prompt = bool -> bool
*/
static VALUE
ruby_whisper_params_set_carry_initial_prompt(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, carry_initial_prompt, value)
}
/*
* call-seq:
* suppress_blank = force_suppress -> force_suppress
Expand Down Expand Up @@ -1168,6 +1189,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(max_len)
SET_PARAM_IF_SAME(split_on_word)
SET_PARAM_IF_SAME(initial_prompt)
SET_PARAM_IF_SAME(carry_initial_prompt)
SET_PARAM_IF_SAME(offset)
SET_PARAM_IF_SAME(duration)
SET_PARAM_IF_SAME(max_text_tokens)
Expand Down Expand Up @@ -1303,28 +1325,29 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(max_len, 11)
DEFINE_PARAM(split_on_word, 12)
DEFINE_PARAM(initial_prompt, 13)
DEFINE_PARAM(diarize, 14)
DEFINE_PARAM(offset, 15)
DEFINE_PARAM(duration, 16)
DEFINE_PARAM(max_text_tokens, 17)
DEFINE_PARAM(temperature, 18)
DEFINE_PARAM(max_initial_ts, 19)
DEFINE_PARAM(length_penalty, 20)
DEFINE_PARAM(temperature_inc, 21)
DEFINE_PARAM(entropy_thold, 22)
DEFINE_PARAM(logprob_thold, 23)
DEFINE_PARAM(no_speech_thold, 24)
DEFINE_PARAM(new_segment_callback, 25)
DEFINE_PARAM(new_segment_callback_user_data, 26)
DEFINE_PARAM(progress_callback, 27)
DEFINE_PARAM(progress_callback_user_data, 28)
DEFINE_PARAM(encoder_begin_callback, 29)
DEFINE_PARAM(encoder_begin_callback_user_data, 30)
DEFINE_PARAM(abort_callback, 31)
DEFINE_PARAM(abort_callback_user_data, 32)
DEFINE_PARAM(vad, 33)
DEFINE_PARAM(vad_model_path, 34)
DEFINE_PARAM(vad_params, 35)
DEFINE_PARAM(carry_initial_prompt, 14)
DEFINE_PARAM(diarize, 15)
DEFINE_PARAM(offset, 16)
DEFINE_PARAM(duration, 17)
DEFINE_PARAM(max_text_tokens, 18)
DEFINE_PARAM(temperature, 19)
DEFINE_PARAM(max_initial_ts, 20)
DEFINE_PARAM(length_penalty, 21)
DEFINE_PARAM(temperature_inc, 22)
DEFINE_PARAM(entropy_thold, 23)
DEFINE_PARAM(logprob_thold, 24)
DEFINE_PARAM(no_speech_thold, 25)
DEFINE_PARAM(new_segment_callback, 26)
DEFINE_PARAM(new_segment_callback_user_data, 27)
DEFINE_PARAM(progress_callback, 28)
DEFINE_PARAM(progress_callback_user_data, 29)
DEFINE_PARAM(encoder_begin_callback, 30)
DEFINE_PARAM(encoder_begin_callback_user_data, 31)
DEFINE_PARAM(abort_callback, 32)
DEFINE_PARAM(abort_callback_user_data, 33)
DEFINE_PARAM(vad, 34)
DEFINE_PARAM(vad_model_path, 35)
DEFINE_PARAM(vad_params, 36)

rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
Expand Down
3 changes: 3 additions & 0 deletions bindings/ruby/sig/whisper.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ module Whisper
?max_len: Integer,
?split_on_word: boolish,
?initial_prompt: string | nil,
?carry_initial_prompt: boolish,
?diarize: boolish,
?offset: Integer,
?duration: Integer,
Expand Down Expand Up @@ -236,13 +237,15 @@ module Whisper
def split_on_word: () -> (true | false)

def initial_prompt=: (_ToS) -> _ToS
def carry_initial_prompt=: (boolish) -> boolish

# Tokens to provide to the whisper decoder as initial prompt
# these are prepended to any existing text context from a previous call
# use whisper_tokenize() to convert text to tokens.
# Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
#
def initial_prompt: () -> (String | nil)
def carry_initial_prompt: () -> (true | false)

def diarize=: (boolish) -> boolish

Expand Down
8 changes: 8 additions & 0 deletions bindings/ruby/test/test_params.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class TestParams < TestBase
:max_len,
:split_on_word,
:initial_prompt,
:carry_initial_prompt,
:diarize,
:offset,
:duration,
Expand Down Expand Up @@ -119,6 +120,13 @@ def test_print_timestamps
assert [email protected]_timestamps
end

def test_carry_initial_prompt
@params.carry_initial_prompt = true
assert @params.carry_initial_prompt
@params.carry_initial_prompt = false
assert [email protected]_initial_prompt
end

def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
Expand Down
Loading