Skip to content

Commit

Permalink
fix: GH Actions BGESmall (#3)
Browse files Browse the repository at this point in the history
* fix: bge-small actions

* fix: handle max_model_length overflow

* chore: log maxLen

* fix: math.Abs() on the overflown val

* fix: coercing to int, then abs()

* fix: min(MaxInt32, model_max_length)
  • Loading branch information
Anush008 authored Oct 8, 2023
1 parent 8d48594 commit 1aa5192
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
5 changes: 4 additions & 1 deletion fastembed.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,10 @@ func loadTokenizer(modelPath string, maxLength int) (*tokenizer.Tokenizer, error
return nil, err
}

maxLength = min(maxLength, int(tokenizerConfig["model_max_length"].(float64)))
// Handle overflow when coercing to int, major hassle.
modelMaxLen := int(math.Min(float64(math.MaxInt32), math.Abs(tokenizerConfig["model_max_length"].(float64))))
maxLength = min(maxLength, modelMaxLen)

tknzer.WithTruncation(&tokenizer.TruncationParams{
MaxLength: maxLength,
Strategy: tokenizer.LongestFirst,
Expand Down
41 changes: 20 additions & 21 deletions fastembed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,30 @@ func TestEmbedAllMiniLML6V2(t *testing.T) {
}
}

//
// Breaks on GH Actions
// --- FAIL: TestEmbedBGESmallEN (2.29s)
// fastembed_test.go:63: Expected no error, got The tensor's shape ([1 512]) requires 512 elements, but only 8 were provided
//
// func TestEmbedBGESmallEN(t *testing.T) {
// // Test with a single input
// fe, err := NewFlagEmbedding(&InitOptions{
// Model: BGESmallEN,
// })
// defer fe.Destroy()
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
// }
// input := []string{"Is the world doing okay?"}
// result, err := fe.Embed(input, 1)
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
// }
// fastembed_test.go:63: Expected no error, got The tensor's shape ([1 512]) requires 512 elements, but only 8 were provided
func TestEmbedBGESmallEN(t *testing.T) {
// Test with a single input
fe, err := NewFlagEmbedding(&InitOptions{
Model: BGESmallEN,
})
defer fe.Destroy()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"Is the world doing okay?"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

// fmt.Printf("result: %v\n", result[0][0:10])
// if len(result) != len(input) {
// t.Errorf("Expected result length %v, got %v", len(input), len(result))
// }
// }
fmt.Printf("result: %v\n", result[0][0:10])
if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
}

// A model type "Unigram" is not yet supported by the tokenizer
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
Expand Down

0 comments on commit 1aa5192

Please sign in to comment.