From a164354988ca033ba34b490395bf3f1f37ed5c37 Mon Sep 17 00:00:00 2001 From: Anush Date: Mon, 9 Oct 2023 15:20:20 +0530 Subject: [PATCH] ci: Canonical testing (#4) * test: canonical value testing * ci: Added tests before release --- .github/workflows/release.yml | 21 ++++++++++++++ fastembed.go | 2 +- fastembed_test.go | 52 +++++++++++++++++++++++++++-------- 3 files changed, 62 insertions(+), 13 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5e440f3..08bd4ae 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,8 +7,29 @@ on: workflow_dispatch: jobs: + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: 1.21 + - name: Install dependencies + run: go get . + - name: Install ONNX Runtime + run: | + wget https://github.com/microsoft/onnxruntime/releases/download/v1.16.0/onnxruntime-linux-x64-1.16.0.tgz + tar xvzf onnxruntime-linux-x64-1.16.0.tgz + echo "ONNX_PATH=$(pwd)/onnxruntime-linux-x64-1.16.0/lib/libonnxruntime.so" >> $GITHUB_ENV + - name: Test with Go + run: go test + release: runs-on: ubuntu-latest + needs: + - test steps: - name: "☁️ checkout repository" uses: actions/checkout@v3 diff --git a/fastembed.go b/fastembed.go index c9bf2d0..6c2c90f 100644 --- a/fastembed.go +++ b/fastembed.go @@ -338,7 +338,7 @@ func loadTokenizer(modelPath string, maxLength int) (*tokenizer.Tokenizer, error } // Handle overflow when coercing to int, major hassle. - modelMaxLen := int(math.Min(float64(math.MaxInt32), math.Abs(tokenizerConfig["model_max_length"].(float64)))) + modelMaxLen := int(min(float64(math.MaxInt32), math.Abs(tokenizerConfig["model_max_length"].(float64)))) maxLength = min(maxLength, modelMaxLen) tknzer.WithTruncation(&tokenizer.TruncationParams{ diff --git a/fastembed_test.go b/fastembed_test.go index 6a206c8..a494d84 100644 --- a/fastembed_test.go +++ b/fastembed_test.go @@ -1,7 +1,7 @@ package fastembed import ( - "fmt" + "math" "testing" ) @@ -14,13 +14,12 @@ func TestEmbedBGEBaseEN(t *testing.T) { if err != nil { t.Fatalf("Expected no error, got %v", err) } - input := []string{"Is the world doing okay?"} + input := []string{"hello world"} result, err := fe.Embed(input, 1) if err != nil { t.Fatalf("Expected no error, got %v", err) } - fmt.Printf("result: %v\n", result[0][0:10]) if len(result) != len(input) { t.Errorf("Expected result length %v, got %v", len(input), len(result)) } @@ -35,22 +34,17 @@ func TestEmbedAllMiniLML6V2(t *testing.T) { if err != nil { t.Fatalf("Expected no error, got %v", err) } - input := []string{"Is the world doing okay?"} + input := []string{"hello world"} result, err := fe.Embed(input, 1) if err != nil { t.Fatalf("Expected no error, got %v", err) } - fmt.Printf("result: %v\n", result[0][0:10]) if len(result) != len(input) { t.Errorf("Expected result length %v, got %v", len(input), len(result)) } } -// Breaks on GH Actions -// --- FAIL: TestEmbedBGESmallEN (2.29s) -// -// fastembed_test.go:63: Expected no error, got The tensor's shape ([1 512]) requires 512 elements, but only 8 were provided func TestEmbedBGESmallEN(t *testing.T) { // Test with a single input fe, err := NewFlagEmbedding(&InitOptions{ @@ -60,13 +54,12 @@ func TestEmbedBGESmallEN(t *testing.T) { if err != nil { t.Fatalf("Expected no error, got %v", err) } - input := []string{"Is the world doing okay?"} + input := []string{"hello world"} result, err := fe.Embed(input, 1) if err != nil { t.Fatalf("Expected no error, got %v", err) } - fmt.Printf("result: %v\n", result[0][0:10]) if len(result) != len(input) { t.Errorf("Expected result length %v, got %v", len(input), len(result)) } @@ -85,7 +78,7 @@ func TestEmbedBGESmallEN(t *testing.T) { // if err != nil { // t.Fatalf("Expected no error, got %v", err) // } -// input := []string{"Is the world doing okay?"} +// input := []string{"hello world"} // result, err := fe.Embed(input, 1) // if err != nil { // t.Fatalf("Expected no error, got %v", err) @@ -95,3 +88,38 @@ func TestEmbedBGESmallEN(t *testing.T) { // t.Errorf("Expected result length %v, got %v", len(input), len(result)) // } // } + +func TestCanonicalValues(T *testing.T) { + canonicalValues := map[EmbeddingModel]([]float32){ + AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328, -0.05493, 0.014040, -0.01079, -0.02440, -0.01822}, + BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061, 0.02212, -0.01472, 0.03925, 0.03444, 0.00459}, + BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451, 0.00876, 0.02356, 0.05414, -0.02945, -0.05472}, + } + + for model, expected := range canonicalValues { + fe, err := NewFlagEmbedding(&InitOptions{ + Model: model, + }) + defer fe.Destroy() + if err != nil { + T.Fatalf("Expected no error, got %v", err) + } + input := []string{"hello world"} + result, err := fe.Embed(input, 1) + if err != nil { + T.Fatalf("Expected no error, got %v", err) + } + + if len(result) != len(input) { + T.Errorf("Expected result length %v, got %v", len(input), len(result)) + } + + epsilon := float64(1e-5) + for i, v := range expected { + if math.Abs(float64(result[0][i]-v)) > float64(epsilon) { + T.Errorf("Element %d mismatch: expected %.6f, got %.6f", i, v, result[0][i]) + } + } + } + +}