Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions bindings/go/context_params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package whisper

func (p *ContextParams) UseGPU() bool {
return bool(p.use_gpu)
}

func (p *ContextParams) SetUseGPU(v bool) {
p.use_gpu = toBool(v)
}

func (p *ContextParams) UseFlashAttention() bool {
return bool(p.flash_attn)
}

func (p *ContextParams) SetUseFlashAttention(v bool) {
p.flash_attn = toBool(v)
}
20 changes: 11 additions & 9 deletions bindings/go/pkg/whisper/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,21 @@ var _ Model = (*model)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE

func New(path string) (Model, error) {
model := new(model)
func New(path string, options ...modelOption) (Model, error) {
if _, err := os.Stat(path); err != nil {
return nil, err
} else if ctx := whisper.Whisper_init(path); ctx == nil {
return nil, ErrUnableToLoadModel
} else {
model.ctx = ctx
model.path = path
}

// Return success
return model, nil
params := whisper.DefaultContextParams()
for _, option := range options {
option.apply(&params)
}

if ctx := whisper.Whisper_init_with_params(path, params); ctx != nil {
return &model{path, ctx}, nil
}

return nil, ErrUnableToLoadModel
}

func (model *model) Close() error {
Expand Down
26 changes: 26 additions & 0 deletions bindings/go/pkg/whisper/model_option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package whisper

import whisper "github.com/ggerganov/whisper.cpp/bindings/go"

type ContextParams = whisper.ContextParams

type (
modelOption interface{ apply(*ContextParams) }
modelOptionFunc func(*ContextParams)
)

func (fn modelOptionFunc) apply(to *ContextParams) {
fn(to)
}

func WithUseGPU(v bool) modelOption {
return modelOptionFunc(func(p *ContextParams) {
p.SetUseGPU(v)
})
}

func WithUseFlashAttention(v bool) modelOption {
return modelOptionFunc(func(p *ContextParams) {
p.SetUseFlashAttention(v)
})
}
11 changes: 10 additions & 1 deletion bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type (
TokenData C.struct_whisper_token_data
SamplingStrategy C.enum_whisper_sampling_strategy
Params C.struct_whisper_full_params
ContextParams C.struct_whisper_context_params
)

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -102,15 +103,23 @@ var (
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
func Whisper_init(path string) *Context {
return Whisper_init_with_params(path, DefaultContextParams())
}

func Whisper_init_with_params(path string, params ContextParams) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
if ctx := C.whisper_init_from_file_with_params(cPath, (C.struct_whisper_context_params)(params)); ctx != nil {
return (*Context)(ctx)
} else {
return nil
}
}

func DefaultContextParams() ContextParams {
return ContextParams(C.whisper_context_default_params())
}

// Frees all memory allocated by the model.
func (ctx *Context) Whisper_free() {
C.whisper_free((*C.struct_whisper_context)(ctx))
Expand Down
Loading