Skip to content

Commit

Permalink
chore: loadTokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Oct 8, 2023
1 parent 12f7278 commit 00d0074
Showing 1 changed file with 35 additions and 26 deletions.
61 changes: 35 additions & 26 deletions fastembed.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const (
BGEBaseEN EmbeddingModel = "fast-bge-base-en"
BGESmallEN EmbeddingModel = "fast-bge-small-en"

// A model type "Unigram" is not yet supported by the tokenizer
// A model with type "Unigram" is not yet supported by the tokenizer
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
// MLE5Large EmbeddingModel = "fast-multilingual-e5-large"
)
Expand All @@ -48,7 +48,7 @@ type FlagEmbedding struct {
// ShowDownloadProgress: Whether to show the download progress bar
// NOTE:
// We use a pointer for "ShowDownloadProgress" so that we can distinguish between the user
// not setting this flag and the user setting it to false.
// not setting this flag and the user setting it to false. We want the default value to be true.
// As Go assigns a default(empty) value of "false" to bools, we can't distinguish
// if the user set it to false or not set at all.
// A pointer to bool will be nil if not set explicitly
Expand Down Expand Up @@ -102,34 +102,14 @@ func NewFlagEmbedding(options *InitOptions) (*FlagEmbedding, error) {
return nil, err
}

tknzer, err := pretrained.FromFile(filepath.Join(modelPath, "tokenizer.json"))

tknzer, err := loadTokenizer(modelPath, options.MaxLength)
if err != nil {
return nil, err
}

maxLen := options.MaxLength

tknzer.WithTruncation(&tokenizer.TruncationParams{
MaxLength: maxLen,
Strategy: tokenizer.LongestFirst,
Stride: 0,
})

padToken := "[PAD]"
paddingStrategy := tokenizer.NewPaddingStrategy(tokenizer.WithFixed(maxLen))

paddingParams := tokenizer.PaddingParams{
Strategy: *paddingStrategy,
Direction: tokenizer.Right,
PadId: 0,
PadToken: padToken,
}
tknzer.WithPadding(&paddingParams)
return &FlagEmbedding{
tokenizer: tknzer,
model: options.Model,
maxLength: maxLen,
maxLength: options.MaxLength,
modelPath: modelPath,
}, nil

Expand Down Expand Up @@ -282,7 +262,7 @@ func (f *FlagEmbedding) PassageEmbed(input []string, batchSize int) ([]([]float3
return f.Embed(processedInput, batchSize)
}

// Function to list the supported models
// Function to list the supported FastEmbed models
func ListSupportedModels() []ModelInfo {
return []ModelInfo{
{
Expand All @@ -308,6 +288,36 @@ func ListSupportedModels() []ModelInfo {
}
}

// TODO: Configure the from model config files
func loadTokenizer(modelPath string, maxLength int) (*tokenizer.Tokenizer, error) {
tknzer, err := pretrained.FromFile(filepath.Join(modelPath, "tokenizer.json"))

if err != nil {
return nil, err
}

maxLen := maxLength

tknzer.WithTruncation(&tokenizer.TruncationParams{
MaxLength: maxLen,
Strategy: tokenizer.LongestFirst,
Stride: 0,
})

padToken := "[PAD]"
paddingStrategy := tokenizer.NewPaddingStrategy(tokenizer.WithFixed(maxLen))

paddingParams := tokenizer.PaddingParams{
Strategy: *paddingStrategy,
Direction: tokenizer.Right,
PadId: 0,
PadToken: padToken,
}
tknzer.WithPadding(&paddingParams)

return tknzer, nil
}

// Private function to get model information from the model name
func getModelInfo(model EmbeddingModel) (ModelInfo, error) {
for _, m := range ListSupportedModels() {
Expand Down Expand Up @@ -349,7 +359,6 @@ func downloadFromGcs(model EmbeddingModel, cacheDir string, showDownloadProgress
return "", fmt.Errorf("model download failed: %s", response.Status)
}

fmt.Println(response.ContentLength)
if showDownloadProgress {
bar := progressbar.DefaultBytes(
response.ContentLength,
Expand Down

0 comments on commit 00d0074

Please sign in to comment.