Skip to content

Commit

Permalink
Add support for nearVector queries with ColBERT
Browse files Browse the repository at this point in the history
  • Loading branch information
antas-marcin committed Jan 2, 2025
1 parent c1c5f60 commit 895cb04
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 16 deletions.
4 changes: 2 additions & 2 deletions test/graphql/multi_target_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func TestMultiTargetNearVector(t *testing.T) {
mta *graphql.MultiTargetArgumentBuilder
}{
{name: "with vector", nva: client.GraphQL().NearVectorArgBuilder().WithVector([]float32{1, 0, 0}), mta: to.mta},
{name: "with vector per target", nva: client.GraphQL().NearVectorArgBuilder().WithVectorPerTarget(map[string][]float32{"first": {1, 0, 0}, "second": {1, 0, 0}}), mta: to.mta},
{name: "with vector per target", nva: client.GraphQL().NearVectorArgBuilder().WithVectorPerTarget(map[string]models.Vector{"first": []float32{1, 0, 0}, "second": []float32{1, 0, 0}}), mta: to.mta},
}
for _, ti := range inner {
t.Run(to.name+" combination "+ti.name, func(t *testing.T) {
Expand Down Expand Up @@ -247,7 +247,7 @@ func TestMultiTargetNearVectorMultipleVectors(t *testing.T) {
for _, to := range outer {
t.Run(to.name+" combination", func(t *testing.T) {
nv := &graphql.NearVectorArgumentBuilder{}
nv.WithVectorsPerTarget(map[string][][]float32{"first": {{1, 0, 0}, {0, 1, 0}}, "second": {{1, 0, 0}}}).WithTargets(to.mta)
nv.WithVectorsPerTarget(map[string][]models.Vector{"first": {[]float32{1, 0, 0}, []float32{0, 1, 0}}, "second": {[]float32{1, 0, 0}}}).WithTargets(to.mta)
resp, err := client.GraphQL().Get().WithNearVector(nv).WithClassName(class.Class).WithFields(graphql.Field{Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "distance"}}}).Do(ctx)
require.Nil(t, err)
if resp.Errors != nil {
Expand Down
29 changes: 21 additions & 8 deletions weaviate/graphql/nearvectorbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"encoding/json"
"fmt"
"strings"

"github.com/weaviate/weaviate/entities/models"
)

type NearVectorArgumentBuilder struct {
vector []float32
vectorsPerTarget map[string][][]float32
vector models.Vector
vectorsPerTarget map[string][]models.Vector
withCertainty bool
certainty float32
withDistance bool
Expand All @@ -18,18 +20,18 @@ type NearVectorArgumentBuilder struct {
}

// WithVector sets the search vector to be used in query
func (b *NearVectorArgumentBuilder) WithVector(vector []float32) *NearVectorArgumentBuilder {
func (b *NearVectorArgumentBuilder) WithVector(vector models.Vector) *NearVectorArgumentBuilder {
b.vector = vector
return b
}

// WithVectorPerTarget sets the search vector per target to be used in a multi target search query. This builder method takes
// precedence over WithVector. So if WithVectorPerTarget is used, WithVector will be ignored.
func (b *NearVectorArgumentBuilder) WithVectorPerTarget(vectorPerTarget map[string][]float32) *NearVectorArgumentBuilder {
func (b *NearVectorArgumentBuilder) WithVectorPerTarget(vectorPerTarget map[string]models.Vector) *NearVectorArgumentBuilder {
if len(vectorPerTarget) > 0 {
vectorPerTargetTmp := make(map[string][][]float32)
vectorPerTargetTmp := make(map[string][]models.Vector)
for k, v := range vectorPerTarget {
vectorPerTargetTmp[k] = [][]float32{v}
vectorPerTargetTmp[k] = []models.Vector{v}
}
b.vectorsPerTarget = vectorPerTargetTmp
}
Expand All @@ -38,7 +40,7 @@ func (b *NearVectorArgumentBuilder) WithVectorPerTarget(vectorPerTarget map[stri

// WithVectorsPerTarget sets the search vector per target to be used in a multi target search query. This builder method takes
// precedence over WithVector and WithVectorPerTarget. So if WithVectorsPerTarget is used, WithVector and WithVectorPerTarget will be ignored.
func (b *NearVectorArgumentBuilder) WithVectorsPerTarget(vectorPerTarget map[string][][]float32) *NearVectorArgumentBuilder {
func (b *NearVectorArgumentBuilder) WithVectorsPerTarget(vectorPerTarget map[string][]models.Vector) *NearVectorArgumentBuilder {
if len(vectorPerTarget) > 0 {
b.vectorsPerTarget = vectorPerTarget
}
Expand Down Expand Up @@ -95,7 +97,7 @@ func (b *NearVectorArgumentBuilder) build() string {
}
clause = append(clause, fmt.Sprintf("vectorPerTarget: {%s}", strings.Join(vectorPerTarget, ",")))
}
if len(b.vector) != 0 && len(b.vectorsPerTarget) == 0 {
if !b.isVectorEmpty(b.vector) && len(b.vectorsPerTarget) == 0 {
vectorB, err := json.Marshal(b.vector)
if err != nil {
panic(fmt.Errorf("failed to unmarshal nearVector search vector: %s", err))
Expand All @@ -114,6 +116,17 @@ func (b *NearVectorArgumentBuilder) build() string {
return fmt.Sprintf("nearVector:{%v}", strings.Join(clause, " "))
}

func (b *NearVectorArgumentBuilder) isVectorEmpty(vector models.Vector) bool {
switch v := vector.(type) {
case []float32:
return len(v) == 0
case [][]float32:
return len(v) == 0
default:
return false
}
}

// prepareTargetVectors adds appends the target name for each target vector associated with it.
// Example:
//
Expand Down
13 changes: 7 additions & 6 deletions weaviate/graphql/nearvectorbuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/models"
)

func TestNearMultiVectorBuilder_build(t *testing.T) {
Expand All @@ -23,7 +24,7 @@ func TestNearMultiVectorBuilder_build(t *testing.T) {
t.Run("Average combination with vector per target", func(t *testing.T) {
vector := NearVectorArgumentBuilder{}
targets := MultiTargetArgumentBuilder{}
str := vector.WithVectorPerTarget(map[string][]float32{"one": {1, 2, 3}, "two": {4, 5, 6}}).WithTargets(targets.Average("one", "two")).build()
str := vector.WithVectorPerTarget(map[string]models.Vector{"one": []float32{1, 2, 3}, "two": []float32{4, 5, 6}}).WithTargets(targets.Average("one", "two")).build()
require.Contains(t, str, "vectorPerTarget: ")
require.NotContains(t, str, "vector: ")
require.Contains(t, str, "one: [[1,2,3]]")
Expand Down Expand Up @@ -78,7 +79,7 @@ func TestNearMultiVectorBuilder_build(t *testing.T) {

t.Run("No combination with vector per target", func(t *testing.T) {
vector := NearVectorArgumentBuilder{}
str := vector.WithVectorPerTarget(map[string][]float32{"one": {1, 2, 3}, "two": {4, 5, 6}}).build()
str := vector.WithVectorPerTarget(map[string]models.Vector{"one": []float32{1, 2, 3}, "two": []float32{4, 5, 6}}).build()
require.Contains(t, str, "vectorPerTarget: ")
require.Contains(t, str, "one: [[1,2,3]]")
require.Contains(t, str, "two: [[4,5,6]]")
Expand All @@ -89,7 +90,7 @@ func TestNearMultiVectorBuilder_build(t *testing.T) {

t.Run("No combination with multiple vectors per target", func(t *testing.T) {
vector := NearVectorArgumentBuilder{}
str := vector.WithVectorsPerTarget(map[string][][]float32{"one": {{1, 2, 3}, {7, 8, 9}}, "two": {{4, 5, 6}}}).build()
str := vector.WithVectorsPerTarget(map[string][]models.Vector{"one": []models.Vector{[]float32{1, 2, 3}, []float32{7, 8, 9}}, "two": []models.Vector{[]float32{4, 5, 6}}}).build()

Check failure on line 93 in weaviate/graphql/nearvectorbuilder_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not `gofumpt`-ed (gofumpt)
require.Contains(t, str, "vectorPerTarget: ")
require.Contains(t, str, "one: [[1,2,3],[7,8,9]]")
require.Contains(t, str, "two: [[4,5,6]]")
Expand All @@ -100,9 +101,9 @@ func TestNearMultiVectorBuilder_build(t *testing.T) {

t.Run("No combination with vector per target and target vectors", func(t *testing.T) {
vector := NearVectorArgumentBuilder{}
str := vector.WithVectorsPerTarget(map[string][][]float32{
"one": {{1, 2, 3}, {7, 8, 9}},
"two": {{4, 5, 6}},
str := vector.WithVectorsPerTarget(map[string][]models.Vector{
"one": {[]float32{1, 2, 3}, []float32{7, 8, 9}},
"two": {[]float32{4, 5, 6}},
},
).WithTargetVectors("one", "two").build()
require.Contains(t, str, "vectorPerTarget: ")
Expand Down

0 comments on commit 895cb04

Please sign in to comment.