Skip to content

Commit

Permalink
feat: MLE5Large
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Oct 8, 2023
1 parent 3ce0526 commit 12f7278
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
21 changes: 20 additions & 1 deletion fastembed.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ const (
AllMiniLML6V2 EmbeddingModel = "fast-all-MiniLM-L6-v2"
BGEBaseEN EmbeddingModel = "fast-bge-base-en"
BGESmallEN EmbeddingModel = "fast-bge-small-en"
// MLE5Large EmbeddingModel = "intfloat-multilingual-e5-large"

// A model type "Unigram" is not yet supported by the tokenizer
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
// MLE5Large EmbeddingModel = "fast-multilingual-e5-large"
)

// Struct to interface with a FastEmbed model
Expand Down Expand Up @@ -297,6 +300,11 @@ func ListSupportedModels() []ModelInfo {
Dim: 384,
Description: "Fast and Default English model",
},
// {
// Model: MLE5Large,
// Dim: 1024,
// Description: "Multilingual model, e5-large. Recommend using this model for non-English languages",
// },
}
}

Expand All @@ -321,6 +329,14 @@ func retrieveModel(model EmbeddingModel, cacheDir string, showDownloadProgress b

// Private function to download the model from Google Cloud Storage
func downloadFromGcs(model EmbeddingModel, cacheDir string, showDownloadProgress bool) (string, error) {
// The MLE5Large model URL doesn't follow the same naming convention as the other models
// So, we tranform "fast-multilingual-e5-large" -> "intfloat-multilingual-e5-large" in the download URL
// The model directory name in the GCS storage is "fast-multilingual-e5-large", like the others
// modelName := model
// if model == MLE5Large {
// modelName = "intfloat" + model[strings.Index(string(model), "-"):]
// }

downloadURL := fmt.Sprintf("https://storage.googleapis.com/qdrant-fastembed/%s.tar.gz", model)

response, err := http.Get(downloadURL)
Expand All @@ -333,6 +349,7 @@ func downloadFromGcs(model EmbeddingModel, cacheDir string, showDownloadProgress
return "", fmt.Errorf("model download failed: %s", response.Status)
}

fmt.Println(response.ContentLength)
if showDownloadProgress {
bar := progressbar.DefaultBytes(
response.ContentLength,
Expand All @@ -341,12 +358,14 @@ func downloadFromGcs(model EmbeddingModel, cacheDir string, showDownloadProgress
reader := progressbar.NewReader(response.Body, bar)
err = untar(&reader, cacheDir)
} else {
fmt.Printf("Downloading %s...", model)
err = untar(response.Body, cacheDir)
}

if err != nil {
return "", err
}

return filepath.Join(cacheDir, string(model)), nil
}

Expand Down
24 changes: 24 additions & 0 deletions fastembed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,27 @@ func TestEmbedAllMiniLML6V2(t *testing.T) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
}

// A model type "Unigram" is not yet supported by the tokenizer
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
// func TestEmbedMLE5Large(t *testing.T) {
// // Test with a single input
// show := false
// fe, err := NewFlagEmbedding(&InitOptions{
// Model: MLE5Large,
// ShowDownloadProgress: &show,
// })
// defer fe.Destroy()
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
// }
// input := []string{"Is the world doing okay?"}
// result, err := fe.Embed(input, 1)
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
// }

// if len(result) != len(input) {
// t.Errorf("Expected result length %v, got %v", len(input), len(result))
// }
// }

0 comments on commit 12f7278

Please sign in to comment.