diff --git a/README.md b/README.md index 5ac8114..8a03853 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ err := index.Upsert(vector.Upsert{ ### Querying Vectors -The query vector must be present and it must have the same dimensions with the +The query vector must be present, and it must have the same dimensions with the all the other vectors in the index. When `TopK` is specified, at most that many vectors will be returned. @@ -135,6 +135,20 @@ scores, err := index.Query(vector.Query{ }) ``` +Additionally, a metadata filter can be specified in queries. When `Filter` is given, the response will contain +only the values whose metadata matches the given filter. See [Metadata Filtering](https://upstash.com/docs/vector/features/metadatafiltering) +docs for more information. + +```go +scores, err := index.Query(vector.Query{ + Vector: []float32{0.0, 1.0}, + TopK: 2, + IncludeVectors: false, + IncludeMetadata: false, + Filter: `foo = 'bar'` +}) +``` + ### Fetching Vectors Vectors can be fetched individually by providing the unique vector ids. diff --git a/go.mod b/go.mod index b77810e..e174c80 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.20 require ( github.com/joho/godotenv v1.5.1 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 ) require ( diff --git a/go.sum b/go.sum index 26059cd..28ecbf2 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/query_test.go b/query_test.go index a4b8acc..eb7e755 100644 --- a/query_test.go +++ b/query_test.go @@ -13,6 +13,7 @@ func TestQuery(t *testing.T) { id0 := randomString() id1 := randomString() + id2 := randomString() err = client.UpsertMany([]Upsert{ { Id: id0, @@ -23,6 +24,11 @@ func TestQuery(t *testing.T) { Id: id1, Vector: []float32{5, 10}, }, + { + Id: id2, + Vector: []float32{0.01, 1.01}, + Metadata: map[string]any{"foo": "nay"}, + }, }) require.NoError(t, err) @@ -35,26 +41,56 @@ func TestQuery(t *testing.T) { t.Run("score", func(t *testing.T) { scores, err := client.Query(Query{ Vector: []float32{0, 1}, - TopK: 1, + TopK: 2, }) require.NoError(t, err) - require.Equal(t, 1, len(scores)) + require.Equal(t, 2, len(scores)) require.Equal(t, id0, scores[0].Id) require.Equal(t, float32(1.0), scores[0].Score) + require.Equal(t, id2, scores[1].Id) }) t.Run("with metadata and vectors", func(t *testing.T) { scores, err := client.Query(Query{ Vector: []float32{0, 1}, - TopK: 1, + TopK: 2, IncludeMetadata: true, IncludeVectors: true, }) require.NoError(t, err) + require.Equal(t, 2, len(scores)) + require.Equal(t, id0, scores[0].Id) + require.Equal(t, float32(1.0), scores[0].Score) + require.Equal(t, map[string]any{"foo": "bar"}, scores[0].Metadata) + require.Equal(t, []float32{0, 1}, scores[0].Vector) + + require.Equal(t, id2, scores[1].Id) + require.Equal(t, []float32{0.01, 1.01}, scores[1].Vector) + }) + + t.Run("with metadata filtering", func(t *testing.T) { + query := Query{ + Vector: []float32{0, 1}, + TopK: 10, + IncludeMetadata: true, + IncludeVectors: true, + Filter: `foo = 'bar'`, + } + + scores, err := client.Query(query) + require.NoError(t, err) require.Equal(t, 1, len(scores)) require.Equal(t, id0, scores[0].Id) require.Equal(t, float32(1.0), scores[0].Score) require.Equal(t, map[string]any{"foo": "bar"}, scores[0].Metadata) require.Equal(t, []float32{0, 1}, scores[0].Vector) + + query.Filter = `foo = 'nay'` + scores, err = client.Query(query) + require.NoError(t, err) + require.Equal(t, 1, len(scores)) + require.Equal(t, id2, scores[0].Id) + require.Equal(t, map[string]any{"foo": "nay"}, scores[0].Metadata) + require.Equal(t, []float32{0.01, 1.01}, scores[0].Vector) }) } diff --git a/types.go b/types.go index 6c16ce9..0528a61 100644 --- a/types.go +++ b/types.go @@ -24,6 +24,9 @@ type Query struct { // Whether to include metadata in the query response, if any. IncludeMetadata bool `json:"includeMetadata,omitempty"` + + // Query filter + Filter any `json:"filter,omitempty"` } type Vector struct {