|
6 | 6 | import static io.qdrant.client.QueryFactory.fusion;
|
7 | 7 | import static io.qdrant.client.QueryFactory.nearest;
|
8 | 8 | import static io.qdrant.client.QueryFactory.orderBy;
|
| 9 | +import static io.qdrant.client.QueryFactory.sample; |
9 | 10 | import static io.qdrant.client.TargetVectorFactory.targetVector;
|
10 | 11 | import static io.qdrant.client.ValueFactory.value;
|
11 | 12 | import static io.qdrant.client.VectorFactory.vector;
|
|
38 | 39 | import io.qdrant.client.grpc.Points.PointsUpdateOperation.ClearPayload;
|
39 | 40 | import io.qdrant.client.grpc.Points.PointsUpdateOperation.UpdateVectors;
|
40 | 41 | import io.qdrant.client.grpc.Points.PrefetchQuery;
|
| 42 | +import io.qdrant.client.grpc.Points.QueryPointGroups; |
41 | 43 | import io.qdrant.client.grpc.Points.QueryPoints;
|
42 | 44 | import io.qdrant.client.grpc.Points.RecommendPointGroups;
|
43 | 45 | import io.qdrant.client.grpc.Points.RecommendPoints;
|
44 | 46 | import io.qdrant.client.grpc.Points.RetrievedPoint;
|
| 47 | +import io.qdrant.client.grpc.Points.Sample; |
45 | 48 | import io.qdrant.client.grpc.Points.ScoredPoint;
|
46 | 49 | import io.qdrant.client.grpc.Points.ScrollPoints;
|
47 | 50 | import io.qdrant.client.grpc.Points.ScrollResponse;
|
|
50 | 53 | import io.qdrant.client.grpc.Points.UpdateResult;
|
51 | 54 | import io.qdrant.client.grpc.Points.UpdateStatus;
|
52 | 55 | import io.qdrant.client.grpc.Points.Vectors;
|
| 56 | +import java.util.Arrays; |
53 | 57 | import java.util.List;
|
54 | 58 | import java.util.concurrent.ExecutionException;
|
55 | 59 | import java.util.concurrent.TimeUnit;
|
@@ -596,7 +600,7 @@ public void batchPointUpdate() throws ExecutionException, InterruptedException {
|
596 | 600 | createAndSeedCollection(testName);
|
597 | 601 |
|
598 | 602 | List<PointsUpdateOperation> operations =
|
599 |
| - List.of( |
| 603 | + Arrays.asList( |
600 | 604 | PointsUpdateOperation.newBuilder()
|
601 | 605 | .setClearPayload(
|
602 | 606 | ClearPayload.newBuilder()
|
@@ -757,6 +761,58 @@ public void queryWithPrefetchAndFusion() throws ExecutionException, InterruptedE
|
757 | 761 | assertEquals(2, points.size());
|
758 | 762 | }
|
759 | 763 |
|
| 764 | + @Test |
| 765 | + public void queryWithSampling() throws ExecutionException, InterruptedException { |
| 766 | + createAndSeedCollection(testName); |
| 767 | + |
| 768 | + List<ScoredPoint> points = |
| 769 | + client |
| 770 | + .queryAsync( |
| 771 | + QueryPoints.newBuilder() |
| 772 | + .setCollectionName(testName) |
| 773 | + .setQuery(sample(Sample.Random)) |
| 774 | + .setLimit(1) |
| 775 | + .build()) |
| 776 | + .get(); |
| 777 | + |
| 778 | + assertEquals(1, points.size()); |
| 779 | + } |
| 780 | + |
| 781 | + @Test |
| 782 | + public void queryGroups() throws ExecutionException, InterruptedException { |
| 783 | + createAndSeedCollection(testName); |
| 784 | + |
| 785 | + client |
| 786 | + .upsertAsync( |
| 787 | + testName, |
| 788 | + ImmutableList.of( |
| 789 | + PointStruct.newBuilder() |
| 790 | + .setId(id(10)) |
| 791 | + .setVectors(VectorsFactory.vectors(30f, 31f)) |
| 792 | + .putAllPayload(ImmutableMap.of("foo", value("hello"))) |
| 793 | + .build())) |
| 794 | + .get(); |
| 795 | + // 3 points in total, 2 with "foo" = "hello" and 1 with "foo" = "goodbye" |
| 796 | + |
| 797 | + List<PointGroup> groups = |
| 798 | + client |
| 799 | + .queryGroupsAsync( |
| 800 | + QueryPointGroups.newBuilder() |
| 801 | + .setCollectionName(testName) |
| 802 | + .setQuery(nearest(ImmutableList.of(10.4f, 11.4f))) |
| 803 | + .setGroupBy("foo") |
| 804 | + .setGroupSize(2) |
| 805 | + .setLimit(10) |
| 806 | + .build()) |
| 807 | + .get(); |
| 808 | + |
| 809 | + assertEquals(2, groups.size()); |
| 810 | + // A group with 2 hits because of 2 points with "foo" = "hello" |
| 811 | + assertEquals(1, groups.stream().filter(g -> g.getHitsCount() == 2).count()); |
| 812 | + // A group with 1 hit because of 1 point with "foo" = "goodbye" |
| 813 | + assertEquals(1, groups.stream().filter(g -> g.getHitsCount() == 1).count()); |
| 814 | + } |
| 815 | + |
760 | 816 | private void createAndSeedCollection(String collectionName)
|
761 | 817 | throws ExecutionException, InterruptedException {
|
762 | 818 | CreateCollection request =
|
|
0 commit comments