From 1aa5192c964dc1ad2338466ffd1baff68eb6da05 Mon Sep 17 00:00:00 2001 From: Anush Date: Sun, 8 Oct 2023 19:55:15 +0530 Subject: [PATCH] fix: GH Actions BGESmall (#3) * fix: bge-small actions * fix: handle max_model_length overflow * chore: log maxLen * fix: math.Abs() on the overflown val * fix: coercing to int, then abs() * fix: min(MaxInt32, model_max_length) --- fastembed.go | 5 ++++- fastembed_test.go | 41 ++++++++++++++++++++--------------------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/fastembed.go b/fastembed.go index 15db724..5717d8f 100644 --- a/fastembed.go +++ b/fastembed.go @@ -334,7 +334,10 @@ func loadTokenizer(modelPath string, maxLength int) (*tokenizer.Tokenizer, error return nil, err } - maxLength = min(maxLength, int(tokenizerConfig["model_max_length"].(float64))) + // Handle overflow when coercing to int, major hassle. + modelMaxLen := int(math.Min(float64(math.MaxInt32), math.Abs(tokenizerConfig["model_max_length"].(float64)))) + maxLength = min(maxLength, modelMaxLen) + tknzer.WithTruncation(&tokenizer.TruncationParams{ MaxLength: maxLength, Strategy: tokenizer.LongestFirst, diff --git a/fastembed_test.go b/fastembed_test.go index 395b83f..6a206c8 100644 --- a/fastembed_test.go +++ b/fastembed_test.go @@ -47,31 +47,30 @@ func TestEmbedAllMiniLML6V2(t *testing.T) { } } -// // Breaks on GH Actions // --- FAIL: TestEmbedBGESmallEN (2.29s) -// fastembed_test.go:63: Expected no error, got The tensor's shape ([1 512]) requires 512 elements, but only 8 were provided // -// func TestEmbedBGESmallEN(t *testing.T) { -// // Test with a single input -// fe, err := NewFlagEmbedding(&InitOptions{ -// Model: BGESmallEN, -// }) -// defer fe.Destroy() -// if err != nil { -// t.Fatalf("Expected no error, got %v", err) -// } -// input := []string{"Is the world doing okay?"} -// result, err := fe.Embed(input, 1) -// if err != nil { -// t.Fatalf("Expected no error, got %v", err) -// } +// fastembed_test.go:63: Expected no error, got The tensor's shape ([1 512]) requires 512 elements, but only 8 were provided +func TestEmbedBGESmallEN(t *testing.T) { + // Test with a single input + fe, err := NewFlagEmbedding(&InitOptions{ + Model: BGESmallEN, + }) + defer fe.Destroy() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + input := []string{"Is the world doing okay?"} + result, err := fe.Embed(input, 1) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } -// fmt.Printf("result: %v\n", result[0][0:10]) -// if len(result) != len(input) { -// t.Errorf("Expected result length %v, got %v", len(input), len(result)) -// } -// } + fmt.Printf("result: %v\n", result[0][0:10]) + if len(result) != len(input) { + t.Errorf("Expected result length %v, got %v", len(input), len(result)) + } +} // A model type "Unigram" is not yet supported by the tokenizer // Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152