Skip to content

Commit

Permalink
chore: default to ONNX_PATH env
Browse files Browse the repository at this point in the history
BREAKING
  • Loading branch information
Anush008 committed Oct 8, 2023
1 parent e4a24e2 commit f3ada9e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 21 deletions.
5 changes: 2 additions & 3 deletions fastembed.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ type InitOptions struct {
MaxLength int
CacheDir string
ShowDownloadProgress *bool
OnnxPath string
}

// Struct to represent FastEmbed model information
Expand Down Expand Up @@ -84,8 +83,8 @@ func NewFlagEmbedding(options *InitOptions) (*FlagEmbedding, error) {
options.ShowDownloadProgress = &showDownloadProgress
}

if options.OnnxPath != "" {
ort.SetSharedLibraryPath(options.OnnxPath)
if onnxPath := os.Getenv("ONNX_PATH"); onnxPath != "" {
ort.SetSharedLibraryPath(onnxPath)
}

if !ort.IsInitialized() {
Expand Down
22 changes: 4 additions & 18 deletions fastembed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,21 @@ package fastembed

import (
"fmt"
"os"
"reflect"
"testing"
)

func TestNewFlagEmbedding(t *testing.T) {
// Test with default options
options := &InitOptions{
OnnxPath: os.Getenv("ONNX_PATH"),
}
_, err := NewFlagEmbedding(options)
_, err := NewFlagEmbedding(&InitOptions{})
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
}

func TestEmbed(t *testing.T) {
// Test with a single input
options := &InitOptions{
OnnxPath: os.Getenv("ONNX_PATH"),
Model: AllMiniLML6V2,
}
fe, err := NewFlagEmbedding(options)
fe, err := NewFlagEmbedding(&InitOptions{})
defer fe.Destroy()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
Expand All @@ -42,10 +34,7 @@ func TestEmbed(t *testing.T) {

func TestQueryEmbed(t *testing.T) {
// Test with a single input
options := &InitOptions{
OnnxPath: os.Getenv("ONNX_PATH"),
}
fe, err := NewFlagEmbedding(options)
fe, err := NewFlagEmbedding(&InitOptions{})
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
Expand All @@ -61,10 +50,7 @@ func TestQueryEmbed(t *testing.T) {

func TestPassageEmbed(t *testing.T) {
// Test with a single input
options := &InitOptions{
OnnxPath: os.Getenv("ONNX_PATH"),
}
fe, err := NewFlagEmbedding(options)
fe, err := NewFlagEmbedding(&InitOptions{})
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
Expand Down

0 comments on commit f3ada9e

Please sign in to comment.