diff --git a/README.md b/README.md index 4068c0f..a3781c3 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,11 @@ The default embedding supports "query" and "passage" prefixes for the input text ## 🤖 Models +- [**BAAI/bge-base-en**](https://huggingface.co/BAAI/bge-base-en) - [**BAAI/bge-base-en-v1.5**](https://huggingface.co/BAAI/bge-base-en-v1.5) +- [**BAAI/bge-small-en**](https://huggingface.co/BAAI/bge-small-en) - [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default +- [**BAAI/bge-base-zh-v1.5**](https://huggingface.co/BAAI/bge-base-zh-v1.5) - [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) ## 🚀 Installation diff --git a/fastembed.go b/fastembed.go index 6c2c90f..8a014d7 100644 --- a/fastembed.go +++ b/fastembed.go @@ -26,7 +26,10 @@ type EmbeddingModel string const ( AllMiniLML6V2 EmbeddingModel = "fast-all-MiniLM-L6-v2" BGEBaseEN EmbeddingModel = "fast-bge-base-en" + BGEBaseENV15 EmbeddingModel = "fast-bge-base-en-v1.5" BGESmallEN EmbeddingModel = "fast-bge-small-en" + BGESmallENV15 EmbeddingModel = "fast-bge-small-en-v1.5" + BGESmallZH EmbeddingModel = "fast-bge-small-zh-v1.5" // A model with type "Unigram" is not yet supported by the tokenizer // Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152 @@ -79,7 +82,7 @@ func NewFlagEmbedding(options *InitOptions) (*FlagEmbedding, error) { } if options.Model == "" { - options.Model = BGESmallEN + options.Model = BGESmallENV15 } if options.MaxLength == 0 { @@ -281,10 +284,25 @@ func ListSupportedModels() []ModelInfo { Dim: 768, Description: "Base English model", }, + { + Model: BGEBaseENV15, + Dim: 768, + Description: "v1.5 release of the base English model", + }, { Model: BGESmallEN, Dim: 384, - Description: "Fast and Default English model", + Description: "Fast English model", + }, + { + Model: BGESmallENV15, + Dim: 384, + Description: "Fast, default English model", + }, + { + Model: BGESmallZH, + Dim: 512, + Description: "Fast Chinese model", }, // { // Model: MLE5Large, diff --git a/fastembed_test.go b/fastembed_test.go index a494d84..9e59efb 100644 --- a/fastembed_test.go +++ b/fastembed_test.go @@ -5,95 +5,14 @@ import ( "testing" ) -func TestEmbedBGEBaseEN(t *testing.T) { - // Test with a single input - fe, err := NewFlagEmbedding(&InitOptions{ - Model: BGEBaseEN, - }) - defer fe.Destroy() - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - input := []string{"hello world"} - result, err := fe.Embed(input, 1) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - - if len(result) != len(input) { - t.Errorf("Expected result length %v, got %v", len(input), len(result)) - } -} - -func TestEmbedAllMiniLML6V2(t *testing.T) { - // Test with a single input - fe, err := NewFlagEmbedding(&InitOptions{ - Model: AllMiniLML6V2, - }) - defer fe.Destroy() - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - input := []string{"hello world"} - result, err := fe.Embed(input, 1) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - - if len(result) != len(input) { - t.Errorf("Expected result length %v, got %v", len(input), len(result)) - } -} - -func TestEmbedBGESmallEN(t *testing.T) { - // Test with a single input - fe, err := NewFlagEmbedding(&InitOptions{ - Model: BGESmallEN, - }) - defer fe.Destroy() - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - input := []string{"hello world"} - result, err := fe.Embed(input, 1) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - - if len(result) != len(input) { - t.Errorf("Expected result length %v, got %v", len(input), len(result)) - } -} - -// A model type "Unigram" is not yet supported by the tokenizer -// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152 -// func TestEmbedMLE5Large(t *testing.T) { -// // Test with a single input -// show := false -// fe, err := NewFlagEmbedding(&InitOptions{ -// Model: MLE5Large, -// ShowDownloadProgress: &show, -// }) -// defer fe.Destroy() -// if err != nil { -// t.Fatalf("Expected no error, got %v", err) -// } -// input := []string{"hello world"} -// result, err := fe.Embed(input, 1) -// if err != nil { -// t.Fatalf("Expected no error, got %v", err) -// } - -// if len(result) != len(input) { -// t.Errorf("Expected result length %v, got %v", len(input), len(result)) -// } -// } - func TestCanonicalValues(T *testing.T) { canonicalValues := map[EmbeddingModel]([]float32){ - AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328, -0.05493, 0.014040, -0.01079, -0.02440, -0.01822}, - BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061, 0.02212, -0.01472, 0.03925, 0.03444, 0.00459}, - BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451, 0.00876, 0.02356, 0.05414, -0.02945, -0.05472}, + AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328}, + BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061}, + BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451}, + BGEBaseENV15: []float32{0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045}, + BGESmallENV15: []float32{0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434}, + BGESmallZH: []float32{-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762}, } for model, expected := range canonicalValues { @@ -114,10 +33,10 @@ func TestCanonicalValues(T *testing.T) { T.Errorf("Expected result length %v, got %v", len(input), len(result)) } - epsilon := float64(1e-5) + epsilon := float64(1e-4) for i, v := range expected { if math.Abs(float64(result[0][i]-v)) > float64(epsilon) { - T.Errorf("Element %d mismatch: expected %.6f, got %.6f", i, v, result[0][i]) + T.Errorf("Element %d mismatch for %s: expected %.6f, got %.6f", i, model, v, result[0][i]) } } }