Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Implement GRPC Script query ([#19455](https://github.com/opensearch-project/OpenSearch/pull/19455))
- [Search Stats] Add search & star-tree search query failure count metrics ([#19210](https://github.com/opensearch-project/OpenSearch/issues/19210))
- [Star-tree] Support for multi-terms aggregation ([#18398](https://github.com/opensearch-project/OpenSearch/issues/18398))
- Add stream search feature flag and auto fallback logic ([#19373](https://github.com/opensearch-project/OpenSearch/pull/19373))

### Changed
- Refactor `if-else` chains to use `Java 17 pattern matching switch expressions`(([#18965](https://github.com/opensearch-project/OpenSearch/pull/18965))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ protected FeatureFlagSettings(
FeatureFlags.TERM_VERSION_PRECOMMIT_ENABLE_SETTING,
FeatureFlags.ARROW_STREAMS_SETTING,
FeatureFlags.STREAM_TRANSPORT_SETTING,
FeatureFlags.MERGED_SEGMENT_WARMER_EXPERIMENTAL_SETTING
FeatureFlags.MERGED_SEGMENT_WARMER_EXPERIMENTAL_SETTING,
FeatureFlags.STREAM_SEARCH_SETTING
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ public class FeatureFlags {
public static final String ARROW_STREAMS = FEATURE_FLAG_PREFIX + "arrow.streams.enabled";
public static final Setting<Boolean> ARROW_STREAMS_SETTING = Setting.boolSetting(ARROW_STREAMS, false, Property.NodeScope);

public static final String STREAM_SEARCH = FEATURE_FLAG_PREFIX + "stream.search.enabled";
public static final Setting<Boolean> STREAM_SEARCH_SETTING = Setting.boolSetting(STREAM_SEARCH, false, Property.NodeScope);

/**
* Underlying implementation for feature flags.
* All settable feature flags are tracked here in FeatureFlagsImpl.featureFlags.
Expand All @@ -145,6 +148,7 @@ static class FeatureFlagsImpl {
put(TERM_VERSION_PRECOMMIT_ENABLE_SETTING, TERM_VERSION_PRECOMMIT_ENABLE_SETTING.getDefault(Settings.EMPTY));
put(ARROW_STREAMS_SETTING, ARROW_STREAMS_SETTING.getDefault(Settings.EMPTY));
put(STREAM_TRANSPORT_SETTING, STREAM_TRANSPORT_SETTING.getDefault(Settings.EMPTY));
put(STREAM_SEARCH_SETTING, STREAM_SEARCH_SETTING.getDefault(Settings.EMPTY));
put(MERGED_SEGMENT_WARMER_EXPERIMENTAL_SETTING, MERGED_SEGMENT_WARMER_EXPERIMENTAL_SETTING.getDefault(Settings.EMPTY));
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

package org.opensearch.rest.action.search;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.search.SearchAction;
Expand All @@ -52,6 +54,8 @@
import org.opensearch.rest.action.RestStatusToXContentListener;
import org.opensearch.search.Scroll;
import org.opensearch.search.SearchService;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.StoredFieldsContext;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
Expand Down Expand Up @@ -83,6 +87,7 @@
* @opensearch.api
*/
public class RestSearchAction extends BaseRestHandler {
private static final Logger logger = LogManager.getLogger(RestSearchAction.class);
/**
* Indicates whether hits.total should be rendered as an integer or an object
* in the rest search response.
Expand Down Expand Up @@ -136,13 +141,16 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
parser -> parseSearchRequest(searchRequest, request, parser, client.getNamedWriteableRegistry(), setSize)
);

boolean stream = request.paramAsBoolean("stream", false);
if (stream) {
if (FeatureFlags.isEnabled(FeatureFlags.STREAM_SEARCH)) {
if (FeatureFlags.isEnabled(FeatureFlags.STREAM_TRANSPORT)) {
return channel -> {
RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
cancelClient.execute(StreamSearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel));
};
if (canUseStreamSearch(searchRequest)) {
return channel -> {
RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
cancelClient.execute(StreamSearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel));
};
} else {
logger.debug("Stream search requested but search contains unsupported aggregations. Falling back to normal search.");
}
} else {
throw new IllegalArgumentException("You need to enable stream transport first to use stream search.");
}
Expand Down Expand Up @@ -435,4 +443,24 @@ protected Set<String> responseParams() {
public boolean allowsUnsafeBuffers() {
return true;
}

/**
* Determines if a search request can use stream search.
*
* @param searchRequest the search request to validate
* @return true if the request can use stream search, false otherwise
*/
static boolean canUseStreamSearch(SearchRequest searchRequest) {
if (searchRequest.source() == null || searchRequest.source().aggregations() == null) {
return false; // No aggregations, stream search is not allowed
}

AggregatorFactories.Builder aggregations = searchRequest.source().aggregations();
if (aggregations.count() != 1) {
return false; // Must have exactly one aggregation
}

// Check if the single aggregation is a terms aggregation
return aggregations.getAggregatorFactories().stream().anyMatch(factory -> factory instanceof TermsAggregationBuilder);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.rest.action.search;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchAction;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.SetOnce;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.rest.RestRequest;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.client.NoOpNodeClient;
import org.opensearch.test.rest.FakeRestChannel;
import org.opensearch.test.rest.FakeRestRequest;
import org.opensearch.transport.client.node.NodeClient;

import static org.opensearch.common.util.FeatureFlags.STREAM_SEARCH;
import static org.opensearch.common.util.FeatureFlags.STREAM_TRANSPORT;
import static org.hamcrest.Matchers.equalTo;

public class RestSearchActionTests extends OpenSearchTestCase {

private NodeClient createMockNodeClient(SetOnce<ActionType<?>> capturedActionType) {
return new NoOpNodeClient(this.getTestName()) {
@Override
public <Request extends ActionRequest, Response extends ActionResponse> Task executeLocally(
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
capturedActionType.set(action);
listener.onResponse(null);
return new Task(1L, "test", action.name(), "test task", null, null);
}

@Override
public String getLocalNodeId() {
return "test-node";
}
};
}

private void testActionExecution(ActionType<?> expectedAction) throws Exception {
SetOnce<ActionType<?>> capturedActionType = new SetOnce<>();
try (NodeClient nodeClient = createMockNodeClient(capturedActionType)) {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).build();
FakeRestChannel channel = new FakeRestChannel(request, false, 0);

new RestSearchAction().handleRequest(request, channel, nodeClient);

assertThat(capturedActionType.get(), equalTo(expectedAction));
}
}

public void testWithSearchStreamFlagDisabled() throws Exception {
// When SEARCH_STREAM flag is disabled, always use SearchAction
testActionExecution(SearchAction.INSTANCE);
}

@LockFeatureFlag(STREAM_SEARCH)
public void testWithStreamSearchEnabledButStreamTransportDisabled() throws Exception {
// When SEARCH_STREAM is enabled but STREAM_TRANSPORT is disabled, should throw exception
try (NodeClient nodeClient = new NoOpNodeClient(this.getTestName())) {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).build();
FakeRestChannel channel = new FakeRestChannel(request, false, 0);

Exception e = expectThrows(
IllegalArgumentException.class,
() -> new RestSearchAction().handleRequest(request, channel, nodeClient)
);
assertThat(e.getMessage(), equalTo("You need to enable stream transport first to use stream search."));
}
}

public void testWithStreamSearchAndTransportEnabled() throws Exception {
// When both SEARCH_STREAM and STREAM_TRANSPORT are enabled, should use StreamSearchAction
try (
FeatureFlags.TestUtils.FlagWriteLock searchStreamLock = new FeatureFlags.TestUtils.FlagWriteLock(STREAM_SEARCH);
FeatureFlags.TestUtils.FlagWriteLock streamTransportLock = new FeatureFlags.TestUtils.FlagWriteLock(STREAM_TRANSPORT)
) {
testActionExecution(SearchAction.INSTANCE);
}
}

// Tests for canUseStreamSearch method
public void testCanUseStreamSearchWithNullSource() {
SearchRequest searchRequest = new SearchRequest();
assertFalse(RestSearchAction.canUseStreamSearch(searchRequest));
}

public void testCanUseStreamSearchWithNoAggregations() {
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder source = new SearchSourceBuilder();
source.query(QueryBuilders.matchAllQuery());
searchRequest.source(source);
assertFalse(RestSearchAction.canUseStreamSearch(searchRequest));
}

public void testCanUseStreamSearchWithSingleTermsAggregation() {
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder source = new SearchSourceBuilder();
source.aggregation(AggregationBuilders.terms("test_terms").field("category"));
searchRequest.source(source);
assertTrue(RestSearchAction.canUseStreamSearch(searchRequest));
}

public void testCanUseStreamSearchWithMultipleAggregations() {
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder source = new SearchSourceBuilder();
source.aggregation(AggregationBuilders.terms("test_terms").field("category"));
source.aggregation(AggregationBuilders.avg("test_avg").field("price"));
searchRequest.source(source);
assertFalse(RestSearchAction.canUseStreamSearch(searchRequest));
}

public void testCanUseStreamSearchWithSingleNonTermsAggregation() {
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder source = new SearchSourceBuilder();
source.aggregation(AggregationBuilders.avg("test_avg").field("price"));
searchRequest.source(source);
assertFalse(RestSearchAction.canUseStreamSearch(searchRequest));
}

public void testCanUseStreamSearchWithSingleHistogramAggregation() {
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder source = new SearchSourceBuilder();
source.aggregation(AggregationBuilders.histogram("test_histogram").field("timestamp").interval(1000));
searchRequest.source(source);
assertFalse(RestSearchAction.canUseStreamSearch(searchRequest));
}
}
Loading