Skip to content

Commit

Permalink
go : exposed various parts to the Go Interface (ggerganov#697)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmurray authored Apr 14, 2023
1 parent 463e463 commit 6704a81
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
4 changes: 4 additions & 0 deletions bindings/go/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n)
}

func (p *Params) SetTokenTimestamps(b bool) {
p.token_timestamps = toBool(b)
}

// Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)
Expand Down
15 changes: 12 additions & 3 deletions bindings/go/pkg/whisper/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n))
}

// Set token timestamps flag
func (context *context) SetTokenTimestamps(b bool) {
context.params.SetTokenTimestamps(b)
}

// Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
Expand Down Expand Up @@ -280,10 +285,14 @@ func toSegment(ctx *whisper.Context, n int) Segment {
func toTokens(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ {
data := ctx.Whisper_full_get_token_data(n, i)

result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
P: ctx.Whisper_full_get_token_p(n, i),
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: ctx.Whisper_full_get_token_text(n, i),
P: ctx.Whisper_full_get_token_p(n, i),
Start: time.Duration(data.T0()) * time.Millisecond * 10,
End: time.Duration(data.T1()) * time.Millisecond * 10,
}
}
return result
Expand Down
8 changes: 5 additions & 3 deletions bindings/go/pkg/whisper/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type Context interface {
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetTokenTimestamps(bool) // Set token timestamps flag
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)

// Process mono audio data and return any errors.
Expand Down Expand Up @@ -85,7 +86,8 @@ type Segment struct {

// Token is a text or special token
type Token struct {
Id int
Text string
P float32
Id int
Text string
P float32
Start, End time.Duration
}
10 changes: 9 additions & 1 deletion bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {

// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}

Expand Down Expand Up @@ -407,3 +407,11 @@ func callEncoderBegin(user_data unsafe.Pointer) C.bool {
}
return true
}

func (t TokenData) T0() int64 {
return int64(t.t0)
}

func (t TokenData) T1() int64 {
return int64(t.t1)
}

0 comments on commit 6704a81

Please sign in to comment.