Skip to content

Commit a8b1734

Browse files
Skip approximation when track_total_hits is set to true (#18087)
Signed-off-by: Prudhvi Godithi <[email protected]>
1 parent 069c2d8 commit a8b1734

File tree

5 files changed

+86
-0
lines changed

5 files changed

+86
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1313
### Changed
1414
- Change the default max header size from 8KB to 16KB. ([#18024](https://github.com/opensearch-project/OpenSearch/pull/18024))
1515
- Enable concurrent_segment_search auto mode by default[#17978](https://github.com/opensearch-project/OpenSearch/pull/17978)
16+
- Skip approximation when `track_total_hits` is set to `true` [#18017](https://github.com/opensearch-project/OpenSearch/pull/18017)
1617

1718
### Dependencies
1819

server/src/main/java/org/opensearch/search/approximate/ApproximateMatchAllQuery.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ protected boolean canApproximate(SearchContext context) {
3636
if (context.aggregations() != null) {
3737
return false;
3838
}
39+
// Exclude approximation when "track_total_hits": true
40+
if (context.trackTotalHitsUpTo() == SearchContext.TRACK_TOTAL_HITS_ACCURATE) {
41+
return false;
42+
}
3943

4044
if (context.request() != null && context.request().source() != null && context.innerHits().getInnerHits().isEmpty()) {
4145
FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(context.request().source());

server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,10 @@ public boolean canApproximate(SearchContext context) {
440440
if (context.aggregations() != null) {
441441
return false;
442442
}
443+
// Exclude approximation when "track_total_hits": true
444+
if (context.trackTotalHitsUpTo() == SearchContext.TRACK_TOTAL_HITS_ACCURATE) {
445+
return false;
446+
}
443447
// size 0 could be set for caching
444448
if (context.from() + context.size() == 0) {
445449
this.setSize(SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO);

server/src/test/java/org/opensearch/search/approximate/ApproximateMatchAllQueryTests.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.opensearch.search.aggregations.AggregatorFactories;
2121
import org.opensearch.search.aggregations.SearchContextAggregations;
2222
import org.opensearch.search.builder.SearchSourceBuilder;
23+
import org.opensearch.search.internal.SearchContext;
2324
import org.opensearch.search.internal.ShardSearchRequest;
2425
import org.opensearch.search.sort.FieldSortBuilder;
2526
import org.opensearch.search.sort.SortOrder;
@@ -105,4 +106,60 @@ public ShardSearchRequest request() {
105106
assertThrows(IllegalStateException.class, () -> approximateMatchAllQuery.rewrite(null));
106107
}
107108

109+
public void testCannotApproximateWithTrackTotalHits() {
110+
ApproximateMatchAllQuery approximateMatchAllQuery = new ApproximateMatchAllQuery();
111+
112+
ShardSearchRequest[] shardSearchRequest = new ShardSearchRequest[1];
113+
114+
MapperService mockMapper = mock(MapperService.class);
115+
String sortfield = "myfield";
116+
MappedFieldType myFieldType = new NumberFieldMapper.NumberFieldType(sortfield, NumberFieldMapper.NumberType.LONG);
117+
when(mockMapper.fieldType(sortfield)).thenReturn(myFieldType);
118+
119+
Settings settings = Settings.builder()
120+
.put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
121+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
122+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1)
123+
.build();
124+
IndexMetadata indexMetadata = new IndexMetadata.Builder("index").settings(settings).build();
125+
QueryShardContext queryShardContext = new QueryShardContext(
126+
0,
127+
new IndexSettings(indexMetadata, settings),
128+
BigArrays.NON_RECYCLING_INSTANCE,
129+
null,
130+
null,
131+
mockMapper,
132+
null,
133+
null,
134+
null,
135+
null,
136+
null,
137+
null,
138+
null,
139+
null,
140+
null,
141+
null,
142+
null
143+
);
144+
TestSearchContext searchContext = new TestSearchContext(queryShardContext) {
145+
@Override
146+
public ShardSearchRequest request() {
147+
return shardSearchRequest[0];
148+
}
149+
};
150+
151+
SearchSourceBuilder source = new SearchSourceBuilder();
152+
shardSearchRequest[0] = new ShardSearchRequest(null, System.currentTimeMillis(), null);
153+
shardSearchRequest[0].source(source);
154+
source.sort(sortfield, SortOrder.ASC);
155+
156+
assertTrue(approximateMatchAllQuery.canApproximate(searchContext));
157+
158+
searchContext.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
159+
assertFalse("Should not approximate when track_total_hits is accurate", approximateMatchAllQuery.canApproximate(searchContext));
160+
161+
searchContext.trackTotalHitsUpTo(SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO);
162+
assertTrue("Should approximate when track_total_hits is not accurate", approximateMatchAllQuery.canApproximate(searchContext));
163+
}
164+
108165
}

server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import static java.util.Arrays.asList;
3434
import static org.apache.lucene.document.LongPoint.pack;
3535
import static org.mockito.Mockito.mock;
36+
import static org.mockito.Mockito.when;
3637

3738
public class ApproximatePointRangeQueryTests extends OpenSearchTestCase {
3839

@@ -372,4 +373,23 @@ public boolean canApproximate(SearchContext context) {
372373
SearchContext searchContext = mock(SearchContext.class);
373374
assertTrue(queryCanApproximate.canApproximate(searchContext));
374375
}
376+
377+
public void testCannotApproximateWithTrackTotalHits() {
378+
ApproximatePointRangeQuery query = new ApproximatePointRangeQuery(
379+
"point",
380+
pack(0).bytes,
381+
pack(20).bytes,
382+
1,
383+
ApproximatePointRangeQuery.LONG_FORMAT
384+
);
385+
SearchContext mockContext = mock(SearchContext.class);
386+
when(mockContext.trackTotalHitsUpTo()).thenReturn(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
387+
assertFalse(query.canApproximate(mockContext));
388+
when(mockContext.trackTotalHitsUpTo()).thenReturn(SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO);
389+
when(mockContext.aggregations()).thenReturn(null);
390+
when(mockContext.from()).thenReturn(0);
391+
when(mockContext.size()).thenReturn(10);
392+
when(mockContext.request()).thenReturn(null);
393+
assertTrue(query.canApproximate(mockContext));
394+
}
375395
}

0 commit comments

Comments
 (0)