Skip to content

Commit

Permalink
feat: ensure query has a target name for each target vector
Browse files Browse the repository at this point in the history
  • Loading branch information
bevzzz committed Oct 29, 2024
1 parent 2f23565 commit 6fa31b9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
33 changes: 31 additions & 2 deletions weaviate/graphql/nearvectorbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ func (b *NearVectorArgumentBuilder) WithTargets(targets *MultiTargetArgumentBuil
// Build build the given clause
func (b *NearVectorArgumentBuilder) build() string {
clause := []string{}
targetVectors := b.targetVectors
if b.withCertainty {
clause = append(clause, fmt.Sprintf("certainty: %v", b.certainty))
}
Expand Down Expand Up @@ -106,9 +105,39 @@ func (b *NearVectorArgumentBuilder) build() string {
if b.targets != nil {
clause = append(clause, fmt.Sprintf("targets: {%s}", b.targets.build()))
}
if len(targetVectors) > 0 && b.targets == nil {

targetVectors := b.prepareTargetVectors(b.targetVectors)
if len(targetVectors) > 0 {
targetVectors, _ := json.Marshal(targetVectors)
clause = append(clause, fmt.Sprintf("targetVectors: %s", targetVectors))
}
return fmt.Sprintf("nearVector:{%v}", strings.Join(clause, " "))
}

// prepareTargetVectors adds appends the target name for each target vector associated with it.
// Example:
//
// // For target vectors:
// WithTargetVectors("v1", "v2").
// WithVectorProTarget(map[string][][]float32{"v1": {{1,2,3}, {4,5,6}}})
// // Outputs:
// []string{"v1", "v1", "v2"}
//
// The server requires that the target names be repeated for each target vector,
// and passing them once only is a mistake that the users can easily make.
// This way, the client provides some safeguard.
//
// Note, too, that in case the user fails to pass a value in TargetVectors,
// it will not be added to the query.
func (b NearVectorArgumentBuilder) prepareTargetVectors(targets []string) (out []string) {
for _, target := range targets {
if vectors, ok := b.vectorsPerTarget[target]; ok {
for range vectors {
out = append(out, target)
}
continue
}
out = append(out, target)
}
return
}
5 changes: 4 additions & 1 deletion weaviate/graphql/nearvectorbuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ func TestNearMultiVectorBuilder_build(t *testing.T) {

t.Run("No combination with vector per target and target vectors", func(t *testing.T) {
vector := NearVectorArgumentBuilder{}
str := vector.WithVectorsPerTarget(map[string][][]float32{"one": {{1, 2, 3}, {7, 8, 9}}, "two": {{4, 5, 6}}}).WithTargetVectors("one", "one", "two").build()
str := vector.WithVectorsPerTarget(map[string][][]float32{
"one": {{1, 2, 3}, {7, 8, 9}},
"two": {{4, 5, 6}}},

Check failure on line 105 in weaviate/graphql/nearvectorbuilder_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not `gofumpt`-ed (gofumpt)
).WithTargetVectors("one", "two").build()
require.Contains(t, str, "vectorPerTarget: ")
require.Contains(t, str, "one: [[1,2,3],[7,8,9]]")
require.Contains(t, str, "two: [[4,5,6]]")
Expand Down

0 comments on commit 6fa31b9

Please sign in to comment.