-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvector_test.go
116 lines (98 loc) · 2.91 KB
/
vector_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package modusdb_test
import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/dgraph-io/dgo/v240/protos/api"
"github.com/dgraph-io/dgraph/v24/dgraphapi"
"github.com/stretchr/testify/require"
"github.com/hypermodeinc/modusdb"
)
const (
vectorSchemaWithIndex = `%v: float32vector @index(hnsw(exponent: "%v", metric: "%v")) .`
numVectors = 1000
)
func TestVectorDelete(t *testing.T) {
db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir()))
require.NoError(t, err)
defer db.Close()
require.NoError(t, db.DropAll(context.Background()))
require.NoError(t, db.AlterSchema(context.Background(),
fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "euclidean")))
// insert random vectors
assignIDs, err := db.LeaseUIDs(numVectors + 1)
require.NoError(t, err)
//nolint:gosec
rdf, vectors := dgraphapi.GenerateRandomVectors(int(assignIDs.StartId)-10, int(assignIDs.EndId)-10, 10, "vtest")
_, err = db.Mutate(context.Background(), []*api.Mutation{{SetNquads: []byte(rdf)}})
require.NoError(t, err)
// check the count of the vectors inserted
const q1 = `{
vector(func: has(vtest)) {
count(uid)
}
}`
resp, err := db.Query(context.Background(), q1)
require.NoError(t, err)
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%d}]}`, numVectors), string(resp.Json))
// check whether all the vectors are inserted
const vectorQuery = `
{
vector(func: has(vtest)) {
uid
vtest
}
}`
require.Equal(t, vectors, queryVectors(t, db, vectorQuery))
triples := strings.Split(rdf, "\n")
deleteTriple := func(idx int) string {
_, err := db.Mutate(context.Background(), []*api.Mutation{{
DelNquads: []byte(triples[idx]),
}})
require.NoError(t, err)
uid := strings.Split(triples[idx], " ")[0]
q2 := fmt.Sprintf(`{
vector(func: uid(%s)) {
vtest
}
}`, uid[1:len(uid)-1])
res, err := db.Query(context.Background(), q2)
require.NoError(t, err)
require.JSONEq(t, `{"vector":[]}`, string(res.Json))
return triples[idx]
}
const q3 = `
{
vector(func: similar_to(vtest, 1, "%v")) {
uid
vtest
}
}`
for i := 0; i < len(triples)-2; i++ {
triple := deleteTriple(i)
vectorQuery := fmt.Sprintf(q3, strings.Split(triple, `"`)[1])
respVectors := queryVectors(t, db, vectorQuery)
require.Len(t, respVectors, 1)
require.Contains(t, vectors, respVectors[0])
}
triple := deleteTriple(len(triples) - 2)
_ = queryVectors(t, db, fmt.Sprintf(q3, strings.Split(triple, `"`)[1]))
}
func queryVectors(t *testing.T, db *modusdb.DB, query string) [][]float32 {
resp, err := db.Query(context.Background(), query)
require.NoError(t, err)
var data struct {
Vector []struct {
UID string `json:"uid"`
VTest []float32 `json:"vtest"`
} `json:"vector"`
}
require.NoError(t, json.Unmarshal(resp.Json, &data))
vectors := make([][]float32, 0)
for _, vector := range data.Vector {
vectors = append(vectors, vector.VTest)
}
return vectors
}