diff --git a/CHANGELOG.md b/CHANGELOG.md index 591faec2c533d..c1d23597d3d3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Optimization in String Terms Aggregation query for Large Bucket Counts([#18732](https://github.com/opensearch-project/OpenSearch/pull/18732)) - New cluster setting search.query.max_query_string_length ([#19491](https://github.com/opensearch-project/OpenSearch/pull/19491)) - Add `StreamNumericTermsAggregator` to allow numeric term aggregation streaming ([#19335](https://github.com/opensearch-project/OpenSearch/pull/19335)) +- Harden the circuit breaker and failure handle logic in query result consumer ([#19396](https://github.com/opensearch-project/OpenSearch/pull/19396)) ### Changed - Refactor `if-else` chains to use `Java 17 pattern matching switch expressions`(([#18965](https://github.com/opensearch-project/OpenSearch/pull/18965)) diff --git a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java index fa40e50aa868a..b04d3086d8c95 100644 --- a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java @@ -66,7 +66,7 @@ * as shard results are consumed. * This implementation adds the memory that it used to save and reduce the results of shard aggregations * in the {@link CircuitBreaker#REQUEST} circuit breaker. Before any partial or final reduce, the memory - * needed to reduce the aggregations is estimated and a {@link CircuitBreakingException} is thrown if it + * needed to reduce the aggregations is estimated and a {@link CircuitBreakingException} is handled if it * exceeds the maximum memory allowed in this breaker. * * @opensearch.internal @@ -86,8 +86,8 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults onPartialMergeFailure; + final PendingReduces pendingReduces; + private final Consumer cancelTaskOnFailure; private final BooleanSupplier isTaskCancelled; public QueryPhaseResultConsumer( @@ -98,7 +98,7 @@ public QueryPhaseResultConsumer( SearchProgressListener progressListener, NamedWriteableRegistry namedWriteableRegistry, int expectedResultSize, - Consumer onPartialMergeFailure + Consumer cancelTaskOnFailure ) { this( request, @@ -108,7 +108,7 @@ public QueryPhaseResultConsumer( progressListener, namedWriteableRegistry, expectedResultSize, - onPartialMergeFailure, + cancelTaskOnFailure, () -> false ); } @@ -125,7 +125,7 @@ public QueryPhaseResultConsumer( SearchProgressListener progressListener, NamedWriteableRegistry namedWriteableRegistry, int expectedResultSize, - Consumer onPartialMergeFailure, + Consumer cancelTaskOnFailure, BooleanSupplier isTaskCancelled ) { super(expectedResultSize); @@ -137,13 +137,13 @@ public QueryPhaseResultConsumer( this.namedWriteableRegistry = namedWriteableRegistry; this.topNSize = SearchPhaseController.getTopDocsSize(request); this.performFinalReduce = request.isFinalReduce(); - this.onPartialMergeFailure = onPartialMergeFailure; + this.cancelTaskOnFailure = cancelTaskOnFailure; SearchSourceBuilder source = request.source(); this.hasTopDocs = source == null || source.size() != 0; this.hasAggs = source != null && source.aggregations() != null; int batchReduceSize = getBatchReduceSize(request.getBatchedReduceSize(), expectedResultSize); - this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo()); + this.pendingReduces = new PendingReduces(batchReduceSize, request.resolveTrackTotalHitsUpTo()); this.isTaskCancelled = isTaskCancelled; } @@ -153,7 +153,7 @@ int getBatchReduceSize(int requestBatchedReduceSize, int minBatchReduceSize) { @Override public void close() { - Releasables.close(pendingMerges); + Releasables.close(pendingReduces); } @Override @@ -161,35 +161,35 @@ public void consumeResult(SearchPhaseResult result, Runnable next) { super.consumeResult(result, () -> {}); QuerySearchResult querySearchResult = result.queryResult(); progressListener.notifyQueryResult(querySearchResult.getShardIndex()); - checkCancellation(); - pendingMerges.consume(querySearchResult, next); + pendingReduces.consume(querySearchResult, next); } @Override public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { - if (pendingMerges.hasPendingMerges()) { + if (pendingReduces.hasPendingReduceTask()) { throw new AssertionError("partial reduce in-flight"); - } else if (pendingMerges.hasFailure()) { - throw pendingMerges.getFailure(); } checkCancellation(); + if (pendingReduces.hasFailure()) { + throw pendingReduces.failure.get(); + } // ensure consistent ordering - pendingMerges.sortBuffer(); - final SearchPhaseController.TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats(); - final List topDocsList = pendingMerges.consumeTopDocs(); - final List aggsList = pendingMerges.consumeAggs(); - long breakerSize = pendingMerges.circuitBreakerBytes; + pendingReduces.sortBuffer(); + final SearchPhaseController.TopDocsStats topDocsStats = pendingReduces.consumeTopDocsStats(); + final List topDocsList = pendingReduces.consumeTopDocs(); + final List aggsList = pendingReduces.consumeAggs(); + long breakerSize = pendingReduces.circuitBreakerBytes; if (hasAggs) { // Add an estimate of the final reduce size - breakerSize = pendingMerges.addEstimateAndMaybeBreak(pendingMerges.estimateRamBytesUsedForReduce(breakerSize)); + breakerSize = pendingReduces.addEstimateAndMaybeBreak(pendingReduces.estimateRamBytesUsedForReduce(breakerSize)); } SearchPhaseController.ReducedQueryPhase reducePhase = controller.reducedQueryPhase( results.asList(), aggsList, topDocsList, topDocsStats, - pendingMerges.numReducePhases, + pendingReduces.numReducePhases, false, aggReduceContextBuilder, performFinalReduce @@ -197,8 +197,12 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { if (hasAggs) { // Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result long finalSize = reducePhase.aggregations.getSerializedSize() - breakerSize; - pendingMerges.addWithoutBreaking(finalSize); - logger.trace("aggs final reduction [{}] max [{}]", pendingMerges.aggsCurrentBufferSize, pendingMerges.maxAggsCurrentBufferSize); + pendingReduces.addWithoutBreaking(finalSize); + logger.trace( + "aggs final reduction [{}] max [{}]", + pendingReduces.aggsCurrentBufferSize, + pendingReduces.maxAggsCurrentBufferSize + ); } progressListener.notifyFinalReduce( SearchProgressListener.buildSearchShards(results.asList()), @@ -209,16 +213,16 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { return reducePhase; } - private MergeResult partialReduce( + private ReduceResult partialReduce( QuerySearchResult[] toConsume, List emptyResults, SearchPhaseController.TopDocsStats topDocsStats, - MergeResult lastMerge, + ReduceResult lastReduceResult, int numReducePhases ) { checkCancellation(); - if (pendingMerges.hasFailure()) { - return lastMerge; + if (pendingReduces.hasFailure()) { + return lastReduceResult; } // ensure consistent ordering Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex)); @@ -230,8 +234,8 @@ private MergeResult partialReduce( final TopDocs newTopDocs; if (hasTopDocs) { List topDocsList = new ArrayList<>(); - if (lastMerge != null) { - topDocsList.add(lastMerge.reducedTopDocs); + if (lastReduceResult != null) { + topDocsList.add(lastReduceResult.reducedTopDocs); } for (QuerySearchResult result : toConsume) { TopDocsAndMaxScore topDocs = result.consumeTopDocs(); @@ -251,8 +255,8 @@ private MergeResult partialReduce( final InternalAggregations newAggs; if (hasAggs) { List aggsList = new ArrayList<>(); - if (lastMerge != null) { - aggsList.add(lastMerge.reducedAggs); + if (lastReduceResult != null) { + aggsList.add(lastReduceResult.reducedAggs); } for (QuerySearchResult result : toConsume) { aggsList.add(result.consumeAggs().expand()); @@ -262,8 +266,8 @@ private MergeResult partialReduce( newAggs = null; } List processedShards = new ArrayList<>(emptyResults); - if (lastMerge != null) { - processedShards.addAll(lastMerge.processedShards); + if (lastReduceResult != null) { + processedShards.addAll(lastReduceResult.processedShards); } for (QuerySearchResult result : toConsume) { SearchShardTarget target = result.getSearchShardTarget(); @@ -273,29 +277,31 @@ private MergeResult partialReduce( // we leave the results un-serialized because serializing is slow but we compute the serialized // size as an estimate of the memory used by the newly reduced aggregations. long serializedSize = hasAggs ? newAggs.getSerializedSize() : 0; - return new MergeResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0); + return new ReduceResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0); } private void checkCancellation() { if (isTaskCancelled.getAsBoolean()) { - pendingMerges.resetCircuitBreakerForCurrentRequest(); - // This check is to ensure that we are not masking the actual reason for cancellation i,e; CircuitBreakingException - if (!pendingMerges.hasFailure()) { - pendingMerges.failure.set(new TaskCancelledException("request has been terminated")); - } + pendingReduces.onFailure(new TaskCancelledException("request has been terminated")); } } public int getNumReducePhases() { - return pendingMerges.numReducePhases; + return pendingReduces.numReducePhases; } /** - * Class representing pending merges + * Manages incremental query result reduction by buffering incoming results and + * triggering partial reduce operations when the threshold is reached. + *
    + *
  • Handles circuit breaker memory accounting
  • + *
  • Coordinates reduce task execution to be one at a time
  • + *
  • Provides thread-safe failure handling with cleanup
  • + *
* * @opensearch.internal */ - class PendingMerges implements Releasable { + class PendingReduces implements Releasable { private final int batchReduceSize; private final List buffer = new ArrayList<>(); private final List emptyResults = new ArrayList<>(); @@ -305,23 +311,23 @@ class PendingMerges implements Releasable { private volatile long aggsCurrentBufferSize; private volatile long maxAggsCurrentBufferSize = 0; - private final ArrayDeque queue = new ArrayDeque<>(); - private final AtomicReference runningTask = new AtomicReference<>(); + private final ArrayDeque queue = new ArrayDeque<>(); + private final AtomicReference runningTask = new AtomicReference<>(); // ensure only one task is running private final AtomicReference failure = new AtomicReference<>(); private final SearchPhaseController.TopDocsStats topDocsStats; - private volatile MergeResult mergeResult; + private volatile ReduceResult reduceResult; private volatile boolean hasPartialReduce; private volatile int numReducePhases; - PendingMerges(int batchReduceSize, int trackTotalHitsUpTo) { + PendingReduces(int batchReduceSize, int trackTotalHitsUpTo) { this.batchReduceSize = batchReduceSize; this.topDocsStats = new SearchPhaseController.TopDocsStats(trackTotalHitsUpTo); } @Override public synchronized void close() { - assert hasPendingMerges() == false : "cannot close with partial reduce in-flight"; + assert hasPendingReduceTask() == false : "cannot close with partial reduce in-flight"; if (hasFailure()) { assert circuitBreakerBytes == 0; return; @@ -331,43 +337,52 @@ public synchronized void close() { circuitBreakerBytes = 0; } - synchronized Exception getFailure() { - return failure.get(); - } - - boolean hasFailure() { + private boolean hasFailure() { return failure.get() != null; } - boolean hasPendingMerges() { + private boolean hasPendingReduceTask() { return queue.isEmpty() == false || runningTask.get() != null; } - void sortBuffer() { + private void sortBuffer() { if (buffer.size() > 0) { Collections.sort(buffer, Comparator.comparingInt(QuerySearchResult::getShardIndex)); } } - synchronized long addWithoutBreaking(long size) { + private synchronized long addWithoutBreaking(long size) { + if (hasFailure()) { + return circuitBreakerBytes; + } circuitBreaker.addWithoutBreaking(size); circuitBreakerBytes += size; maxAggsCurrentBufferSize = Math.max(maxAggsCurrentBufferSize, circuitBreakerBytes); return circuitBreakerBytes; } - synchronized long addEstimateAndMaybeBreak(long estimatedSize) { + private synchronized long addEstimateAndMaybeBreak(long estimatedSize) { + if (hasFailure()) { + return circuitBreakerBytes; + } circuitBreaker.addEstimateBytesAndMaybeBreak(estimatedSize, ""); circuitBreakerBytes += estimatedSize; maxAggsCurrentBufferSize = Math.max(maxAggsCurrentBufferSize, circuitBreakerBytes); return circuitBreakerBytes; } + private synchronized void resetCircuitBreaker() { + if (circuitBreakerBytes > 0) { + circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); + circuitBreakerBytes = 0; + } + } + /** * Returns the size of the serialized aggregation that is contained in the * provided {@link QuerySearchResult}. */ - long ramBytesUsedQueryResult(QuerySearchResult result) { + private long ramBytesUsedQueryResult(QuerySearchResult result) { if (hasAggs == false) { return 0; } @@ -382,129 +397,62 @@ long ramBytesUsedQueryResult(QuerySearchResult result) { * off for some aggregations but it is corrected with the real size after * the reduce completes. */ - long estimateRamBytesUsedForReduce(long size) { - return Math.round(1.5d * size - size); + private long estimateRamBytesUsedForReduce(long size) { + return Math.round(0.5d * size); } - public void consume(QuerySearchResult result, Runnable next) throws CircuitBreakingException { - boolean executeNextImmediately = true; - synchronized (this) { - checkCircuitBreaker(next); - if (hasFailure() || result.isNull()) { - result.consumeAll(); - if (result.isNull()) { - SearchShardTarget target = result.getSearchShardTarget(); - emptyResults.add(new SearchShard(target.getClusterAlias(), target.getShardId())); - } - } else { - // add one if a partial merge is pending - int size = buffer.size() + (hasPartialReduce ? 1 : 0); - if (size >= batchReduceSize) { - hasPartialReduce = true; - executeNextImmediately = false; - QuerySearchResult[] clone = buffer.stream().toArray(QuerySearchResult[]::new); - MergeTask task = new MergeTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), next); - aggsCurrentBufferSize = 0; - buffer.clear(); - emptyResults.clear(); - queue.add(task); - tryExecuteNext(); - } - if (hasAggs) { - long aggsSize = ramBytesUsedQueryResult(result); - addWithoutBreaking(aggsSize); - aggsCurrentBufferSize += aggsSize; - } - buffer.add(result); - } - } - if (executeNextImmediately) { - next.run(); - } - } + void consume(QuerySearchResult result, Runnable callback) { + checkCancellation(); - /** - * This method is needed to prevent OOM when the buffered results are too large - * - */ - private void checkCircuitBreaker(Runnable next) throws CircuitBreakingException { - try { - // force the CircuitBreaker eval to ensure during buffering we did not hit the circuit breaker limit - addEstimateAndMaybeBreak(0); - } catch (CircuitBreakingException e) { - resetCircuitBreakerForCurrentRequest(); - // onPartialMergeFailure should only be invoked once since this is responsible for cancelling the - // search task - if (!hasFailure()) { - failure.set(e); - onPartialMergeFailure.accept(e); - } + if (consumeResult(result, callback)) { + callback.run(); } } - private synchronized void onMergeFailure(Exception exc) { + private synchronized boolean consumeResult(QuerySearchResult result, Runnable callback) { if (hasFailure()) { - assert circuitBreakerBytes == 0; - return; - } - assert circuitBreakerBytes >= 0; - resetCircuitBreakerForCurrentRequest(); - failure.compareAndSet(null, exc); - MergeTask task = runningTask.get(); - runningTask.compareAndSet(task, null); - onPartialMergeFailure.accept(exc); - clearPendingMerges(task); - } - - void clearPendingMerges(MergeTask task) { - List toCancels = new ArrayList<>(); - if (task != null) { - toCancels.add(task); - } - queue.stream().forEach(toCancels::add); - queue.clear(); - mergeResult = null; - for (MergeTask toCancel : toCancels) { - toCancel.cancel(); + result.consumeAll(); // release memory + return true; } - } - - private void resetCircuitBreakerForCurrentRequest() { - if (circuitBreakerBytes > 0) { - circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); - circuitBreakerBytes = 0; + if (result.isNull()) { + SearchShardTarget target = result.getSearchShardTarget(); + emptyResults.add(new SearchShard(target.getClusterAlias(), target.getShardId())); + return true; } - } - - private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedSize) { - synchronized (this) { - runningTask.compareAndSet(task, null); - if (hasFailure()) { - task.cancel(); - return; - } - mergeResult = newResult; - if (hasAggs) { - // Update the circuit breaker to remove the size of the source aggregations - // and replace the estimation with the serialized size of the newly reduced result. - long newSize = mergeResult.estimatedSize - estimatedSize; - addWithoutBreaking(newSize); - logger.trace( - "aggs partial reduction [{}->{}] max [{}]", - estimatedSize, - mergeResult.estimatedSize, - maxAggsCurrentBufferSize - ); + // Check circuit breaker before consuming + if (hasAggs) { + long aggsSize = ramBytesUsedQueryResult(result); + try { + addEstimateAndMaybeBreak(aggsSize); + aggsCurrentBufferSize += aggsSize; + } catch (CircuitBreakingException e) { + onFailure(e); + return true; } - task.consumeListener(); } + // Process non-empty results + int size = buffer.size() + (hasPartialReduce ? 1 : 0); + if (size >= batchReduceSize) { + hasPartialReduce = true; + // the callback must wait for the new reduce task to complete to maintain proper result processing order + QuerySearchResult[] clone = buffer.toArray(QuerySearchResult[]::new); + ReduceTask task = new ReduceTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), callback); + aggsCurrentBufferSize = 0; + buffer.clear(); + emptyResults.clear(); + queue.add(task); + tryExecuteNext(); + buffer.add(result); + return false; // callback will be run by reduce task + } + buffer.add(result); + return true; } private void tryExecuteNext() { - final MergeTask task; + final ReduceTask task; synchronized (this) { if (hasFailure()) { - clearPendingMerges(null); return; } if (queue.isEmpty() || runningTask.get() != null) { @@ -517,49 +465,102 @@ private void tryExecuteNext() { executor.execute(new AbstractRunnable() { @Override protected void doRun() { - final MergeResult thisMergeResult = mergeResult; - long estimatedTotalSize = (thisMergeResult != null ? thisMergeResult.estimatedSize : 0) + task.aggsBufferSize; - final MergeResult newMerge; + final ReduceResult thisReduceResult = reduceResult; + long estimatedTotalSize = (thisReduceResult != null ? thisReduceResult.estimatedSize : 0) + task.aggsBufferSize; + final ReduceResult newReduceResult; try { final QuerySearchResult[] toConsume = task.consumeBuffer(); if (toConsume == null) { - task.cancel(); + onAfterReduce(task, null, 0); return; } - long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize); - addEstimateAndMaybeBreak(estimatedMergeSize); - estimatedTotalSize += estimatedMergeSize; + long estimateRamBytesUsedForReduce = estimateRamBytesUsedForReduce(estimatedTotalSize); + addEstimateAndMaybeBreak(estimateRamBytesUsedForReduce); + estimatedTotalSize += estimateRamBytesUsedForReduce; ++numReducePhases; - newMerge = partialReduce(toConsume, task.emptyResults, topDocsStats, thisMergeResult, numReducePhases); + newReduceResult = partialReduce(toConsume, task.emptyResults, topDocsStats, thisReduceResult, numReducePhases); } catch (Exception t) { - onMergeFailure(t); + PendingReduces.this.onFailure(t); return; } - onAfterMerge(task, newMerge, estimatedTotalSize); - tryExecuteNext(); + onAfterReduce(task, newReduceResult, estimatedTotalSize); } @Override public void onFailure(Exception exc) { - onMergeFailure(exc); + PendingReduces.this.onFailure(exc); } }); } - public synchronized SearchPhaseController.TopDocsStats consumeTopDocsStats() { + private void onAfterReduce(ReduceTask task, ReduceResult newResult, long estimatedSize) { + if (newResult != null) { + synchronized (this) { + if (hasFailure()) { + return; + } + runningTask.compareAndSet(task, null); + reduceResult = newResult; + if (hasAggs) { + // Update the circuit breaker to remove the size of the source aggregations + // and replace the estimation with the serialized size of the newly reduced result. + long newSize = reduceResult.estimatedSize - estimatedSize; + addWithoutBreaking(newSize); + logger.trace( + "aggs partial reduction [{}->{}] max [{}]", + estimatedSize, + reduceResult.estimatedSize, + maxAggsCurrentBufferSize + ); + } + } + } + task.consumeListener(); + executor.execute(this::tryExecuteNext); + } + + // Idempotent and thread-safe failure handling + private synchronized void onFailure(Exception exc) { + if (hasFailure()) { + assert circuitBreakerBytes == 0; + return; + } + assert circuitBreakerBytes >= 0; + resetCircuitBreaker(); + failure.compareAndSet(null, exc); + clearReduceTaskQueue(); + cancelTaskOnFailure.accept(exc); + } + + private synchronized void clearReduceTaskQueue() { + ReduceTask task = runningTask.get(); + runningTask.compareAndSet(task, null); + List toCancels = new ArrayList<>(); + if (task != null) { + toCancels.add(task); + } + toCancels.addAll(queue); + queue.clear(); + reduceResult = null; + for (ReduceTask toCancel : toCancels) { + toCancel.cancel(); + } + } + + private synchronized SearchPhaseController.TopDocsStats consumeTopDocsStats() { for (QuerySearchResult result : buffer) { topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); } return topDocsStats; } - public synchronized List consumeTopDocs() { + private synchronized List consumeTopDocs() { if (hasTopDocs == false) { return Collections.emptyList(); } List topDocsList = new ArrayList<>(); - if (mergeResult != null) { - topDocsList.add(mergeResult.reducedTopDocs); + if (reduceResult != null) { + topDocsList.add(reduceResult.reducedTopDocs); } for (QuerySearchResult result : buffer) { TopDocsAndMaxScore topDocs = result.consumeTopDocs(); @@ -569,13 +570,13 @@ public synchronized List consumeTopDocs() { return topDocsList; } - public synchronized List consumeAggs() { + private synchronized List consumeAggs() { if (hasAggs == false) { return Collections.emptyList(); } List aggsList = new ArrayList<>(); - if (mergeResult != null) { - aggsList.add(mergeResult.reducedAggs); + if (reduceResult != null) { + aggsList.add(reduceResult.reducedAggs); } for (QuerySearchResult result : buffer) { aggsList.add(result.consumeAggs().expand()); @@ -585,41 +586,26 @@ public synchronized List consumeAggs() { } /** - * A single merge result + * Immutable container holding the outcome of a partial reduce operation * * @opensearch.internal */ - private static class MergeResult { - private final List processedShards; - private final TopDocs reducedTopDocs; - private final InternalAggregations reducedAggs; - private final long estimatedSize; - - private MergeResult( - List processedShards, - TopDocs reducedTopDocs, - InternalAggregations reducedAggs, - long estimatedSize - ) { - this.processedShards = processedShards; - this.reducedTopDocs = reducedTopDocs; - this.reducedAggs = reducedAggs; - this.estimatedSize = estimatedSize; - } + private record ReduceResult(List processedShards, TopDocs reducedTopDocs, InternalAggregations reducedAggs, + long estimatedSize) { } /** - * A single merge task + * ReduceTask is created to reduce buffered query results when buffer size hits threshold * * @opensearch.internal */ - private static class MergeTask { + private static class ReduceTask { private final List emptyResults; private QuerySearchResult[] buffer; - private long aggsBufferSize; + private final long aggsBufferSize; private Runnable next; - private MergeTask(QuerySearchResult[] buffer, long aggsBufferSize, List emptyResults, Runnable next) { + private ReduceTask(QuerySearchResult[] buffer, long aggsBufferSize, List emptyResults, Runnable next) { this.buffer = buffer; this.aggsBufferSize = aggsBufferSize; this.emptyResults = emptyResults; diff --git a/server/src/main/java/org/opensearch/action/search/StreamQueryPhaseResultConsumer.java b/server/src/main/java/org/opensearch/action/search/StreamQueryPhaseResultConsumer.java index 6186e4546afc5..75612b081e5e5 100644 --- a/server/src/main/java/org/opensearch/action/search/StreamQueryPhaseResultConsumer.java +++ b/server/src/main/java/org/opensearch/action/search/StreamQueryPhaseResultConsumer.java @@ -59,6 +59,6 @@ void consumeStreamResult(SearchPhaseResult result, Runnable next) { // For streaming, we skip the ArraySearchPhaseResults.consumeResult() call // since it doesn't support multiple results from the same shard. QuerySearchResult querySearchResult = result.queryResult(); - pendingMerges.consume(querySearchResult, next); + pendingReduces.consume(querySearchResult, next); } } diff --git a/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java index 63c67c1e0cd6b..defe13e2e320d 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java @@ -1065,7 +1065,7 @@ private void consumerTestCase(int numEmptyResponses) throws Exception { SearchRequest request = randomSearchRequest(); request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); request.setBatchedReduceSize(bufferSize); - ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults( + QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults( fixedExecutor, new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, @@ -1134,18 +1134,17 @@ private void consumerTestCase(int numEmptyResponses) throws Exception { result.setSearchShardTarget(new SearchShardTarget("node", new ShardId("a", "b", shardId), null, OriginalIndices.NONE)); consumer.consumeResult(result, latch::countDown); numEmptyResponses--; - } latch.await(); final int numTotalReducePhases; if (numShards > bufferSize) { if (bufferSize == 2) { - assertEquals(1, ((QueryPhaseResultConsumer) consumer).getNumReducePhases()); + assertEquals(1, consumer.getNumReducePhases()); assertEquals(1, reductions.size()); assertEquals(false, reductions.get(0)); numTotalReducePhases = 2; } else { - assertEquals(0, ((QueryPhaseResultConsumer) consumer).getNumReducePhases()); + assertEquals(0, consumer.getNumReducePhases()); assertEquals(0, reductions.size()); numTotalReducePhases = 1; }