diff --git a/bindings/go/context_params.go b/bindings/go/context_params.go new file mode 100644 index 00000000000..cfc6dbbf9f2 --- /dev/null +++ b/bindings/go/context_params.go @@ -0,0 +1,17 @@ +package whisper + +func (p *ContextParams) UseGPU() bool { + return bool(p.use_gpu) +} + +func (p *ContextParams) SetUseGPU(v bool) { + p.use_gpu = toBool(v) +} + +func (p *ContextParams) UseFlashAttention() bool { + return bool(p.flash_attn) +} + +func (p *ContextParams) SetUseFlashAttention(v bool) { + p.flash_attn = toBool(v) +} diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go index 68a150223c7..7de2ff69e97 100644 --- a/bindings/go/pkg/whisper/model.go +++ b/bindings/go/pkg/whisper/model.go @@ -23,19 +23,21 @@ var _ Model = (*model)(nil) /////////////////////////////////////////////////////////////////////////////// // LIFECYCLE -func New(path string) (Model, error) { - model := new(model) +func New(path string, options ...modelOption) (Model, error) { if _, err := os.Stat(path); err != nil { return nil, err - } else if ctx := whisper.Whisper_init(path); ctx == nil { - return nil, ErrUnableToLoadModel - } else { - model.ctx = ctx - model.path = path } - // Return success - return model, nil + params := whisper.DefaultContextParams() + for _, option := range options { + option.apply(¶ms) + } + + if ctx := whisper.Whisper_init_with_params(path, params); ctx != nil { + return &model{path, ctx}, nil + } + + return nil, ErrUnableToLoadModel } func (model *model) Close() error { diff --git a/bindings/go/pkg/whisper/model_option.go b/bindings/go/pkg/whisper/model_option.go new file mode 100644 index 00000000000..2e630545d61 --- /dev/null +++ b/bindings/go/pkg/whisper/model_option.go @@ -0,0 +1,26 @@ +package whisper + +import whisper "github.com/ggerganov/whisper.cpp/bindings/go" + +type ContextParams = whisper.ContextParams + +type ( + modelOption interface{ apply(*ContextParams) } + modelOptionFunc func(*ContextParams) +) + +func (fn modelOptionFunc) apply(to *ContextParams) { + fn(to) +} + +func WithUseGPU(v bool) modelOption { + return modelOptionFunc(func(p *ContextParams) { + p.SetUseGPU(v) + }) +} + +func WithUseFlashAttention(v bool) modelOption { + return modelOptionFunc(func(p *ContextParams) { + p.SetUseFlashAttention(v) + }) +} diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 3ef73414d90..8081f553051 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -71,6 +71,7 @@ type ( TokenData C.struct_whisper_token_data SamplingStrategy C.enum_whisper_sampling_strategy Params C.struct_whisper_full_params + ContextParams C.struct_whisper_context_params ) /////////////////////////////////////////////////////////////////////////////// @@ -102,15 +103,23 @@ var ( // Allocates all memory needed for the model and loads the model from the given file. // Returns NULL on failure. func Whisper_init(path string) *Context { + return Whisper_init_with_params(path, DefaultContextParams()) +} + +func Whisper_init_with_params(path string, params ContextParams) *Context { cPath := C.CString(path) defer C.free(unsafe.Pointer(cPath)) - if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil { + if ctx := C.whisper_init_from_file_with_params(cPath, (C.struct_whisper_context_params)(params)); ctx != nil { return (*Context)(ctx) } else { return nil } } +func DefaultContextParams() ContextParams { + return ContextParams(C.whisper_context_default_params()) +} + // Frees all memory allocated by the model. func (ctx *Context) Whisper_free() { C.whisper_free((*C.struct_whisper_context)(ctx))