From 00d0074ca243c4b47a5b8c7771cf2ca05773b073 Mon Sep 17 00:00:00 2001 From: Anush008 <46051506+Anush008@users.noreply.github.com> Date: Sun, 8 Oct 2023 16:58:11 +0530 Subject: [PATCH] chore: loadTokenizer --- fastembed.go | 61 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/fastembed.go b/fastembed.go index 2bfe60d..a9247da 100644 --- a/fastembed.go +++ b/fastembed.go @@ -27,7 +27,7 @@ const ( BGEBaseEN EmbeddingModel = "fast-bge-base-en" BGESmallEN EmbeddingModel = "fast-bge-small-en" -// A model type "Unigram" is not yet supported by the tokenizer +// A model with type "Unigram" is not yet supported by the tokenizer // Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152 // MLE5Large EmbeddingModel = "fast-multilingual-e5-large" ) @@ -48,7 +48,7 @@ type FlagEmbedding struct { // ShowDownloadProgress: Whether to show the download progress bar // NOTE: // We use a pointer for "ShowDownloadProgress" so that we can distinguish between the user -// not setting this flag and the user setting it to false. +// not setting this flag and the user setting it to false. We want the default value to be true. // As Go assigns a default(empty) value of "false" to bools, we can't distinguish // if the user set it to false or not set at all. // A pointer to bool will be nil if not set explicitly @@ -102,34 +102,14 @@ func NewFlagEmbedding(options *InitOptions) (*FlagEmbedding, error) { return nil, err } - tknzer, err := pretrained.FromFile(filepath.Join(modelPath, "tokenizer.json")) - + tknzer, err := loadTokenizer(modelPath, options.MaxLength) if err != nil { return nil, err } - - maxLen := options.MaxLength - - tknzer.WithTruncation(&tokenizer.TruncationParams{ - MaxLength: maxLen, - Strategy: tokenizer.LongestFirst, - Stride: 0, - }) - - padToken := "[PAD]" - paddingStrategy := tokenizer.NewPaddingStrategy(tokenizer.WithFixed(maxLen)) - - paddingParams := tokenizer.PaddingParams{ - Strategy: *paddingStrategy, - Direction: tokenizer.Right, - PadId: 0, - PadToken: padToken, - } - tknzer.WithPadding(&paddingParams) return &FlagEmbedding{ tokenizer: tknzer, model: options.Model, - maxLength: maxLen, + maxLength: options.MaxLength, modelPath: modelPath, }, nil @@ -282,7 +262,7 @@ func (f *FlagEmbedding) PassageEmbed(input []string, batchSize int) ([]([]float3 return f.Embed(processedInput, batchSize) } -// Function to list the supported models +// Function to list the supported FastEmbed models func ListSupportedModels() []ModelInfo { return []ModelInfo{ { @@ -308,6 +288,36 @@ func ListSupportedModels() []ModelInfo { } } +// TODO: Configure the from model config files +func loadTokenizer(modelPath string, maxLength int) (*tokenizer.Tokenizer, error) { + tknzer, err := pretrained.FromFile(filepath.Join(modelPath, "tokenizer.json")) + + if err != nil { + return nil, err + } + + maxLen := maxLength + + tknzer.WithTruncation(&tokenizer.TruncationParams{ + MaxLength: maxLen, + Strategy: tokenizer.LongestFirst, + Stride: 0, + }) + + padToken := "[PAD]" + paddingStrategy := tokenizer.NewPaddingStrategy(tokenizer.WithFixed(maxLen)) + + paddingParams := tokenizer.PaddingParams{ + Strategy: *paddingStrategy, + Direction: tokenizer.Right, + PadId: 0, + PadToken: padToken, + } + tknzer.WithPadding(&paddingParams) + + return tknzer, nil +} + // Private function to get model information from the model name func getModelInfo(model EmbeddingModel) (ModelInfo, error) { for _, m := range ListSupportedModels() { @@ -349,7 +359,6 @@ func downloadFromGcs(model EmbeddingModel, cacheDir string, showDownloadProgress return "", fmt.Errorf("model download failed: %s", response.Status) } - fmt.Println(response.ContentLength) if showDownloadProgress { bar := progressbar.DefaultBytes( response.ContentLength,