diff --git a/weaviate/graphql/nearvectorbuilder.go b/weaviate/graphql/nearvectorbuilder.go index fcc61477..bd29abd1 100644 --- a/weaviate/graphql/nearvectorbuilder.go +++ b/weaviate/graphql/nearvectorbuilder.go @@ -77,7 +77,6 @@ func (b *NearVectorArgumentBuilder) WithTargets(targets *MultiTargetArgumentBuil // Build build the given clause func (b *NearVectorArgumentBuilder) build() string { clause := []string{} - targetVectors := b.targetVectors if b.withCertainty { clause = append(clause, fmt.Sprintf("certainty: %v", b.certainty)) } @@ -106,9 +105,39 @@ func (b *NearVectorArgumentBuilder) build() string { if b.targets != nil { clause = append(clause, fmt.Sprintf("targets: {%s}", b.targets.build())) } - if len(targetVectors) > 0 && b.targets == nil { + + targetVectors := b.prepareTargetVectors(b.targetVectors) + if len(targetVectors) > 0 { targetVectors, _ := json.Marshal(targetVectors) clause = append(clause, fmt.Sprintf("targetVectors: %s", targetVectors)) } return fmt.Sprintf("nearVector:{%v}", strings.Join(clause, " ")) } + +// prepareTargetVectors adds appends the target name for each target vector associated with it. +// Example: +// +// // For target vectors: +// WithTargetVectors("v1", "v2"). +// WithVectorProTarget(map[string][][]float32{"v1": {{1,2,3}, {4,5,6}}}) +// // Outputs: +// []string{"v1", "v1", "v2"} +// +// The server requires that the target names be repeated for each target vector, +// and passing them once only is a mistake that the users can easily make. +// This way, the client provides some safeguard. +// +// Note, too, that in case the user fails to pass a value in TargetVectors, +// it will not be added to the query. +func (b NearVectorArgumentBuilder) prepareTargetVectors(targets []string) (out []string) { + for _, target := range targets { + if vectors, ok := b.vectorsPerTarget[target]; ok { + for range vectors { + out = append(out, target) + } + continue + } + out = append(out, target) + } + return +} diff --git a/weaviate/graphql/nearvectorbuilder_test.go b/weaviate/graphql/nearvectorbuilder_test.go index da14d968..c6927509 100644 --- a/weaviate/graphql/nearvectorbuilder_test.go +++ b/weaviate/graphql/nearvectorbuilder_test.go @@ -100,7 +100,10 @@ func TestNearMultiVectorBuilder_build(t *testing.T) { t.Run("No combination with vector per target and target vectors", func(t *testing.T) { vector := NearVectorArgumentBuilder{} - str := vector.WithVectorsPerTarget(map[string][][]float32{"one": {{1, 2, 3}, {7, 8, 9}}, "two": {{4, 5, 6}}}).WithTargetVectors("one", "one", "two").build() + str := vector.WithVectorsPerTarget(map[string][][]float32{ + "one": {{1, 2, 3}, {7, 8, 9}}, + "two": {{4, 5, 6}}}, + ).WithTargetVectors("one", "two").build() require.Contains(t, str, "vectorPerTarget: ") require.Contains(t, str, "one: [[1,2,3],[7,8,9]]") require.Contains(t, str, "two: [[4,5,6]]")