From addf0574f2f40a4297ad72d10d50e89d6fcd71a3 Mon Sep 17 00:00:00 2001 From: Yuanchun Shen Date: Mon, 11 Aug 2025 10:32:44 +0800 Subject: [PATCH 01/30] Update parameter handling of tools (#618) * Add parameter extraction utilities for tool inputs - Add utilities for extracting required parameters and JSON input parameters - Apply parameter extraction in AbstractRetrieverTool and RAGTool - Define TOOL_REQUIRED_PARAMS constant for consistent parameter handling Signed-off-by: Yuanchun Shen * Standardize parameter handling in all Tool implementations - Update all Tool interface implementations to use extractInputParameters utility Signed-off-by: Yuanchun Shen * Update release note Signed-off-by: Yuanchun Shen * Declare origin of helper method extractInputParameters Signed-off-by: Yuanchun Shen * Remove displaced comment in javadoc Signed-off-by: Yuanchun Shen * Fix failed test in AbstractRetrieverToolTests Signed-off-by: Yuanchun Shen * Replace copied tool utils to library ones Signed-off-by: Yuanchun Shen --------- Signed-off-by: Yuanchun Shen --- .../opensearch-skills.release-notes-3.2.0.0.md | 14 ++++++-------- .../agent/tools/AbstractRetrieverTool.java | 4 +++- .../opensearch/agent/tools/CreateAlertTool.java | 4 +++- .../agent/tools/CreateAnomalyDetectorTool.java | 2 ++ .../org/opensearch/agent/tools/DynamicTool.java | 4 +++- .../org/opensearch/agent/tools/LogPatternTool.java | 4 +++- .../java/org/opensearch/agent/tools/PPLTool.java | 4 +++- .../java/org/opensearch/agent/tools/RAGTool.java | 5 ++++- .../opensearch/agent/tools/SearchAlertsTool.java | 4 +++- .../agent/tools/SearchAnomalyDetectorsTool.java | 4 +++- .../agent/tools/SearchAnomalyResultsTool.java | 4 +++- .../opensearch/agent/tools/SearchMonitorsTool.java | 4 +++- .../org/opensearch/agent/tools/WebSearchTool.java | 4 +++- .../agent/tools/utils/ToolConstants.java | 1 - .../agent/tools/AbstractRetrieverToolTests.java | 4 +++- 15 files changed, 45 insertions(+), 21 deletions(-) diff --git a/release-notes/opensearch-skills.release-notes-3.2.0.0.md b/release-notes/opensearch-skills.release-notes-3.2.0.0.md index 7203ddd3..17a9f533 100644 --- a/release-notes/opensearch-skills.release-notes-3.2.0.0.md +++ b/release-notes/opensearch-skills.release-notes-3.2.0.0.md @@ -1,20 +1,18 @@ -## Version 3.2.0 Release Notes +## Version 3.2.0.0 Release Notes -Compatible with OpenSearch and OpenSearch Dashboards version 3.2.0 +Compatible with OpenSearch and OpenSearch Dashboards version 3.2.0.0 ### Features * Support dynamic tool in agent framework ([#606](https://github.com/opensearch-project/skills/pull/606)) ### Enhancements * Merge index schema meta ([#596](https://github.com/opensearch-project/skills/pull/596)) +* Mask error message in PPLTool ([#609](https://github.com/opensearch-project/skills/pull/609)) ### Bug Fixes * Fix attributes handling in dynamic tool ([#607](https://github.com/opensearch-project/skills/pull/607)) -* Mask error message in PPLTool ([#609](https://github.com/opensearch-project/skills/pull/609)) - -### Infrastructure -* Update the maven snapshot publish endpoint and credential ([#601](https://github.com/opensearch-project/skills/pull/601)) -* Gradle and Lombok bump, changing CI java to 24 and adjusting AD getConfigRequest ([#615](https://github.com/opensearch-project/skills/pull/615)) ### Maintenance -* [AUTO] Increment version to 3.2.0-SNAPSHOT ([#605](https://github.com/opensearch-project/skills/pull/605)) \ No newline at end of file +* Update the maven snapshot publish endpoint and credential ([#601](https://github.com/opensearch-project/skills/pull/601)) +* Bump gradle, java, lombok and fix ad configrequest change ([#615](https://github.com/opensearch-project/skills/pull/615)) +* Bump version to 3.2.0.0 ([#605](https://github.com/opensearch-project/skills/pull/605)) diff --git a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java index dd713ae6..4865e70a 100644 --- a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java +++ b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java @@ -22,6 +22,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.transport.client.Client; @@ -94,7 +95,8 @@ protected SearchRequest buildSearchRequest(Map parameters) t } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); SearchRequest searchRequest; try { searchRequest = buildSearchRequest(parameters); diff --git a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java index 368807c0..3e7096a5 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java @@ -39,6 +39,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.transport.client.Client; import com.google.gson.reflect.TypeToken; @@ -133,7 +134,8 @@ public boolean validate(Map parameters) { } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); Map tmpParams = new HashMap<>(parameters); if (!tmpParams.containsKey("indices") || Strings.isEmpty(tmpParams.get("indices"))) { throw new IllegalArgumentException( diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index 2beb4cf4..2c4a2273 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -44,6 +44,7 @@ import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.transport.client.Client; import com.google.common.collect.ImmutableMap; @@ -169,6 +170,7 @@ public CreateAnomalyDetectorTool(Client client, String modelId, String modelType */ @Override public void run(Map parameters, ActionListener listener) { + parameters = ToolUtils.extractInputParameters(parameters, attributes); final String tenantId = parameters.get(TENANT_ID_FIELD); Map enrichedParameters = enrichParameters(parameters); String indexName = enrichedParameters.get("index"); diff --git a/src/main/java/org/opensearch/agent/tools/DynamicTool.java b/src/main/java/org/opensearch/agent/tools/DynamicTool.java index 82c67f73..ae143520 100644 --- a/src/main/java/org/opensearch/agent/tools/DynamicTool.java +++ b/src/main/java/org/opensearch/agent/tools/DynamicTool.java @@ -25,6 +25,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.rest.DynamicRestRequestCreator; import org.opensearch.rest.DynamicToolExecutor; import org.opensearch.rest.RestRequest; @@ -114,7 +115,8 @@ public boolean validate(Map map) { } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); RestRequest.Method method = RestRequest.Method.valueOf(parameters.get(METHOD_KEY)); String uri = parameters.get(URI_KEY); String requestBody = parameters.get(REQUEST_BODY_KEY); diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternTool.java index 4359b2dc..70464156 100644 --- a/src/main/java/org/opensearch/agent/tools/LogPatternTool.java +++ b/src/main/java/org/opensearch/agent/tools/LogPatternTool.java @@ -31,6 +31,7 @@ import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.sql.plugin.transport.PPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; @@ -107,7 +108,8 @@ protected String getQueryBody(String queryText) { } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); String dsl = parameters.get(INPUT_FIELD); String ppl = parameters.get(PPL_FIELD); if (!StringUtils.isBlank(dsl)) { diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index b1450cbf..82ec6871 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -52,6 +52,7 @@ import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.plugin.transport.PPLQueryAction; @@ -196,7 +197,8 @@ public PPLTool( @SuppressWarnings("unchecked") @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String tenantId = parameters.get(TENANT_ID_FIELD); extractFromChatParameters(parameters); String indexName = getIndexNameFromParameters(parameters); diff --git a/src/main/java/org/opensearch/agent/tools/RAGTool.java b/src/main/java/org/opensearch/agent/tools/RAGTool.java index 7771ca66..c1a32667 100644 --- a/src/main/java/org/opensearch/agent/tools/RAGTool.java +++ b/src/main/java/org/opensearch/agent/tools/RAGTool.java @@ -28,6 +28,7 @@ import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.transport.client.Client; import com.google.gson.Gson; @@ -95,7 +96,9 @@ public Object parse(Object o) { }; } - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); + final String tenantId = parameters.get(TENANT_ID_FIELD); String input = null; diff --git a/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java index cab2bc7c..e144dd83 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java @@ -21,6 +21,7 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.transport.client.Client; import org.opensearch.transport.client.node.NodeClient; @@ -70,7 +71,8 @@ public Object parse(Object o) { } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String tableSortOrder = parameters.getOrDefault("sortOrder", "asc"); final String tableSortString = parameters.getOrDefault("sortString", "monitor_name.keyword"); final int tableSize = parameters.containsKey("size") && StringUtils.isNumeric(parameters.get("size")) diff --git a/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java index cd6772a0..9830f5d6 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java @@ -36,6 +36,7 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.SortOrder; @@ -94,7 +95,8 @@ public Object parse(Object o) { // number of total detectors. The output will likely need to be updated, standardized, and include more fields in the // future to cover a sufficient amount of potential questions the agent will need to handle. @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String detectorName = parameters.getOrDefault("detectorName", null); final String detectorNamePattern = parameters.getOrDefault("detectorNamePattern", null); final String indices = parameters.getOrDefault("indices", null); diff --git a/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java index 76b322cb..b7936417 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java @@ -26,6 +26,7 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; @@ -84,7 +85,8 @@ public Object parse(Object o) { // and total # of results. The output will likely need to be updated, standardized, and include more fields in the // future to cover a sufficient amount of potential questions the agent will need to handle. @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String detectorId = parameters.getOrDefault("detectorId", null); final Boolean realTime = parameters.containsKey("realTime") ? Boolean.parseBoolean(parameters.get("realTime")) : null; final Double anomalyGradeThreshold = parameters.containsKey("anomalyGradeThreshold") diff --git a/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java b/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java index 0e928c73..91c2bf14 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java @@ -30,6 +30,7 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.SortOrder; @@ -84,7 +85,8 @@ public Object parse(Object o) { // number of total monitors. The output will likely need to be updated, standardized, and include more fields in the // future to cover a sufficient amount of potential questions the agent will need to handle. @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String monitorId = parameters.getOrDefault("monitorId", null); final String monitorName = parameters.getOrDefault("monitorName", null); final String monitorNamePattern = parameters.getOrDefault("monitorNamePattern", null); diff --git a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java index c7081fcb..047d6769 100644 --- a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java +++ b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java @@ -33,6 +33,7 @@ import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; @@ -99,7 +100,8 @@ public WebSearchTool(ThreadPool threadPool) { } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); try { // common search parameters String query = parameters.getOrDefault("query", parameters.get("question")).replaceAll(" ", "+"); diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java b/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java index b5433a0e..fa77024f 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java @@ -36,5 +36,4 @@ public static ModelType from(String value) { public static final String ALERTING_CONFIG_INDEX = ".opendistro-alerting-config"; public static final String ALERTING_ALERTS_INDEX = ".opendistro-alerting-alerts"; - } diff --git a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java index 63ed33af..e93e661a 100644 --- a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java @@ -165,7 +165,9 @@ public void testRunAsyncWithIllegalQueryThenListenerOnFailure() { mockedImpl.run(null, listener4); Exception exception4 = assertThrows(Exception.class, future4::join); - assertTrue(exception4.getCause() instanceof NullPointerException); + // parameter is re-created with extractInputParameters, thus will not be null + assertTrue(exception4.getCause() instanceof IllegalArgumentException); + assertEquals(exception4.getCause().getMessage(), "[input] is null or empty, can not process it."); } @Test From a33cbb33455cf42b3c0362268ad541acf316eaad Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 11 Aug 2025 18:46:23 +0800 Subject: [PATCH 02/30] Remove dynamic tool (#620) Signed-off-by: zane-neo --- .../java/org/opensearch/agent/ToolPlugin.java | 8 +- .../opensearch/agent/tools/DynamicTool.java | 236 ------------ .../rest/DynamicRestRequestCreator.java | 143 -------- .../opensearch/rest/DynamicToolExecutor.java | 66 ---- .../org/opensearch/agent/ToolPluginTests.java | 2 +- .../agent/tools/DynamicToolTests.java | 336 ------------------ .../rest/DynamicRestRequestCreatorTests.java | 70 ---- .../rest/DynamicToolExecutorTests.java | 114 ------ 8 files changed, 2 insertions(+), 973 deletions(-) delete mode 100644 src/main/java/org/opensearch/agent/tools/DynamicTool.java delete mode 100644 src/main/java/org/opensearch/rest/DynamicRestRequestCreator.java delete mode 100644 src/main/java/org/opensearch/rest/DynamicToolExecutor.java delete mode 100644 src/test/java/org/opensearch/agent/tools/DynamicToolTests.java delete mode 100644 src/test/java/org/opensearch/rest/DynamicRestRequestCreatorTests.java delete mode 100644 src/test/java/org/opensearch/rest/DynamicToolExecutorTests.java diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 8aa007d6..c1931f89 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -13,7 +13,6 @@ import org.opensearch.agent.tools.CreateAlertTool; import org.opensearch.agent.tools.CreateAnomalyDetectorTool; -import org.opensearch.agent.tools.DynamicTool; import org.opensearch.agent.tools.LogPatternTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; @@ -41,8 +40,6 @@ import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.repositories.RepositoriesService; -import org.opensearch.rest.DynamicRestRequestCreator; -import org.opensearch.rest.DynamicToolExecutor; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; @@ -103,8 +100,6 @@ public Collection createComponents( CreateAnomalyDetectorTool.Factory.getInstance().init(client); LogPatternTool.Factory.getInstance().init(client, xContentRegistry); WebSearchTool.Factory.getInstance().init(threadPool); - DynamicToolExecutor toolExecutor = new DynamicToolExecutor(restControllerRef, client); - DynamicTool.Factory.getInstance().init(client, toolExecutor, new DynamicRestRequestCreator(), xContentRegistry); return Collections.emptyList(); } @@ -123,8 +118,7 @@ public List> getToolFactories() { CreateAlertTool.Factory.getInstance(), CreateAnomalyDetectorTool.Factory.getInstance(), LogPatternTool.Factory.getInstance(), - WebSearchTool.Factory.getInstance(), - DynamicTool.Factory.getInstance() + WebSearchTool.Factory.getInstance() ); } diff --git a/src/main/java/org/opensearch/agent/tools/DynamicTool.java b/src/main/java/org/opensearch/agent/tools/DynamicTool.java deleted file mode 100644 index ae143520..00000000 --- a/src/main/java/org/opensearch/agent/tools/DynamicTool.java +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.agent.tools; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.text.StringSubstitutor; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.MediaType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.spi.tools.Tool; -import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import org.opensearch.ml.common.utils.ToolUtils; -import org.opensearch.rest.DynamicRestRequestCreator; -import org.opensearch.rest.DynamicToolExecutor; -import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestResponse; -import org.opensearch.transport.client.Client; - -import com.jayway.jsonpath.JsonPath; - -@ToolAnnotation(DynamicTool.TYPE) -public class DynamicTool implements Tool { - - private static final Logger log = LogManager.getLogger(DynamicTool.class); - public static final String TYPE = "DynamicTool"; - private static final String URI_KEY = "uri"; - private static final String METHOD_KEY = "method"; - private static final String REQUEST_BODY_KEY = "request_body"; - private static final String RESPONSE_FILTER_KEY = "response_filter"; - private static final String DEFAULT_DESCRIPTION = - "This is a template tool to enable OpenSearch APIs as tool, this tool accepts several parameters: uri, method, request_body and response_filter. uri represents the OpenSearch API uri, method represents the" - + "OpenSearch API method, request_body represents the OpenSearch API request body and response_filter is a json path expression so that target fields can be extracted from OpenSearch API response. Most OpenSearch APIs" - + "can be configured with this tool, during agent execution the configured API will be invoked and the response/filtered response will be returned as tool's response."; - - private final Client client; - private final DynamicToolExecutor toolExecutor; - private final DynamicRestRequestCreator dynamicRestRequestCreator; - private final NamedXContentRegistry namedXContentRegistry; - private String name = TYPE; - private String description; - private Map attributes; - - public DynamicTool( - Client client, - DynamicToolExecutor toolExecutor, - DynamicRestRequestCreator dynamicRestRequestCreator, - NamedXContentRegistry namedXContentRegistry - ) { - this.client = client; - this.toolExecutor = toolExecutor; - this.dynamicRestRequestCreator = dynamicRestRequestCreator; - this.namedXContentRegistry = namedXContentRegistry; - } - - @Override - public String getType() { - return TYPE; - } - - @Override - public String getVersion() { - return null; - } - - @Override - public String getName() { - return name; - } - - @Override - public void setName(String s) { - this.name = s; - } - - @Override - public String getDescription() { - return Optional.ofNullable(description).orElse(DEFAULT_DESCRIPTION); - } - - @Override - public Map getAttributes() { - return attributes; - } - - @Override - public void setAttributes(Map map) { - this.attributes = new HashMap<>(); - this.attributes.putAll(map); - } - - @Override - public void setDescription(String s) { - this.description = s; - } - - @Override - public boolean validate(Map map) { - return true; - } - - @Override - public void run(Map originalParameters, ActionListener listener) { - Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); - RestRequest.Method method = RestRequest.Method.valueOf(parameters.get(METHOD_KEY)); - String uri = parameters.get(URI_KEY); - String requestBody = parameters.get(REQUEST_BODY_KEY); - String responseFileter = parameters.get(RESPONSE_FILTER_KEY); - StringSubstitutor substitution = new StringSubstitutor(parameters, "${parameters.", "}"); - uri = substitution.replace(uri); - try { - BytesReference content = null; - if (notNullOrEmpty(requestBody)) { - requestBody = substitution.replace(requestBody); - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - XContentParser parser = MediaType - .fromMediaType("application/json") - .xContent() - .createParser(namedXContentRegistry, DeprecationHandler.IGNORE_DEPRECATIONS, requestBody); - builder.copyCurrentStructure(parser); - content = BytesReference.bytes(builder); - } - Map> clientHeaders = client - .threadPool() - .getThreadContext() - .getHeaders() - .entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> List.of(entry.getValue()))); - RestRequest request = dynamicRestRequestCreator.createRestRequest(namedXContentRegistry, method, uri, content, clientHeaders); - ActionListener actionListener = ActionListener.wrap(r -> { - if (notNullOrEmpty(responseFileter)) { - // fetch with jsonpath from response. - Object result = JsonPath.read(r.content().utf8ToString(), responseFileter); - listener.onResponse((T) String.valueOf(result)); - } else { - listener.onResponse((T) r.content().utf8ToString()); - } - }, e -> { - log.error("Failed to run ToolExecutor", e); - listener.onFailure(e); - }); - toolExecutor.execute(request, actionListener); - } catch (Exception ex) { - log.error("Failed to run DynamicTool", ex); - listener.onFailure(ex); - } - } - - private boolean notNullOrEmpty(String s) { - return s != null && !s.isEmpty() && !"null".equals(s); - } - - public static class Factory implements Tool.Factory { - private Client client; - private DynamicToolExecutor toolExecutor; - private DynamicRestRequestCreator dynamicRestRequestCreator; - private NamedXContentRegistry namedXContentRegistry; - - private static DynamicTool.Factory INSTANCE; - - public static DynamicTool.Factory getInstance() { - if (INSTANCE != null) { - return INSTANCE; - } - synchronized (DynamicTool.class) { - if (INSTANCE != null) { - return INSTANCE; - } - INSTANCE = new DynamicTool.Factory(); - return INSTANCE; - } - } - - public void init( - Client client, - DynamicToolExecutor toolExecutor, - DynamicRestRequestCreator dynamicRestRequestCreator, - NamedXContentRegistry namedXContentRegistry - ) { - this.client = client; - this.toolExecutor = toolExecutor; - this.dynamicRestRequestCreator = dynamicRestRequestCreator; - this.namedXContentRegistry = namedXContentRegistry; - } - - @Override - public DynamicTool create(Map map) { - if (!map.containsKey(URI_KEY) || StringUtils.isBlank(String.valueOf(map.get(URI_KEY)))) { - throw new IllegalArgumentException("valid uri is required in DynamicTool configuration!"); - } - if (!map.containsKey(METHOD_KEY) || map.get(METHOD_KEY) == null) { - throw new IllegalArgumentException("method is required and not null in DynamicTool configuration!"); - } else { - try { - RestRequest.Method.valueOf(String.valueOf(map.get(METHOD_KEY))); - } catch (Exception e) { - throw new IllegalArgumentException("valid method value is required in DynamicTool configuration!"); - } - } - - return new DynamicTool(client, toolExecutor, dynamicRestRequestCreator, namedXContentRegistry); - } - - @Override - public String getDefaultDescription() { - return DEFAULT_DESCRIPTION; - } - - @Override - public String getDefaultType() { - return TYPE; - } - - @Override - public String getDefaultVersion() { - return null; - } - - } -} diff --git a/src/main/java/org/opensearch/rest/DynamicRestRequestCreator.java b/src/main/java/org/opensearch/rest/DynamicRestRequestCreator.java deleted file mode 100644 index 138a8c1c..00000000 --- a/src/main/java/org/opensearch/rest/DynamicRestRequestCreator.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.rest; - -import java.net.InetSocketAddress; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.http.HttpChannel; -import org.opensearch.http.HttpRequest; -import org.opensearch.http.HttpResponse; - -import com.google.common.net.HttpHeaders; - -public class DynamicRestRequestCreator { - public RestRequest createRestRequest( - NamedXContentRegistry namedXContentRegistry, - RestRequest.Method method, - String uri, - BytesReference content, - Map> headers - ) { - HttpRequest httpRequest = new HttpRequest() { - @Override - public RestRequest.Method method() { - return method; - } - - @Override - public String uri() { - return uri; - } - - @Override - public BytesReference content() { - return content; - } - - @Override - public Map> getHeaders() { - // The transport action needs correct headers to work, e.g. credentials so passing the original headers to the created - // request. - Map> internalRequestHeaders = new HashMap<>(headers); - internalRequestHeaders.put(HttpHeaders.CONTENT_TYPE, List.of("application/json")); - return internalRequestHeaders; - } - - @Override - public List strictCookies() { - return List.of(); - } - - @Override - public HttpVersion protocolVersion() { - // This doesn't have actual impact only to ensure no NPE in corner cases. - return HttpRequest.HttpVersion.HTTP_1_0; - } - - @Override - public HttpRequest removeHeader(String s) { - return this; - } - - @Override - public HttpResponse createResponse(RestStatus restStatus, BytesReference bytesReference) { - // An example of overriding this method is: - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/http/DefaultRestChannel.java#L145-L174 - // After the response been created, it's been sent to the RestChannel, but in our case the rest channel is mock and - // sendResponse method will be used to send response to a listener, - // So this method never been invoked, so returning null here. - return null; - } - - @Override - public Exception getInboundException() { - return null; - } - - @Override - public void release() { - // Nothing needs to be released, for other implementation like: - // https://github.com/opensearch-project/OpenSearch/blob/main/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequest.java#L64 - // It needs to release the internal FullHttpRequest resources. - } - - @Override - public HttpRequest releaseAndCopy() { - // Some handlers can't handle pooled buffer correctly then it'll override the allowUnsafeBuffers method, e.g.: - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/rest/action/document/RestBulkAction.java#L128 - // This HttpRequest is not created from the pooled buffer, so it's safe to return itself. - return this; - } - }; - HttpChannel httpChannel = new HttpChannel() { - @Override - public void sendResponse(HttpResponse httpResponse, ActionListener actionListener) { - // This is used in: - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/http/DefaultRestChannel.java#L145-L174, - // But since this method is mainly invoked by - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/rest/action/RestResponseListener.java#L52 - // and since the RestChannel.sendResponse has been rewritten to invoke actionListener, so this method won't be invoked in - // current case. - } - - @Override - public InetSocketAddress getLocalAddress() { - // Just for logging in Netty4HttpChannel, safe to return null. - return null; - } - - @Override - public InetSocketAddress getRemoteAddress() { - // Just for logging in Netty4HttpChannel, safe to return null. - return null; - } - - @Override - public void close() { - // Close resources that needs to be closed which hold by the channel, in this case nothing to close. - } - - @Override - public void addCloseListener(ActionListener actionListener) { - // No resources need to add listener - } - - @Override - public boolean isOpen() { - return true; - } - }; - return RestRequest.request(namedXContentRegistry, httpRequest, httpChannel); - } - -} diff --git a/src/main/java/org/opensearch/rest/DynamicToolExecutor.java b/src/main/java/org/opensearch/rest/DynamicToolExecutor.java deleted file mode 100644 index afb53435..00000000 --- a/src/main/java/org/opensearch/rest/DynamicToolExecutor.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.rest; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Locale; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; - -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.transport.client.Client; -import org.opensearch.transport.client.node.NodeClient; - -public class DynamicToolExecutor { - private final AtomicReference restControllerRef; - private final Client client; - - public DynamicToolExecutor(AtomicReference restControllerRef, Client nodeClient) { - this.restControllerRef = restControllerRef; - this.client = nodeClient; - } - - public void execute(RestRequest request, ActionListener listener) throws Exception { - String rawPath = request.rawPath(); - String uri = request.uri(); - RestRequest.Method requestMethod = request.method(); - - Optional restHandler = restControllerRef.get().dispatchHandler(uri, rawPath, requestMethod, request.params()); - RestChannel dummyChannel = new AbstractRestChannel(request, true) { - @Override - public void sendResponse(RestResponse response) { - // This supposes to be the API's response, and will be encapsulated in the agent response, so either the API succeed or not, - // we use onResponse. - listener.onResponse(response); - } - }; - if (restHandler.isEmpty()) { - listener.onResponse(new RestResponse() { - @Override - public String contentType() { - return "text/plain"; - } - - @Override - public BytesReference content() { - String errorMessage = String - .format(Locale.ROOT, "No handler found for %s, please check your agent configuration!", uri); - return BytesReference.fromByteBuffer(ByteBuffer.wrap(errorMessage.getBytes(StandardCharsets.UTF_8))); - } - - @Override - public RestStatus status() { - return RestStatus.BAD_REQUEST; - } - }); - } else { - restHandler.get().handleRequest(request, dummyChannel, (NodeClient) client); - } - } -} diff --git a/src/test/java/org/opensearch/agent/ToolPluginTests.java b/src/test/java/org/opensearch/agent/ToolPluginTests.java index f7cfc3dd..22150cf6 100644 --- a/src/test/java/org/opensearch/agent/ToolPluginTests.java +++ b/src/test/java/org/opensearch/agent/ToolPluginTests.java @@ -96,7 +96,7 @@ public void test_getRestHandlers_successful() { @Test public void test_getToolFactories_successful() { - assertEquals(13, toolPlugin.getToolFactories().size()); + assertEquals(12, toolPlugin.getToolFactories().size()); } @Test diff --git a/src/test/java/org/opensearch/agent/tools/DynamicToolTests.java b/src/test/java/org/opensearch/agent/tools/DynamicToolTests.java deleted file mode 100644 index d142a841..00000000 --- a/src/test/java/org/opensearch/agent/tools/DynamicToolTests.java +++ /dev/null @@ -1,336 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.agent.tools; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThrows; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import java.util.Map; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.xcontent.MediaType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.utils.StringUtils; -import org.opensearch.rest.DynamicRestRequestCreator; -import org.opensearch.rest.DynamicToolExecutor; -import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestResponse; -import org.opensearch.test.rest.FakeRestRequest; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.client.Client; - -import com.google.common.collect.ImmutableMap; - -public class DynamicToolTests { - - @Mock - private DynamicToolExecutor dynamicToolExecutor; - @Mock - private ThreadPool threadPool; - @Mock - private ThreadContext threadContext; - @Mock - private Client client; - @Mock - private NamedXContentRegistry xContentRegistry; - @Mock - private ActionListener listener; - @Mock - private DynamicRestRequestCreator dynamicRestRequestCreator; - - @Before - public void setup() throws Exception { - MockitoAnnotations.openMocks(this); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - when(threadContext.getHeaders()).thenReturn(ImmutableMap.of()); - - BytesReference mockRequestBody = BytesReference - .fromByteBuffer(ByteBuffer.wrap("mock request body".getBytes(StandardCharsets.UTF_8))); - RestRequest restRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withContent(mockRequestBody, MediaType.fromMediaType(XContentType.JSON.mediaType())) - .build(); - when(dynamicRestRequestCreator.createRestRequest(any(), any(), any(), any(), any())).thenReturn(restRequest); - DynamicTool.Factory.getInstance().init(client, dynamicToolExecutor, dynamicRestRequestCreator, xContentRegistry); - } - - @Test - public void test_createTool_successful() { - DynamicTool tool = DynamicTool.Factory.getInstance().create(ImmutableMap.of("uri", "/my_index/_search", "method", "POST")); - assertNotNull(tool); - } - - @Test - public void test_createTool_missUri() { - Exception exception = assertThrows( - IllegalArgumentException.class, - () -> DynamicTool.Factory.getInstance().create(ImmutableMap.of()) - ); - assertEquals("valid uri is required in DynamicTool configuration!", exception.getMessage()); - } - - @Test - public void test_createTool_invalidUri() { - Exception exception = assertThrows( - IllegalArgumentException.class, - () -> DynamicTool.Factory.getInstance().create(ImmutableMap.of("uri", "")) - ); - assertEquals("valid uri is required in DynamicTool configuration!", exception.getMessage()); - } - - @Test - public void test_createTool_missMethod() { - Exception exception = assertThrows( - IllegalArgumentException.class, - () -> DynamicTool.Factory.getInstance().create(ImmutableMap.of("uri", "/my_index/_search")) - ); - assertEquals("method is required and not null in DynamicTool configuration!", exception.getMessage()); - } - - @Test - public void test_createTool_invalidMethod() { - Exception exception = assertThrows( - IllegalArgumentException.class, - () -> DynamicTool.Factory.getInstance().create(ImmutableMap.of("uri", "/my_index/_search", "method", "NULL")) - ); - assertEquals("valid method value is required in DynamicTool configuration!", exception.getMessage()); - } - - @Test - public void test_run_withoutResponseFilter_successful() throws Exception { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - registerAgentParameters.put("request_body", "{\"query\": {\"match\": {\"name\": \"${parameters.search_content}\"}}}"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - registerAgentParameters.put("search_content", "test"); - doAnswer(invocationOnMock -> { - ActionListener actionListener = invocationOnMock.getArgument(1); - RestResponse restResponse = mock(RestResponse.class); - BytesReference mockResponseBody = BytesReference - .fromByteBuffer(ByteBuffer.wrap("mock response body".getBytes(StandardCharsets.UTF_8))); - when(restResponse.content()).thenReturn(mockResponseBody); - actionListener.onResponse(restResponse); - return null; - }).when(dynamicToolExecutor).execute(any(), isA(ActionListener.class)); - tool.run(StringUtils.getParameterMap(registerAgentParameters), listener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); - verify(listener).onResponse(argumentCaptor.capture()); - assertNotNull(argumentCaptor.getValue()); - assertEquals("mock response body", argumentCaptor.getValue()); - } - - @Test - public void test_run_requestBodyNullOrNotExist_successful() throws Exception { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - registerAgentParameters.put("request_body", null); - DynamicTool tool0 = DynamicTool.Factory.getInstance().create(registerAgentParameters); - assertNotNull(tool0); - registerAgentParameters.remove("request_body"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - registerAgentParameters.put("search_content", "test"); - doAnswer(invocationOnMock -> { - ActionListener actionListener = invocationOnMock.getArgument(1); - RestResponse restResponse = mock(RestResponse.class); - BytesReference mockResponseBody = BytesReference - .fromByteBuffer(ByteBuffer.wrap("mock response body".getBytes(StandardCharsets.UTF_8))); - when(restResponse.content()).thenReturn(mockResponseBody); - actionListener.onResponse(restResponse); - return null; - }).when(dynamicToolExecutor).execute(any(), isA(ActionListener.class)); - tool.run(StringUtils.getParameterMap(registerAgentParameters), listener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); - verify(listener).onResponse(argumentCaptor.capture()); - assertNotNull(argumentCaptor.getValue()); - assertEquals("mock response body", argumentCaptor.getValue()); - } - - @Test - public void test_run_withResponseFilter_successful() throws Exception { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - registerAgentParameters.put("request_body", "{\"query\": {\"match\": {\"name\": \"${parameters.search_content}\"}}}"); - registerAgentParameters.put("response_filter", "$.name"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - registerAgentParameters.put("search_content", "test"); - doAnswer(invocationOnMock -> { - ActionListener actionListener = invocationOnMock.getArgument(1); - RestResponse restResponse = mock(RestResponse.class); - BytesReference mockResponseBody = BytesReference - .fromByteBuffer(ByteBuffer.wrap("{\"name\": \"This is a mock value\"}".getBytes(StandardCharsets.UTF_8))); - when(restResponse.content()).thenReturn(mockResponseBody); - actionListener.onResponse(restResponse); - return null; - }).when(dynamicToolExecutor).execute(any(), isA(ActionListener.class)); - tool.run(StringUtils.getParameterMap(registerAgentParameters), listener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); - verify(listener).onResponse(argumentCaptor.capture()); - assertNotNull(argumentCaptor.getValue()); - assertEquals("This is a mock value", argumentCaptor.getValue()); - } - - @Test - public void test_run_failureOnToolExecutor() throws Exception { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - registerAgentParameters.put("request_body", "{\"query\": {\"match\": {\"name\": \"${parameters.search_content}\"}}}"); - registerAgentParameters.put("response_filter", "$.name"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - registerAgentParameters.put("search_content", "test"); - doAnswer(invocationOnMock -> { - ActionListener actionListener = invocationOnMock.getArgument(1); - actionListener.onFailure(new RuntimeException("System Error")); - return null; - }).when(dynamicToolExecutor).execute(any(), isA(ActionListener.class)); - tool.run(StringUtils.getParameterMap(registerAgentParameters), listener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); - verify(listener).onFailure(argumentCaptor.capture()); - assertNotNull(argumentCaptor.getValue()); - assertEquals("System Error", argumentCaptor.getValue().getMessage()); - } - - @Test - public void test_run_exceptionOnToolExecutor() throws Exception { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - registerAgentParameters.put("request_body", "{\"query\": {\"match\": {\"name\": \"${parameters.search_content}\"}}}"); - registerAgentParameters.put("response_filter", "$.name"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - registerAgentParameters.put("search_content", "test"); - doThrow(new RuntimeException("System Error")).when(dynamicToolExecutor).execute(any(), any()); - tool.run(StringUtils.getParameterMap(registerAgentParameters), listener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); - verify(listener).onFailure(argumentCaptor.capture()); - assertNotNull(argumentCaptor.getValue()); - assertEquals("System Error", argumentCaptor.getValue().getMessage()); - } - - @Test - public void test_factory_getDefaultDescription() { - String description = DynamicTool.Factory.getInstance().getDefaultDescription(); - assertNotNull(description); - assertEquals( - "This is a template tool to enable OpenSearch APIs as tool, this tool accepts several parameters: uri, method, request_body and response_filter. uri represents the OpenSearch API uri, method represents the" - + "OpenSearch API method, request_body represents the OpenSearch API request body and response_filter is a json path expression so that target fields can be extracted from OpenSearch API response. Most OpenSearch APIs" - + "can be configured with this tool, during agent execution the configured API will be invoked and the response/filtered response will be returned as tool's response.", - description - ); - } - - @Test - public void test_factory_getDefaultType() { - String type = DynamicTool.Factory.getInstance().getDefaultType(); - assertNotNull(type); - assertEquals("DynamicTool", type); - } - - @Test - public void test_factory_getDefaultVersion() { - String version = DynamicTool.Factory.getInstance().getDefaultVersion(); - assertNull(version); - } - - @Test - public void test_tool_getType() { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - assertEquals("DynamicTool", tool.getType()); - } - - @Test - public void test_tool_getVersion() { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - assertNull(tool.getVersion()); - } - - @Test - public void test_tool_getName() { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - tool.setName("test"); - assertEquals("test", tool.getName()); - } - - @Test - public void test_tool_getDescription() { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - assertEquals( - "This is a template tool to enable OpenSearch APIs as tool, this tool accepts several parameters: uri, method, request_body and response_filter. uri represents the OpenSearch API uri, method represents the" - + "OpenSearch API method, request_body represents the OpenSearch API request body and response_filter is a json path expression so that target fields can be extracted from OpenSearch API response. Most OpenSearch APIs" - + "can be configured with this tool, during agent execution the configured API will be invoked and the response/filtered response will be returned as tool's response.", - tool.getDescription() - ); - } - - @Test - public void test_tool_getAttributes() { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - tool.setAttributes(ImmutableMap.of("test_input", "{}")); - Map attributes = tool.getAttributes(); - assertNotNull(attributes); - assertEquals(ImmutableMap.of("test_input", "{}"), attributes); - } - - @Test - public void test_tool_setDescription() { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - tool.setDescription("test description"); - assertEquals("test description", tool.getDescription()); - } - - @Test - public void test_tool_validate() { - Map registerAgentParameters = new HashMap<>(); - registerAgentParameters.put("uri", "/my_index/_search"); - registerAgentParameters.put("method", "POST"); - DynamicTool tool = DynamicTool.Factory.getInstance().create(registerAgentParameters); - Map runtimeParameters = ImmutableMap.of("search_content", "test"); - assertTrue(tool.validate(runtimeParameters)); - } - -} diff --git a/src/test/java/org/opensearch/rest/DynamicRestRequestCreatorTests.java b/src/test/java/org/opensearch/rest/DynamicRestRequestCreatorTests.java deleted file mode 100644 index a7e8bad7..00000000 --- a/src/test/java/org/opensearch/rest/DynamicRestRequestCreatorTests.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.rest; - -import static org.mockito.Mockito.mock; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.List; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.MockitoAnnotations; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.http.HttpRequest; -import org.opensearch.http.HttpResponse; - -import com.google.common.collect.ImmutableMap; - -public class DynamicRestRequestCreatorTests { - - @Before - public void setUp() { - MockitoAnnotations.initMocks(this); - } - - @Test - public void test_createRestRequest() { - DynamicRestRequestCreator dynamicRestRequestCreator = new DynamicRestRequestCreator(); - RestRequest restRequest = dynamicRestRequestCreator - .createRestRequest( - null, - RestRequest.Method.GET, - "/_search", - null, - ImmutableMap.of("Content-Type", List.of("application/json")) - ); - assert restRequest != null; - assert restRequest.path().equals("/_search"); - assert restRequest.method().equals(RestRequest.Method.GET); - assert restRequest.content() == null; - assert restRequest.getHeaders().size() == 1; - - restRequest.getHttpRequest().release(); - restRequest.getHttpChannel().close(); - restRequest.getHttpChannel().addCloseListener(mock(ActionListener.class)); - restRequest.getHttpChannel().sendResponse(mock(HttpResponse.class), mock(ActionListener.class)); - - assert restRequest.getHttpRequest().removeHeader("Content-Type") == restRequest.getHttpRequest(); - assert restRequest.getHttpRequest().strictCookies().isEmpty(); - assert restRequest.getHttpRequest().protocolVersion().equals(HttpRequest.HttpVersion.HTTP_1_0); - assert restRequest - .getHttpRequest() - .createResponse( - RestStatus.BAD_REQUEST, - BytesReference.fromByteBuffer(ByteBuffer.wrap("mock response body".getBytes(StandardCharsets.UTF_8))) - ) == null; - assert restRequest.getHttpRequest().getInboundException() == null; - assert restRequest.getHttpRequest().releaseAndCopy() != null; - - assert restRequest.getHttpChannel().isOpen(); - assert restRequest.getHttpChannel().getLocalAddress() == null; - assert restRequest.getHttpChannel().getRemoteAddress() == null; - } -} diff --git a/src/test/java/org/opensearch/rest/DynamicToolExecutorTests.java b/src/test/java/org/opensearch/rest/DynamicToolExecutorTests.java deleted file mode 100644 index 77849cf7..00000000 --- a/src/test/java/org/opensearch/rest/DynamicToolExecutorTests.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.rest; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.MediaType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.test.rest.FakeRestRequest; -import org.opensearch.transport.client.node.NodeClient; - -import com.google.common.collect.ImmutableMap; - -public class DynamicToolExecutorTests { - @Mock - private AtomicReference restControllerRef; - @Mock - private RestController restController; - @Mock - private NodeClient client; - @Mock - private ActionListener listener; - - @Before - public void setUp() { - MockitoAnnotations.openMocks(this); - when(restControllerRef.get()).thenReturn(restController); - } - - @Test - public void test_constructor() { - DynamicToolExecutor executor = new DynamicToolExecutor(restControllerRef, client); - assertNotNull(executor); - } - - @Test - public void test_execute_successful() throws Exception { - Optional restHandler = Optional.of((request, channel, client) -> channel.sendResponse(new RestResponse() { - @Override - public String contentType() { - return "text/plain"; - } - - @Override - public BytesReference content() { - return BytesReference.fromByteBuffer(ByteBuffer.wrap("mock response body".getBytes(StandardCharsets.UTF_8))); - } - - @Override - public RestStatus status() { - return RestStatus.OK; - } - })); - when(restController.dispatchHandler(any(), any(), any(), any())).thenReturn(restHandler); - BytesReference mockRequestBody = BytesReference - .fromByteBuffer(ByteBuffer.wrap("mock request body".getBytes(StandardCharsets.UTF_8))); - RestRequest restRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withMethod(RestRequest.Method.GET) - .withPath("/my_index/_search") - .withParams(ImmutableMap.of("allow_no_indices", "true")) - .withContent(mockRequestBody, MediaType.fromMediaType(XContentType.JSON.mediaType())) - .build(); - new DynamicToolExecutor(restControllerRef, client).execute(restRequest, listener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RestResponse.class); - verify(listener).onResponse(argumentCaptor.capture()); - assertNotNull(argumentCaptor.getValue()); - } - - @Test - public void test_execute_restHandlerNotFound() throws Exception { - Optional restHandler = Optional.empty(); - when(restController.dispatchHandler(any(), any(), any(), any())).thenReturn(restHandler); - BytesReference mockRequestBody = BytesReference - .fromByteBuffer(ByteBuffer.wrap("mock request body".getBytes(StandardCharsets.UTF_8))); - RestRequest restRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withMethod(RestRequest.Method.GET) - .withPath("/my_index/_search") - .withParams(ImmutableMap.of("allow_no_indices", "true")) - .withContent(mockRequestBody, MediaType.fromMediaType(XContentType.JSON.mediaType())) - .build(); - new DynamicToolExecutor(restControllerRef, client).execute(restRequest, listener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RestResponse.class); - verify(listener).onResponse(argumentCaptor.capture()); - assertNotNull(argumentCaptor.getValue()); - RestResponse restResponse = argumentCaptor.getValue(); - assertEquals( - "No handler found for /my_index/_search, please check your agent configuration!", - restResponse.content().utf8ToString() - ); - assertEquals(RestStatus.BAD_REQUEST, restResponse.status()); - assertEquals("text/plain", restResponse.contentType()); - } -} From d1152a243cbf7f834e7bcf6a5cdea8120a3df336 Mon Sep 17 00:00:00 2001 From: Yuanchun Shen Date: Wed, 13 Aug 2025 17:11:25 +0800 Subject: [PATCH 03/30] Wait until LLM setup tasks complete in ToolIntegrationTest (#623) * Wait until LLM setup tasks complete in ToolIntegrationTest Signed-off-by: Yuanchun Shen * Update release notes of 3.2.0.0 Signed-off-by: Yuanchun Shen --------- Signed-off-by: Yuanchun Shen --- .../opensearch-skills.release-notes-3.2.0.0.md | 5 +---- .../org/opensearch/integTest/ToolIntegrationTest.java | 10 +++++----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/release-notes/opensearch-skills.release-notes-3.2.0.0.md b/release-notes/opensearch-skills.release-notes-3.2.0.0.md index 17a9f533..16f7e946 100644 --- a/release-notes/opensearch-skills.release-notes-3.2.0.0.md +++ b/release-notes/opensearch-skills.release-notes-3.2.0.0.md @@ -2,15 +2,12 @@ Compatible with OpenSearch and OpenSearch Dashboards version 3.2.0.0 -### Features -* Support dynamic tool in agent framework ([#606](https://github.com/opensearch-project/skills/pull/606)) - ### Enhancements * Merge index schema meta ([#596](https://github.com/opensearch-project/skills/pull/596)) * Mask error message in PPLTool ([#609](https://github.com/opensearch-project/skills/pull/609)) ### Bug Fixes -* Fix attributes handling in dynamic tool ([#607](https://github.com/opensearch-project/skills/pull/607)) +* Update parameter handling of tools ([#618](https://github.com/opensearch-project/skills/pull/618)) ### Maintenance * Update the maven snapshot publish endpoint and credential ([#601](https://github.com/opensearch-project/skills/pull/601)) diff --git a/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java index afdfb3d4..33f26e99 100644 --- a/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java +++ b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java @@ -19,6 +19,7 @@ import org.opensearch.client.Response; import com.google.gson.Gson; +import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.sun.net.httpserver.HttpServer; @@ -46,8 +47,6 @@ public void setupTestAgent() throws IOException, InterruptedException { connectorId = setUpConnectorWithRetry(5); modelGroupId = setupModelGroup(); modelId = setupLLMModel(connectorId, modelGroupId); - // wait for model to get deployed - TimeUnit.SECONDS.sleep(1); agentId = setupConversationalAgent(modelId); log.info("model_id: {}, agent_id: {}", modelId, agentId); } @@ -172,10 +171,11 @@ private String setupLLMModel(String connectorId, String modelGroupId) throws IOE + "}" ); Response response = executeRequest(request); - String resp = readResponse(response); - - return JsonParser.parseString(resp).getAsJsonObject().get("model_id").getAsString(); + JsonObject respObj = JsonParser.parseString(resp).getAsJsonObject(); + String taskId = respObj.get("task_id").getAsString(); + waitTaskComplete(taskId); + return respObj.get("model_id").getAsString(); } private String setupConversationalAgent(String modelId) throws IOException { From fff1ca0cf9ac0b5033dbb477735a98911695dc37 Mon Sep 17 00:00:00 2001 From: Riley Jerger <214163063+RileyJergerAmazon@users.noreply.github.com> Date: Wed, 13 Aug 2025 02:44:08 -0700 Subject: [PATCH 04/30] Update delete_backport_branch workflow to include release-chores branches (#622) Signed-off-by: Riley Jerger --- .github/workflows/delete_backport_branch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/delete_backport_branch.yml b/.github/workflows/delete_backport_branch.yml index 8ca5ed3e..be2dffd5 100644 --- a/.github/workflows/delete_backport_branch.yml +++ b/.github/workflows/delete_backport_branch.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest permissions: contents: write - if: startsWith(github.event.pull_request.head.ref,'backport/') + if: startsWith(github.event.pull_request.head.ref,'backport/') || startsWith(github.event.pull_request.head.ref,'release-chores/') steps: - name: Delete merged branch uses: actions/github-script@v7 From aa1e091219577b28a15658cd9400b01bfd9eca0d Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 20 Aug 2025 23:30:29 -0700 Subject: [PATCH 05/30] fix: Update System.env syntax for Gradle 9 compatibility (#630) Signed-off-by: Daniel Widdis --- build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index 78c211d3..f3b3aefb 100644 --- a/build.gradle +++ b/build.gradle @@ -442,8 +442,8 @@ publishing { name = "Snapshots" url = "https://central.sonatype.com/repository/maven-snapshots/" credentials { - username "$System.env.SONATYPE_USERNAME" - password "$System.env.SONATYPE_PASSWORD" + username System.getenv("SONATYPE_USERNAME") + password System.getenv("SONATYPE_PASSWORD") } } } From 32ac2158bd1a8df46ea616ba365642abbfd95001 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Fri, 5 Sep 2025 11:33:17 +0800 Subject: [PATCH 06/30] Log patterns analysis tool (#625) * Add LogPatternAnalysisTool Signed-off-by: Binlong Gao Signed-off-by: Hailong Cui * output more for log pattern analysis Signed-off-by: Hailong Cui * adjust the order of comparing selection and baseline Signed-off-by: Hailong Cui * log pattern analysis Signed-off-by: Hailong Cui * using HCA as clustering method Signed-off-by: Hailong Cui * log pattern analysis Signed-off-by: Hailong Cui * using cosine similarity Signed-off-by: Hailong Cui * result format change Signed-off-by: Hailong Cui * Improve processing speed through sharding. Signed-off-by: Hailong Cui * add more error keywords Signed-off-by: Hailong Cui * refactor and ignore single trace for log sequence Signed-off-by: Hailong Cui * fix unit test Signed-off-by: Hailong Cui * limit top 10 difference Signed-off-by: Hailong Cui * don't remove signle trace id event for log sequence analysis Signed-off-by: Hailong Cui * remove unused code Signed-off-by: Hailong Cui * add input schema for MCP Signed-off-by: Hailong Cui * fix spottless Signed-off-by: Hailong Cui * fix ci Signed-off-by: Hailong Cui * add extractInputParameters Signed-off-by: Hailong Cui * add:UT&IT Signed-off-by: Jiaru Jiang * using aggregation mode to improve performance Signed-off-by: Hailong Cui * fix:spotlessCheck Signed-off-by: Jiaru Jiang * add:ClusteringHelperTests Signed-off-by: Jiaru Jiang * fix:ClusteringHelperTests spotlessCheck failed Signed-off-by: Jiaru Jiang * fix:Improve code coverage Signed-off-by: Jiaru Jiang * fix:Modify exception handling Signed-off-by: Jiaru Jiang * update date format in input schema Signed-off-by: Hailong Cui * address review comments Signed-off-by: Hailong Cui --------- Signed-off-by: Binlong Gao Signed-off-by: Hailong Cui Signed-off-by: Jiaru Jiang Co-authored-by: Binlong Gao Co-authored-by: Jiaru Jiang --- build.gradle | 1 + .../java/org/opensearch/agent/ToolPlugin.java | 9 +- .../agent/tools/LogPatternAnalysisTool.java | 1036 +++++++++++++++++ .../utils/clustering/ClusteringHelper.java | 514 ++++++++ .../HierarchicalAgglomerativeClustering.java | 281 +++++ .../org/opensearch/agent/ToolPluginTests.java | 2 +- .../tools/LogPatternAnalysisToolTests.java | 566 +++++++++ .../tools/utils/ClusteringHelperTests.java | 164 +++ .../integTest/LogPatternAnalysisToolIT.java | 220 ++++ ...og_pattern_analysis_tool_request_body.json | 10 + 10 files changed, 2798 insertions(+), 5 deletions(-) create mode 100644 src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java create mode 100644 src/main/java/org/opensearch/agent/tools/utils/clustering/ClusteringHelper.java create mode 100644 src/main/java/org/opensearch/agent/tools/utils/clustering/HierarchicalAgglomerativeClustering.java create mode 100644 src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java create mode 100644 src/test/java/org/opensearch/agent/tools/utils/ClusteringHelperTests.java create mode 100644 src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java create mode 100644 src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_log_pattern_analysis_tool_request_body.json diff --git a/build.gradle b/build.gradle index f3b3aefb..541ba649 100644 --- a/build.gradle +++ b/build.gradle @@ -140,6 +140,7 @@ dependencies { compileOnly("com.google.guava:guava:33.2.1-jre") compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.16.0' compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.12.0' + compileOnly group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") compileOnly("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") compileOnly(group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: "${versions.httpcore5}") diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index c1931f89..6dcdc829 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -13,6 +13,7 @@ import org.opensearch.agent.tools.CreateAlertTool; import org.opensearch.agent.tools.CreateAnomalyDetectorTool; +import org.opensearch.agent.tools.LogPatternAnalysisTool; import org.opensearch.agent.tools.LogPatternTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; @@ -49,8 +50,6 @@ import org.opensearch.transport.client.Client; import org.opensearch.watcher.ResourceWatcherService; -import com.google.common.collect.ImmutableList; - import lombok.SneakyThrows; public class ToolPlugin extends Plugin implements MLCommonsExtension, ActionPlugin { @@ -100,6 +99,7 @@ public Collection createComponents( CreateAnomalyDetectorTool.Factory.getInstance().init(client); LogPatternTool.Factory.getInstance().init(client, xContentRegistry); WebSearchTool.Factory.getInstance().init(threadPool); + LogPatternAnalysisTool.Factory.getInstance().init(client); return Collections.emptyList(); } @@ -118,7 +118,8 @@ public List> getToolFactories() { CreateAlertTool.Factory.getInstance(), CreateAnomalyDetectorTool.Factory.getInstance(), LogPatternTool.Factory.getInstance(), - WebSearchTool.Factory.getInstance() + WebSearchTool.Factory.getInstance(), + LogPatternAnalysisTool.Factory.getInstance() ); } @@ -133,7 +134,7 @@ public List> getExecutorBuilders(Settings settings) { false ); - return ImmutableList.of(websearchCrawlThread); + return List.of(websearchCrawlThread); } } diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java new file mode 100644 index 00000000..0f63ff6e --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java @@ -0,0 +1,1036 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.agent.tools.utils.ToolHelper.getPPLTransportActionListener; +import static org.opensearch.agent.tools.utils.clustering.HierarchicalAgglomerativeClustering.calculateCosineSimilarity; +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.json.JSONObject; +import org.opensearch.agent.tools.utils.clustering.ClusteringHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; +import org.opensearch.sql.ppl.domain.PPLQueryRequest; +import org.opensearch.transport.client.Client; + +import com.google.common.collect.ImmutableMap; +import com.google.gson.reflect.TypeToken; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Usage: + * 1. Register agent: + * POST /_plugins/_ml/agents/_register + * { + * "name": "LogPatternAnalysis", + * "type": "flow", + * "tools": [ + * { + * "name": "log_pattern_analysis_tool", + * "type": "LogPatternAnalysisTool", + * "parameters": { + * } + * } + * ] + * } + * 2. Execute agent: + * POST /_plugins/_ml/agents/{agent_id}/_execute + * { + * "parameters": { + * "index": "ss4o_logs-otel-2025.06.24", + * "logFieldName": "body", + * "traceFieldName": "traceId", + * "baseTimeRangeStart": "2025-06-24 07:33:05", + * "baseTimeRangeEnd": "2025-06-24 07:51:27", + * "selectionTimeRangeStart": "2025-06-24 07:50:26", + * "selectionTimeRangeEnd": "2025-06-24 07:55:56" + * } + * } + * 3. Result: a list of selection traceId + * { + * "inference_results": [ + * { + * "output": [ + * { + * "name": "response", + * "result": """{"EXCEPTIONAL": {"traceId": "sequence"}}""" + * } + * ] + * } + * ] + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(LogPatternAnalysisTool.TYPE) +public class LogPatternAnalysisTool implements Tool { + public static final String TYPE = "LogPatternAnalysisTool"; + public static final String STRICT_FIELD = "strict"; + + // Constants + private static final String DEFAULT_DESCRIPTION = + "This is a tool used to detect selection log patterns by the patterns command in PPL or to detect selection log sequences by the log clustering algorithm."; + private static final double LOG_VECTORS_CLUSTERING_THRESHOLD = 0.5; + private static final double LOG_PATTERN_THRESHOLD = 0.75; + private static final double LOG_PATTERN_LIFT = 3; + private static final String DEFAULT_TIME_FIELD = "@timestamp"; + + public static final String DEFAULT_INPUT_SCHEMA = + """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "Target OpenSearch index name containing log data (e.g., 'ss4o_logs-otel-2025.06.24')" + }, + "timeField": { + "type": "string", + "description": "Date/time field in the index mapping used for time-based filtering" + }, + "logFieldName": { + "type": "string", + "description": "Field containing raw log messages to analyze (e.g., 'body', 'message', 'log')" + }, + "traceFieldName": { + "type": "string", + "description": "[OPTIONAL] Field for trace/correlation ID to enable sequence analysis (e.g., 'traceId', 'correlationId'). Leave empty for pattern-only analysis." + }, + "baseTimeRangeStart": { + "type": "string", + "description": "Start time for baseline comparison period (date string in utc timezone, e.g., '2025-06-24 07:33:05')" + }, + "baseTimeRangeEnd": { + "type": "string", + "description": "End time for baseline comparison period (date string in utc timezone, e.g., '2025-06-24 07:51:27')" + }, + "selectionTimeRangeStart": { + "type": "string", + "description": "Start time for analysis target period (date string in utc timezone, e.g., '2025-06-24 07:50:26')" + }, + "selectionTimeRangeEnd": { + "type": "string", + "description": "End time for analysis target period (date string in utc timezone, e.g., '2025-06-24 07:55:56')" + } + }, + "required": [ + "index", + "timeField", + "logFieldName", + "selectionTimeRangeStart", + "selectionTimeRangeEnd" + ], + "additionalProperties": false + } + """; + + public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); + + // Compiled regex patterns for better performance + private static final Pattern REPEATED_WILDCARDS_PATTERN = Pattern.compile("(<\\*>)(\\s+<\\*>)+"); + + /** + * Parameter class to hold analysis parameters with validation + */ + private static class AnalysisParameters { + final String index; + final String timeField; + final String logFieldName; + final String traceFieldName; + final String baseTimeRangeStart; + final String baseTimeRangeEnd; + final String selectionTimeRangeStart; + final String selectionTimeRangeEnd; + + AnalysisParameters(Map parameters) { + this.index = parameters.getOrDefault("index", ""); + this.timeField = parameters.getOrDefault("timeField", DEFAULT_TIME_FIELD); + this.logFieldName = parameters.getOrDefault("logFieldName", "message"); + this.traceFieldName = parameters.getOrDefault("traceFieldName", ""); + this.baseTimeRangeStart = parameters.getOrDefault("baseTimeRangeStart", ""); + this.baseTimeRangeEnd = parameters.getOrDefault("baseTimeRangeEnd", ""); + this.selectionTimeRangeStart = parameters.getOrDefault("selectionTimeRangeStart", ""); + this.selectionTimeRangeEnd = parameters.getOrDefault("selectionTimeRangeEnd", ""); + } + + private void validate() { + List missingParams = new ArrayList<>(); + if (Strings.isEmpty(index)) + missingParams.add("index"); + if (Strings.isEmpty(timeField)) + missingParams.add("timeField"); + if (Strings.isEmpty(logFieldName)) + missingParams.add("logFieldName"); + if (Strings.isEmpty(selectionTimeRangeStart)) + missingParams.add("selectionTimeRangeStart"); + if (Strings.isEmpty(selectionTimeRangeEnd)) + missingParams.add("selectionTimeRangeEnd"); + if (!missingParams.isEmpty()) { + throw new IllegalArgumentException("Missing required parameters: " + String.join(", ", missingParams)); + } + } + + boolean hasBaseTime() { + return !Strings.isEmpty(baseTimeRangeStart) && !Strings.isEmpty(baseTimeRangeEnd); + } + + boolean hasTraceField() { + return !Strings.isEmpty(traceFieldName); + } + } + + /** + * Result class for pattern analysis + */ + private record PatternAnalysisResult(Map> tracePatternMap, Map> patternCountMap, + Map patternWeightsMap) { + } + + private record PatternDiffResult(String pattern, Double base, Double selection, Double lift) { + } + + Comparator comparator = (d1, d2) -> { + Double lift1 = Optional.ofNullable(d1.lift).orElse(Double.MIN_VALUE); + Double lift2 = Optional.ofNullable(d2.lift).orElse(Double.MIN_VALUE); + + if (lift1.compareTo(lift2) == 0) { + return Optional + .ofNullable(d2.selection) + .orElse(Double.MIN_VALUE) + .compareTo(Optional.ofNullable(d1.selection).orElse(Double.MIN_VALUE)); + } else { + return lift2.compareTo(lift1); + } + }; + + private record PatternWithSamples(String pattern, double count, List sampleLogs) { + } + + // Instance fields + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String version; + private Client client; + private ClusteringHelper clusteringHelper; + + public LogPatternAnalysisTool(Client client) { + this.client = client; + this.clusteringHelper = new ClusteringHelper(LOG_VECTORS_CLUSTERING_THRESHOLD); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public Map getAttributes() { + return Map.of(); + } + + @Override + public void setAttributes(Map map) { + + } + + @Override + public boolean validate(Map map) { + try { + new AnalysisParameters(map).validate(); + } catch (Exception e) { + return false; + } + return true; + } + + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, DEFAULT_ATTRIBUTES); + log.debug("Starting log pattern analysis with parameters: {}", parameters.keySet()); + AnalysisParameters params = new AnalysisParameters(parameters); + params.validate(); + + if (params.hasTraceField() && params.hasBaseTime()) { + log.debug("Performing log sequence analysis for index: {}", params.index); + logSequenceAnalysis(params, listener); + } else if (params.hasBaseTime()) { + log.debug("Performing log pattern analysis for index: {}", params.index); + logPatternDiffAnalysis(params, listener); + } else { + logInsight(params, listener); + } + } catch (IllegalArgumentException e) { + log.error("Invalid parameters for LogPatternAnalysisTool: {}", e.getMessage()); + listener.onFailure(new IllegalArgumentException("Invalid parameters: " + e.getMessage(), e)); + } catch (Exception e) { + log.error("Unexpected error in LogPatternAnalysisTool", e); + listener.onFailure(new RuntimeException("Failed to execute log pattern analysis", e)); + } + } + + private void logSequenceAnalysis(AnalysisParameters params, ActionListener listener) { + // Step 1: Analyze selection time range + analyzeSelectionTimeRange(params, ActionListener.wrap(selectionResult -> { + log.debug("Base time range analysis completed, found {} traces", selectionResult.tracePatternMap.size()); + + if (selectionResult.tracePatternMap.isEmpty()) { + Map> emptyResult = buildFinalResult( + List.of(), + List.of(), + Collections.emptyMap(), + Collections.emptyMap() + ); + listener.onResponse((T) gson.toJson(emptyResult)); + return; + } + + // Step 2: Analyze base time range + analyzeBaseTimeRange(params, ActionListener.wrap(baseResult -> { + log.debug("Selection time range analysis completed, found {} traces", baseResult.tracePatternMap.size()); + + // Step 3: Generate comparison result + generateSequenceComparisonResult(baseResult, selectionResult, listener); + }, listener::onFailure)); + }, error -> { + log.error("Failed to execute analysis", error); + listener.onFailure(new RuntimeException("Analysis failed: " + error.getMessage(), error)); + })); + } + + private void analyzeBaseTimeRange(AnalysisParameters params, ActionListener listener) { + String baseTimeRangeLogPatternPPL = buildLogPatternPPL( + params.index, + params.timeField, + params.logFieldName, + params.traceFieldName, + params.baseTimeRangeStart, + params.baseTimeRangeEnd + ); + + executePPL(baseTimeRangeLogPatternPPL, listener); + } + + private void analyzeSelectionTimeRange(AnalysisParameters params, ActionListener listener) { + String selectionTimeRangeLogPatternPPL = buildLogPatternPPL( + params.index, + params.timeField, + params.logFieldName, + params.traceFieldName, + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd + ); + + executePPL(selectionTimeRangeLogPatternPPL, listener); + } + + private void executePPL(String ppl, ActionListener listener) { + Function>, PatternAnalysisResult> rowParser = dataRows -> { + Map> tracePatternMap = new HashMap<>(); + Map> patternCountMap = new HashMap<>(); + Map rawPatternCache = new HashMap<>(); + + for (List row : dataRows) { + if (row.size() < 2) { + continue; + } + + String traceId = (String) row.get(0); + String rawPattern = (String) row.get(1); + + String simplifiedPattern = rawPatternCache.computeIfAbsent(rawPattern, this::postProcessPattern); + + tracePatternMap.computeIfAbsent(traceId, k -> new LinkedHashSet<>()).add(simplifiedPattern); + patternCountMap.computeIfAbsent(simplifiedPattern, k -> new HashSet<>()).add(traceId); + } + + // Calculate pattern values using IDF and sigmoid + Map patternVectors = vectorizePattern(patternCountMap, tracePatternMap.size()); + + return new PatternAnalysisResult(tracePatternMap, patternCountMap, patternVectors); + }; + + executePPLAndParseResult(ppl, rowParser, listener); + } + + private String buildLogPatternPPL( + String index, + String timeField, + String logFieldName, + String traceFieldName, + String startTime, + String endTime + ) { + return String + .format( + Locale.ROOT, + "source=%s | where %s!='' | where %s>'%s' and %s<'%s' | patterns %s method=brain " + + "variable_count_threshold=3 | fields %s, patterns_field, %s | sort %s", + index, + traceFieldName, + timeField, + startTime, + timeField, + endTime, + logFieldName, + traceFieldName, + timeField, + timeField + ); + } + + private Map vectorizePattern(Map> patternCountMap, int totalTraceCount) { + Map patternValues = new HashMap<>(); + + for (Map.Entry> entry : patternCountMap.entrySet()) { + String pattern = entry.getKey(); + Set traceIds = entry.getValue(); + + if (traceIds != null && !traceIds.isEmpty()) { + // IDF calculation + double idf = Math.log((double) totalTraceCount / traceIds.size()); + // Apply sigmoid function + double value = 1.0 / (1.0 + Math.exp(-idf)); + patternValues.put(pattern, value); + } else { + patternValues.put(pattern, 0.0); + } + } + + return patternValues; + } + + private void generateSequenceComparisonResult( + PatternAnalysisResult baseResult, + PatternAnalysisResult selectionResult, + ActionListener listener + ) { + try { + // Step 3: Build pattern index for vector construction + Map patternIndexMap = buildPatternIndex(baseResult, selectionResult); + + // Step 4: Build vectors for base time range + Map baseVectorMap = buildVectorMap( + baseResult.tracePatternMap, + baseResult.patternWeightsMap, + patternIndexMap, + false + ); + + // Step 5: Cluster base vectors and find centroids + List baseRepresentative = this.clusteringHelper.clusterLogVectorsAndGetRepresentative(baseVectorMap); + + // Step 6: Build vectors for traceNeedToExamine time range + Map selectionVectorMap = buildVectorMap( + selectionResult.tracePatternMap, + selectionResult.patternWeightsMap, + patternIndexMap, + true, + baseResult.patternCountMap, + selectionResult.patternCountMap + ); + + // Step 7: Find traceNeedToExamine centroids + List selectionRepresentative = this.clusteringHelper.clusterLogVectorsAndGetRepresentative(selectionVectorMap); + + List traceNeedToExamine = filterSelectionCentroids( + baseRepresentative, + selectionRepresentative, + baseVectorMap, + selectionVectorMap + ); + + log + .info( + "Identified {} traceNeedToExamine centroids from {} candidates", + traceNeedToExamine.size(), + selectionRepresentative.size() + ); + + // Generate final result + Map> result = buildFinalResult( + baseRepresentative, + traceNeedToExamine, + baseResult.tracePatternMap, + selectionResult.tracePatternMap + ); + listener.onResponse((T) gson.toJson(result)); + + } catch (Exception e) { + log.error("Failed to generate sequence comparison result", e); + listener.onFailure(new RuntimeException("Failed to generate comparison result: " + e.getMessage(), e)); + } + } + + private Map buildPatternIndex(PatternAnalysisResult baseResult, PatternAnalysisResult selectionResult) { + Set allPatterns = new HashSet<>(baseResult.patternCountMap.keySet()); + allPatterns.addAll(selectionResult.patternCountMap.keySet()); + + List sortedPatterns = new ArrayList<>(allPatterns); + Collections.sort(sortedPatterns); + log.debug("vector dimension is {}", sortedPatterns.size()); + + // pattern and its index in a vector + Map patternIndexMap = new HashMap<>(); + for (int i = 0; i < sortedPatterns.size(); i++) { + patternIndexMap.put(sortedPatterns.get(i), i); + } + + return patternIndexMap; + } + + @SafeVarargs + private Map buildVectorMap( + Map> tracePatternMap, + Map patternWeightsMap, + Map patternIndexMap, + boolean isSelection, + Map>... additionalPatternMaps + ) { + Map vectorMap = new HashMap<>(); + int dimension = patternIndexMap.size(); + + for (Map.Entry> entry : tracePatternMap.entrySet()) { + String traceId = entry.getKey(); + Set patterns = entry.getValue(); + double[] vector = new double[dimension]; + + for (String pattern : patterns) { + Integer index = patternIndexMap.get(pattern); + if (index != null) { + double baseValue = 0.5 * patternWeightsMap.getOrDefault(pattern, 0.0); + + if (isSelection && additionalPatternMaps.length >= 2) { + // Add existence weight for selection patterns + Map> basePatterns = additionalPatternMaps[0]; + + int existenceWeight = basePatterns.containsKey(pattern) ? 0 : 1; + vector[index] = baseValue + 0.5 * existenceWeight; + } else { + vector[index] = baseValue; + } + } + } + + vectorMap.put(traceId, vector); + } + + return vectorMap; + } + + private List filterSelectionCentroids( + List baseCentroids, + List selectionCandidates, + Map baseVectorMap, + Map selectionVectorMap + ) { + List selectionCentroids = new ArrayList<>(); + + for (String candidate : selectionCandidates) { + boolean isSelection = true; + double[] candidateVector = selectionVectorMap.get(candidate); + + if (candidateVector == null) { + log.warn("No vector found for selection candidate: {}", candidate); + continue; + } + + for (String baseCentroid : baseCentroids) { + double[] baseVector = baseVectorMap.get(baseCentroid); + if (baseVector != null && calculateCosineSimilarity(baseVector, candidateVector) > LOG_VECTORS_CLUSTERING_THRESHOLD) { + isSelection = false; + break; + } + } + + if (isSelection) { + selectionCentroids.add(candidate); + } + } + + return selectionCentroids; + } + + private Map> buildFinalResult( + List baseCentroids, + List selectionCentroids, + Map> baseTracePatternMap, + Map> selectionTracePatternMap + ) { + Map baseSequences = new HashMap<>(); + for (String centroid : baseCentroids) { + Set patterns = baseTracePatternMap.get(centroid); + if (patterns != null) { + baseSequences.put(centroid, String.join(" -> ", patterns)); + } + } + + Map selectionSequences = new HashMap<>(); + for (String centroid : selectionCentroids) { + Set patterns = selectionTracePatternMap.get(centroid); + if (patterns != null) { + selectionSequences.put(centroid, String.join(" -> ", patterns)); + } + } + + Map> result = new HashMap<>(); + result.put("BASE", baseSequences); + result.put("EXCEPTIONAL", selectionSequences); + + return result; + } + + private void logPatternDiffAnalysis(AnalysisParameters params, ActionListener listener) { + // Step 1: Generate log patterns for baseline time range + String baseTimeRangeLogPatternPPL = buildLogPatternPPL( + params.index, + params.timeField, + params.logFieldName, + params.baseTimeRangeStart, + params.baseTimeRangeEnd + ); + Function>, Map> dataRowsParser = dataRows -> { + Map patternMap = new HashMap<>(); + for (List row : dataRows) { + if (row.size() == 2) { + String pattern = (String) row.get(1); + double count = ((Number) row.get(0)).doubleValue(); + patternMap.put(pattern, count); + } + } + return patternMap; + }; + + log.debug("Executing base time range pattern PPL: {}", baseTimeRangeLogPatternPPL); + executePPLAndParseResult(baseTimeRangeLogPatternPPL, dataRowsParser, ActionListener.wrap(basePatterns -> { + try { + mergeSimilarPatterns(basePatterns); + + log.debug("Base patterns processed: {} patterns", basePatterns.size()); + + // Step 2: Generate log patterns for selection time range + String selectionTimeRangeLogPatternPPL = buildLogPatternPPL( + params.index, + params.timeField, + params.logFieldName, + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd + ); + + log.debug("Executing selection time range pattern PPL: {}", selectionTimeRangeLogPatternPPL); + executePPLAndParseResult(selectionTimeRangeLogPatternPPL, dataRowsParser, ActionListener.wrap(selectionPatterns -> { + mergeSimilarPatterns(selectionPatterns); + + log.debug("Selection patterns processed: {} patterns", selectionPatterns.size()); + + // Step 3: Calculate pattern differences + List patternDifferences = calculatePatternDifferences(basePatterns, selectionPatterns); + + // Step 4: Sort the difference and get top 10 + List topDiffs = Stream + .concat( + patternDifferences.stream().filter(diff -> !Objects.isNull(diff.lift)).sorted(comparator).limit(10), + patternDifferences.stream().filter(diff -> Objects.isNull(diff.lift)).sorted(comparator).limit(10) + ) + .collect(Collectors.toList()); + + Map finalResult = new HashMap<>(); + finalResult.put("patternMapDifference", topDiffs); + + log.debug("Pattern analysis completed: {} differences found", patternDifferences.size()); + listener.onResponse((T) gson.toJson(finalResult)); + }, listener::onFailure)); + + } catch (Exception e) { + log.error("Failed to process base pattern response", e); + listener.onFailure(new RuntimeException("Failed to process base patterns: " + e.getMessage(), e)); + } + }, error -> { + log.error("Failed to execute pattern analysis", error); + listener.onFailure(new RuntimeException("Analysis failed: " + error.getMessage(), error)); + })); + } + + private void logInsight(AnalysisParameters params, ActionListener listener) { + Set errorKeywords = Set + .of( + "error", + "err", + "exception", + "failed", + "failure", + "timeout", + "panic", + "fatal", + "critical", + "severe", + "abort", + "aborted", + "aborting", + "crash", + "crashed", + "broken", + "corrupt", + "corrupted", + "invalid", + "malformed", + "unprocessable", + "denied", + "forbidden", + "unauthorized", + "conflict", + "deadlock", + "overflow", + "underflow", + "throttled", + "disk_full", + "insufficient", + "retrying", + "backpressure", + "degraded", + "unexpected", + "unusual", + "missing", + "stale", + "expired", + "mismatch", + "violation" + ); + + String selectionTimeRangeLogPatternPPL = String + .format( + Locale.ROOT, + "source=%s | where %s>'%s' and %s<'%s' | where match(%s, '%s') | patterns %s method=brain " + + "mode=aggregation max_sample_count=2 " + + "variable_count_threshold=3 | fields patterns_field, pattern_count, sample_logs " + + "| sort -pattern_count | head 5", + params.index, + params.timeField, + params.selectionTimeRangeStart, + params.timeField, + params.selectionTimeRangeEnd, + params.logFieldName, + String.join(" ", errorKeywords), + params.logFieldName + ); + + Function>, List> dataRowsParser = dataRows -> { + List patternWithSamplesList = new ArrayList<>(); + for (List row : dataRows) { + if (row.size() == 3) { + String pattern = (String) row.get(0); + double count = ((Number) row.get(1)).doubleValue(); + List samples = (List) row.get(2); + patternWithSamplesList.add(new PatternWithSamples(pattern, count, samples)); + } + } + return patternWithSamplesList; + }; + + executePPLAndParseResult(selectionTimeRangeLogPatternPPL, dataRowsParser, ActionListener.wrap(logInsights -> { + try { + Map finalResult = new HashMap<>(); + finalResult.put("logInsights", logInsights); + listener.onResponse((T) gson.toJson(finalResult)); + } catch (Exception e) { + log.error("Failed to process base pattern response", e); + listener.onFailure(new RuntimeException("Failed to process base patterns: " + e.getMessage(), e)); + } + }, error -> { + log.error("Failed to execute log insights analysis", error); + listener.onFailure(new RuntimeException("Log insights analysis failed: " + error.getMessage(), error)); + })); + } + + private String buildLogPatternPPL(String index, String timeField, String logFieldName, String startTime, String endTime) { + return String + .format( + Locale.ROOT, + "source=%s | where %s>'%s' and %s<'%s' | patterns %s method=brain mode=aggregation " + + "variable_count_threshold=3 | fields pattern_count, patterns_field", + index, + timeField, + startTime, + timeField, + endTime, + logFieldName + ); + } + + private List calculatePatternDifferences(Map basePatterns, Map selectionPatterns) { + List differences = new ArrayList<>(); + + double selectionTotal = selectionPatterns.values().stream().mapToDouble(Double::doubleValue).sum(); + double baseTotal = basePatterns.values().stream().mapToDouble(Double::doubleValue).sum(); + + for (Map.Entry entry : selectionPatterns.entrySet()) { + String pattern = entry.getKey(); + double selectionCount = entry.getValue(); + + if (basePatterns.containsKey(pattern)) { + double baseCount = basePatterns.get(pattern); + double lift = (selectionCount / selectionTotal) / (baseCount / baseTotal); + + if (lift < 1) { + lift = 1.0 / lift; + } + + if (lift > LOG_PATTERN_LIFT) { + differences.add(new PatternDiffResult(pattern, baseCount / baseTotal, selectionCount / selectionTotal, lift)); + } + } else { + // Pattern only exists in selection time range + differences.add(new PatternDiffResult(pattern, 0.0, selectionCount / selectionTotal, null)); + log.debug("New selection pattern detected: {} (count: {})", pattern, selectionCount); + } + } + + return differences; + } + + private void handlePPLError(Throwable error) { + String errorMsg = error.getMessage(); + String errorType = error.getClass().getSimpleName(); + log.error("PPL execution failed [{}]: {}", errorType, errorMsg); + String errorString = error.toString(); + String fullErrorMessage = errorMsg != null ? errorMsg : errorString; + throw new RuntimeException("PPL execution failed: " + fullErrorMessage, error); + } + + private double jaccardSimilarity(String pattern1, String pattern2) { + if (Strings.isEmpty(pattern1) && Strings.isEmpty(pattern2)) { + return 1.0; + } + if (Strings.isEmpty(pattern1) || Strings.isEmpty(pattern2)) { + return 0.0; + } + + Set set1 = new HashSet<>(Arrays.asList(pattern1.split("\\s+"))); + Set set2 = new HashSet<>(Arrays.asList(pattern2.split("\\s+"))); + + // Calculate union + Set union = new HashSet<>(set1); + union.addAll(set2); + + int intersectionSize = set1.size() + set2.size() - union.size(); + return (double) intersectionSize / union.size(); + } + + private void mergeSimilarPatterns(Map patternMap) { + if (patternMap.isEmpty()) { + return; + } + + List patterns = new ArrayList<>(patternMap.keySet()); + patterns.sort(String::compareTo); + Set removed = new HashSet<>(); + + for (int i = 0; i < patterns.size(); i++) { + String pattern1 = patterns.get(i); + if (removed.contains(pattern1)) { + continue; + } + + for (int j = i + 1; j < patterns.size(); j++) { + String pattern2 = patterns.get(j); + if (removed.contains(pattern2)) { + continue; + } + + if (jaccardSimilarity(pattern1, pattern2) > LOG_PATTERN_THRESHOLD) { + // Merge pattern2 into pattern1 + double count1 = patternMap.getOrDefault(pattern1, 0.0); + double count2 = patternMap.getOrDefault(pattern2, 0.0); + patternMap.put(pattern1, count1 + count2); + patternMap.remove(pattern2); + removed.add(pattern2); + log.debug("Merged similar patterns: '{}' + '{}' -> '{}'", pattern1, pattern2, pattern1); + } + } + } + + // Post-process patterns and merge those with similar processed forms + Map toReplace = new HashMap<>(); + for (String pattern : patternMap.keySet()) { + String processedPattern = postProcessPattern(pattern); + if (!processedPattern.equals(pattern)) { + toReplace.put(pattern, processedPattern); + } + } + + for (Map.Entry entry : toReplace.entrySet()) { + String originalPattern = entry.getKey(); + String processedPattern = entry.getValue(); + double count = patternMap.remove(originalPattern); + patternMap.merge(processedPattern, count, Double::sum); + } + + log.debug("Pattern merging completed: {} patterns remaining", patternMap.size()); + } + + private String postProcessPattern(String pattern) { + if (Strings.isEmpty(pattern)) { + return pattern; + } + + // Replace repeated <*> with single <*> using compiled pattern + pattern = REPEATED_WILDCARDS_PATTERN.matcher(pattern).replaceAll("<*>"); + return pattern; + } + + private void executePPLAndParseResult(String ppl, Function>, T> rowParser, ActionListener listener) { + try { + JSONObject jsonContent = new JSONObject(ImmutableMap.of("query", ppl)); + PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc"); + TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest); + + client + .execute( + PPLQueryAction.INSTANCE, + transportPPLQueryRequest, + getPPLTransportActionListener(ActionListener.wrap(transportPPLQueryResponse -> { + String result = transportPPLQueryResponse.getResult(); + if (Strings.isEmpty(result)) { + listener.onFailure(new RuntimeException("Empty PPL response")); + } else { + Map pplResult = gson.fromJson(result, new TypeToken>() { + }.getType()); + if (pplResult.containsKey("error")) { + Object errorObj = pplResult.get("error"); + String errorDetail; + if (errorObj instanceof Map) { + Map errorMap = (Map) errorObj; + Object reason = errorMap.get("reason"); + errorDetail = reason != null ? reason.toString() : errorMap.toString(); + } else { + errorDetail = errorObj != null ? errorObj.toString() : "Unknown error"; + } + throw new RuntimeException("PPL query error: " + errorDetail); + } + + Object datarowsObj = pplResult.get("datarows"); + if (!(datarowsObj instanceof List)) { + throw new IllegalStateException("Invalid PPL response format: missing or invalid datarows"); + } + + @SuppressWarnings("unchecked") + List> dataRows = (List>) datarowsObj; + if (dataRows.isEmpty()) { + log.warn("PPL query returned no data rows for the specified criteria"); + } + listener.onResponse(rowParser.apply(dataRows)); + } + }, error -> { + try { + handlePPLError(error); + } catch (Exception handledException) { + listener.onFailure(handledException); + } + })) + ); + } catch (Exception e) { + String errorMessage = String + .format( + Locale.ROOT, + "Failed to execute PPL query: %s. Query: %s", + e.getMessage(), + ppl.substring(0, Math.min(100, ppl.length())) + ); + log.error(errorMessage, e); + listener.onFailure(new RuntimeException(errorMessage, e)); + } + } + + public static class Factory implements Tool.Factory { + private Client client; + + private static LogPatternAnalysisTool.Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static LogPatternAnalysisTool.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (LogPatternAnalysisTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new LogPatternAnalysisTool.Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public LogPatternAnalysisTool create(Map map) { + + return new LogPatternAnalysisTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/clustering/ClusteringHelper.java b/src/main/java/org/opensearch/agent/tools/utils/clustering/ClusteringHelper.java new file mode 100644 index 00000000..c937e997 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/clustering/ClusteringHelper.java @@ -0,0 +1,514 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils.clustering; + +import static org.opensearch.agent.tools.utils.clustering.HierarchicalAgglomerativeClustering.calculateCosineSimilarity; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.math3.ml.clustering.CentroidCluster; +import org.apache.commons.math3.ml.clustering.DoublePoint; +import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer; +import org.apache.commons.math3.ml.distance.DistanceMeasure; + +import com.google.common.collect.Lists; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ClusteringHelper { + private final double logVectorsClusteringThreshold; + + /** + * Constructor for ClusteringHelper + * + * @param logVectorsClusteringThreshold Threshold for determining when two vectors are similar + * Should be between 0 and 1.0 (inclusive) + * @throws IllegalArgumentException if threshold is outside valid range + */ + public ClusteringHelper(double logVectorsClusteringThreshold) { + if (logVectorsClusteringThreshold < 0.0 || logVectorsClusteringThreshold > 1.0) { + throw new IllegalArgumentException("Clustering threshold must be between 0.0 and 1.0, got: " + logVectorsClusteringThreshold); + } + this.logVectorsClusteringThreshold = logVectorsClusteringThreshold; + } + + /** + * Cluster log vectors using a two-phase approach and get representative vectors. + * Input validation is performed to ensure log vectors are valid. + * + * @param logVectors Map of trace IDs to their vector representations + * @return List of trace IDs representing the centroids of each cluster + * @throws IllegalArgumentException if logVectors contains invalid entries + */ + public List clusterLogVectorsAndGetRepresentative(Map logVectors) { + if (logVectors == null || logVectors.isEmpty()) { + return new ArrayList<>(); + } + + // Validate input vectors + validateLogVectors(logVectors); + + log.debug("Starting two-phase clustering for {} log vectors", logVectors.size()); + + // Convert map to arrays for processing + double[][] vectors = new double[logVectors.size()][]; + Map indexTraceIdMap = new HashMap<>(); + convertLogVectorsToArrays(logVectors, vectors, indexTraceIdMap); + + List finalCentroids; + + // Choose clustering approach based on dataset size + if (logVectors.size() > 1000) { + finalCentroids = processTwoPhaseClusteringForLargeDataset(vectors, indexTraceIdMap); + } else { + // Small dataset - use hierarchical clustering directly + finalCentroids = performClustering(vectors, indexTraceIdMap); + } + + log + .debug( + "Two-phase clustering completed: {} input vectors -> {} representative centroids", + logVectors.size(), + finalCentroids.size() + ); + + return finalCentroids; + } + + /** + * Converts log vectors map to arrays for processing + * + * @param logVectors Map of trace IDs to vector representations + * @param vectors Output array for vectors + * @param indexTraceIdMap Output map for index to trace ID mapping + */ + private void convertLogVectorsToArrays(Map logVectors, double[][] vectors, Map indexTraceIdMap) { + int i = 0; + for (Map.Entry entry : logVectors.entrySet()) { + vectors[i] = entry.getValue(); + indexTraceIdMap.put(i, entry.getKey()); + i++; + } + } + + /** + * Processes large datasets using two-phase clustering approach + * + * @param vectors Array of vectors + * @param indexTraceIdMap Mapping from vector index to trace ID + * @return List of trace IDs representing cluster centroids + */ + private List processTwoPhaseClusteringForLargeDataset(double[][] vectors, Map indexTraceIdMap) { + List finalCentroids = new ArrayList<>(); + log.debug("Large dataset detected ({}), applying K-means pre-clustering", vectors.length); + + // Calculate optimal number of K-means clusters (target 500 points per cluster) + int targetClusterSize = 500; + int numKMeansClusters = (vectors.length + (targetClusterSize - 1)) / targetClusterSize; + + log.debug("Using {} K-means clusters for pre-clustering", numKMeansClusters); + + try { + List> kMeansClusters = performKMeansClustering(vectors, numKMeansClusters); + + // Process each K-means cluster + for (int clusterIdx = 0; clusterIdx < kMeansClusters.size(); clusterIdx++) { + List kMeansCluster = kMeansClusters.get(clusterIdx); + log.debug("Processing K-means cluster {} with {} points", clusterIdx, kMeansCluster.size()); + + List clusterCentroids = processCluster(kMeansCluster, vectors, indexTraceIdMap, clusterIdx); + finalCentroids.addAll(clusterCentroids); + } + + } catch (Exception e) { + log.warn("K-means clustering failed, falling back to hierarchical clustering only: {}", e.getMessage()); + // Fallback to hierarchical clustering only + finalCentroids = performClustering(vectors, indexTraceIdMap); + } + + return finalCentroids; + } + + /** + * Processes a single K-means cluster + * + * @param kMeansCluster List of indices in the K-means cluster + * @param vectors Original vector array + * @param indexTraceIdMap Original mapping from indices to trace IDs + * @param clusterIdx Index of the cluster (for logging) + * @return List of trace IDs representing cluster centroids + */ + private List processCluster( + List kMeansCluster, + double[][] vectors, + Map indexTraceIdMap, + int clusterIdx + ) { + if (kMeansCluster.isEmpty()) { + return List.of(); + } + + if (kMeansCluster.size() == 1) { + return List.of(indexTraceIdMap.get(kMeansCluster.getFirst())); + } + + if (kMeansCluster.size() > 500) { + log.debug("The cluster size is greater than 500, performing partitioned clustering"); + return performHierarchicalClusteringOfPartition(kMeansCluster, vectors, indexTraceIdMap); + } + + log.debug("Applying hierarchical clustering to K-means cluster {} with {} points", clusterIdx, kMeansCluster.size()); + + // Extract vectors for this K-means cluster + double[][] clusterVectors = extractVectors(kMeansCluster, vectors); + Map clusterIndexTraceIdMap = createTraceIdMapping(kMeansCluster, indexTraceIdMap); + + // Apply hierarchical clustering within this K-means cluster + return performClustering(clusterVectors, clusterIndexTraceIdMap); + } + + /** + * Perform K-means clustering using Apache Commons Math3 + * + * @param vectors Input vectors for clustering + * @param numClusters Number of K-means clusters + * @return List of clusters, each containing indices of points in that cluster + * @throws RuntimeException if clustering fails + */ + private List> performKMeansClustering(double[][] vectors, int numClusters) { + if (vectors == null || vectors.length == 0) { + return new ArrayList<>(); + } + + if (numClusters <= 0) { + numClusters = 1; + } + + // Cap number of clusters to vector size + numClusters = Math.min(numClusters, vectors.length); + + try { + KMeansPlusPlusClusterer clusterer = createKMeansClusterer(numClusters); + List points = convertVectorsToPoints(vectors); + List> clusters = clusterer.cluster(points); + return extractClusterIndices(clusters, vectors); + } catch (Exception e) { + log.error("K-means clustering failed: {}", e.getMessage(), e); + throw new RuntimeException("K-means clustering failed: " + e.getMessage(), e); + } + } + + /** + * Creates a KMeansPlusPlusClusterer with cosine distance metric + * + * @param numClusters Number of clusters to create + * @return Configured KMeansPlusPlusClusterer + */ + private KMeansPlusPlusClusterer createKMeansClusterer(int numClusters) { + return new KMeansPlusPlusClusterer<>( + numClusters, + 300, // Maximum iterations + (DistanceMeasure) (a, b) -> 1 - calculateCosineSimilarity(a, b) + ); + } + + /** + * Converts vector array to list of DoublePoint objects + * + * @param vectors Array of vectors + * @return List of DoublePoint objects + */ + private List convertVectorsToPoints(double[][] vectors) { + List points = new ArrayList<>(vectors.length); + for (double[] vector : vectors) { + points.add(new DoublePoint(vector)); + } + return points; + } + + /** + * Validates log vectors to ensure they are valid for clustering + * + * @param logVectors Map of trace IDs to vector representations + * @throws IllegalArgumentException if vectors are invalid + */ + private void validateLogVectors(Map logVectors) { + int vectorDimension = -1; + + for (Map.Entry entry : logVectors.entrySet()) { + String traceId = entry.getKey(); + double[] vector = entry.getValue(); + + if (traceId == null || traceId.isEmpty()) { + throw new IllegalArgumentException("Trace ID cannot be null or empty"); + } + + if (vector == null) { + throw new IllegalArgumentException("Vector for trace ID '" + traceId + "' is null"); + } + + if (vector.length == 0) { + throw new IllegalArgumentException("Vector for trace ID '" + traceId + "' is empty"); + } + + // Ensure all vectors have the same dimension + if (vectorDimension == -1) { + vectorDimension = vector.length; + } else if (vector.length != vectorDimension) { + throw new IllegalArgumentException( + "Vector dimension mismatch: expected " + + vectorDimension + + " but got " + + vector.length + + " for trace ID '" + + traceId + + "'" + ); + } + + // Check for NaN or Infinity values + for (int i = 0; i < vector.length; i++) { + if (Double.isNaN(vector[i]) || Double.isInfinite(vector[i])) { + throw new IllegalArgumentException( + "Vector for trace ID '" + traceId + "' contains invalid value at index " + i + ": " + vector[i] + ); + } + } + } + } + + /** + * Extracts original vector indices for each K-means cluster + * + * @param clusters K-means clustering result + * @param vectors Original vector array + * @return List of clusters with original vector indices + */ + private List> extractClusterIndices(List> clusters, double[][] vectors) { + List> result = new ArrayList<>(); + for (CentroidCluster cluster : clusters) { + List clusterIndices = new ArrayList<>(); + for (DoublePoint point : cluster.getPoints()) { + // Find the original index of this point + for (int i = 0; i < vectors.length; i++) { + if (Arrays.equals(vectors[i], point.getPoint())) { + clusterIndices.add(i); + break; + } + } + } + if (!clusterIndices.isEmpty()) { + result.add(clusterIndices); + } + } + return result; + } + + /** + * Generic method to perform clustering with specified linkage method + * + * @param vectors Input vectors for clustering + * @param indexTraceIdMap Mapping from vector index to trace ID + * @return List of trace IDs representing cluster centroids + */ + private List performClustering(double[][] vectors, Map indexTraceIdMap) { + if (vectors == null || vectors.length == 0) { + return List.of(); + } + + if (vectors.length == 1) { + String traceId = indexTraceIdMap.get(0); + return List.of(traceId); + } + + List centroids = new ArrayList<>(); + try { + HierarchicalAgglomerativeClustering hac = new HierarchicalAgglomerativeClustering(vectors); + List clusters = hac + .fit(HierarchicalAgglomerativeClustering.LinkageMethod.COMPLETE, this.logVectorsClusteringThreshold); + + for (HierarchicalAgglomerativeClustering.ClusterNode cluster : clusters) { + int centroidIndex = hac.getClusterCentroid(cluster); + String traceId = indexTraceIdMap.get(centroidIndex); + centroids.add(traceId); + } + } catch (Exception e) { + log.error("Hierarchical clustering failed: {}", e.getMessage(), e); + // Fallback: return first point as representative if available + String traceId = indexTraceIdMap.get(0); + centroids.add(traceId); + } + + return centroids; + } + + /** + * If the first stage K-means clustering results exceed 500 clusters, implement batch processing and merge the results. + * @param kMeansCluster Clustering results from the first stage. + * @param vectors List of vectors by index. + * @param indexTraceIdMap Map of index to their trace id. + * @return List of trace IDs representing cluster centroids after partitioned processing + */ + private List performHierarchicalClusteringOfPartition( + List kMeansCluster, + double[][] vectors, + Map indexTraceIdMap + ) { + List> partition = Lists.partition(kMeansCluster, 500); + + List vectorRes = new ArrayList<>(); + Map index2Trace = new HashMap<>(); + + for (List partList : partition) { + double[][] clusterVectors = extractVectors(partList, vectors); + Map clusterIndexTraceIdMap = createTraceIdMapping(partList, indexTraceIdMap); + + log.debug("Starting performHierarchicalClusteringOfPartition!"); + processPartition(clusterVectors, clusterIndexTraceIdMap, vectorRes, index2Trace); + } + + return removeSimilarVectors(vectorRes, index2Trace); + } + + /** + * Extracts vectors for a partition based on indices + * + * @param partList List of indices in the partition + * @param vectors Original vector array + * @return Array of vectors for the partition + */ + private double[][] extractVectors(List partList, double[][] vectors) { + double[][] clusterVectors = new double[partList.size()][]; + for (int j = 0; j < partList.size(); j++) { + int originalIndex = partList.get(j); + clusterVectors[j] = vectors[originalIndex]; + } + return clusterVectors; + } + + /** + * Creates a mapping from partition indices to trace IDs + * + * @param partList List of indices in the partition + * @param indexTraceIdMap Original mapping from indices to trace IDs + * @return Mapping from partition indices to trace IDs + */ + private Map createTraceIdMapping(List partList, Map indexTraceIdMap) { + Map clusterIndexTraceIdMap = new HashMap<>(); + for (int j = 0; j < partList.size(); j++) { + int originalIndex = partList.get(j); + clusterIndexTraceIdMap.put(j, indexTraceIdMap.get(originalIndex)); + } + return clusterIndexTraceIdMap; + } + + /** + * Processes a partition for hierarchical clustering + * + * @param clusterVectors Vectors in the partition + * @param clusterIndexTraceIdMap Mapping from partition indices to trace IDs + * @param vectorRes Result vector collection to append to + * @param index2Trace Result mapping from indices to trace IDs to append to + */ + private void processPartition( + double[][] clusterVectors, + Map clusterIndexTraceIdMap, + List vectorRes, + Map index2Trace + ) { + if (clusterVectors.length == 0) { + return; + } + + if (clusterVectors.length == 1) { + vectorRes.add(clusterVectors[0]); + index2Trace.put(vectorRes.size() - 1, clusterIndexTraceIdMap.get(0)); + return; + } + + try { + HierarchicalAgglomerativeClustering hac = new HierarchicalAgglomerativeClustering(clusterVectors); + List clusters = hac + .fit(HierarchicalAgglomerativeClustering.LinkageMethod.COMPLETE, this.logVectorsClusteringThreshold); + log.info("Completing performHierarchicalClusteringOfPartition!"); + + for (HierarchicalAgglomerativeClustering.ClusterNode cluster : clusters) { + int centroidIndex = hac.getClusterCentroid(cluster); + vectorRes.add(clusterVectors[centroidIndex]); + index2Trace.put(vectorRes.size() - 1, clusterIndexTraceIdMap.get(centroidIndex)); + } + } catch (Exception e) { + log.error("Hierarchical clustering failed: {}", e.getMessage(), e); + // Fallback: return first point as representative + vectorRes.add(clusterVectors[0]); + index2Trace.put(vectorRes.size() - 1, clusterIndexTraceIdMap.get(0)); + } + } + + /** + * Compute the cosine similarity pairwise and remove vectors that are too similar. + * Vectors with similarity higher than threshold are considered duplicates. + * + * @param vectorRes List of vectors + * @param index2Trace Map of index to their trace id + * @return List of trace IDs after removing similar vectors + */ + private List removeSimilarVectors(List vectorRes, Map index2Trace) { + Set toRemove = new HashSet<>(); + + for (int i = 0; i < vectorRes.size(); i++) { + if (toRemove.contains(i)) { + continue; + } + + for (int j = i + 1; j < vectorRes.size(); j++) { + if (toRemove.contains(j)) { + continue; + } + + double similarity = calculateCosineSimilarity(vectorRes.get(i), vectorRes.get(j)); + // If similarity is higher than threshold, vectors are considered similar enough to remove one + if (similarity > this.logVectorsClusteringThreshold) { + log.debug("Removing similar vector with similarity: {}", similarity); + toRemove.add(j); + } + } + } + + log.debug("Removed {} similar vectors out of {}", toRemove.size(), vectorRes.size()); + return collectNonRemovedTraceIds(vectorRes, index2Trace, toRemove); + } + + /** + * Collects trace IDs for vectors that are not marked for removal + * + * @param vectors List of vectors + * @param indexToTraceMap Mapping from indices to trace IDs + * @param indicesToRemove Set of indices to exclude + * @return List of trace IDs for non-removed vectors + */ + private List collectNonRemovedTraceIds( + List vectors, + Map indexToTraceMap, + Set indicesToRemove + ) { + List result = new ArrayList<>(vectors.size() - indicesToRemove.size()); + for (int i = 0; i < vectors.size(); i++) { + if (!indicesToRemove.contains(i)) { + result.add(indexToTraceMap.get(i)); + } + } + return result; + } + +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/clustering/HierarchicalAgglomerativeClustering.java b/src/main/java/org/opensearch/agent/tools/utils/clustering/HierarchicalAgglomerativeClustering.java new file mode 100644 index 00000000..a9600118 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/clustering/HierarchicalAgglomerativeClustering.java @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils.clustering; + +import java.util.ArrayList; +import java.util.List; + +public class HierarchicalAgglomerativeClustering { + + private final double[][] data; + private final double[][] distanceMatrix; + private final int nSamples; + private final int nFeatures; + + public enum LinkageMethod { + SINGLE, // Minimum distance between clusters + COMPLETE, // Maximum distance between clusters + AVERAGE // Average distance between clusters + } + + /** + * Internal cluster node for tracking during clustering process + */ + public static class ClusterNode { + final int id; + final List samples; + final int size; + + ClusterNode(int id, int sample) { + this.id = id; + this.samples = new ArrayList<>(); + this.samples.add(sample); + this.size = 1; + } + + ClusterNode(int id, ClusterNode left, ClusterNode right) { + this.id = id; + this.samples = new ArrayList<>(); + this.samples.addAll(left.samples); + this.samples.addAll(right.samples); + this.size = left.size + right.size; + } + } + + /** + * Constructor - computes cosine distance matrix + */ + public HierarchicalAgglomerativeClustering(double[][] data) { + this.data = data; + this.nSamples = data.length; + this.nFeatures = data[0].length; + this.distanceMatrix = new double[nSamples][nSamples]; + + // Compute cosine distance matrix + computeCosineDistanceMatrix(); + } + + /** + * Compute pairwise cosine distances + * Cosine distance = 1 - cosine similarity + */ + private void computeCosineDistanceMatrix() { + // Pre-calculate norms for efficiency + double[] norms = new double[nSamples]; + for (int i = 0; i < nSamples; i++) { + double norm = 0.0; + for (int j = 0; j < nFeatures; j++) { + norm += data[i][j] * data[i][j]; + } + norms[i] = Math.sqrt(norm); + } + + // Calculate cosine distances + for (int i = 0; i < nSamples; i++) { + distanceMatrix[i][i] = 0.0; + for (int j = i + 1; j < nSamples; j++) { + double similarity = calculateCosineSimilarity(data[i], data[j], norms[i], norms[j]); + double distance = 1.0 - similarity; + distanceMatrix[i][j] = distanceMatrix[j][i] = distance; + } + } + } + + /** + * Optimized cosine similarity calculation with pre-calculated norms + */ + private static double calculateCosineSimilarity(double[] a, double[] b, double normA, double normB) { + if (normA == 0.0 || normB == 0.0) { + return 0.0; + } + + double dotProduct = 0.0; + for (int i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; + } + + return dotProduct / (normA * normB); + } + + /** + * Perform hierarchical clustering with distance threshold + * + * @param linkage The linkage method to use + * @param threshold Distance threshold - clustering stops when minimum distance exceeds this value + * @return List of final clusters + */ + public List fit(LinkageMethod linkage, double threshold) { + if (threshold < 0) { + throw new IllegalArgumentException("Distance threshold must be non-negative"); + } + + // Initialize clusters - each sample starts as its own cluster + List activeClusters = new ArrayList<>(); + for (int i = 0; i < nSamples; i++) { + activeClusters.add(new ClusterNode(i, i)); + } + + int nextClusterId = nSamples; + + // Main clustering loop + while (activeClusters.size() > 1) { + // Find the closest pair of clusters + int[] closestPair = findClosestClusters(activeClusters, linkage, threshold); + if (closestPair == null) { + break; + } + + int i = closestPair[0]; + int j = closestPair[1]; + + // Merge the two closest clusters + ClusterNode newCluster = new ClusterNode(nextClusterId++, activeClusters.get(i), activeClusters.get(j)); + + // Remove old clusters and add new one + activeClusters.remove(Math.max(i, j)); + activeClusters.remove(Math.min(i, j)); + activeClusters.add(newCluster); + } + + return activeClusters; + } + + /** + * Find the two closest clusters + */ + private int[] findClosestClusters(List clusters, LinkageMethod linkage, double threshold) { + double minDistance = threshold; + int bestI = -1, bestJ = -1; + + for (int i = 0; i < clusters.size(); i++) { + for (int j = i + 1; j < clusters.size(); j++) { + double distance = computeClusterDistance(clusters.get(i), clusters.get(j), linkage); + if (distance < minDistance) { + minDistance = distance; + bestI = i; + bestJ = j; + } + } + } + + return (bestI == -1) ? null : new int[] { bestI, bestJ }; + } + + /** + * Compute distance between clusters using specified linkage method + */ + private double computeClusterDistance(ClusterNode c1, ClusterNode c2, LinkageMethod linkage) { + return switch (linkage) { + case SINGLE -> singleLinkage(c1, c2); + case COMPLETE -> completeLinkage(c1, c2); + case AVERAGE -> averageLinkage(c1, c2); + }; + } + + /** + * Single linkage: minimum distance between any two points in different clusters + */ + private double singleLinkage(ClusterNode c1, ClusterNode c2) { + double minDist = Double.MAX_VALUE; + + for (int i : c1.samples) { + for (int j : c2.samples) { + double dist = distanceMatrix[i][j]; + if (dist < minDist) { + minDist = dist; + // Early termination for very small distances + if (minDist < 1e-10) { + return minDist; + } + } + } + } + + return minDist; + } + + /** + * Complete linkage: maximum distance between any two points in different clusters + */ + private double completeLinkage(ClusterNode c1, ClusterNode c2) { + double maxDist = Double.MIN_VALUE; + + for (int i : c1.samples) { + for (int j : c2.samples) { + double dist = distanceMatrix[i][j]; + if (dist > maxDist) { + maxDist = dist; + } + } + } + + return maxDist; + } + + /** + * Average linkage: average distance between all pairs of points in different clusters + */ + private double averageLinkage(ClusterNode c1, ClusterNode c2) { + double sumDist = 0.0; + int count = 0; + + for (int i : c1.samples) { + for (int j : c2.samples) { + sumDist += distanceMatrix[i][j]; + count++; + } + } + + return sumDist / count; + } + + /** + * Get cluster centroid (medoid) - the point with minimum total distance to other points in cluster + */ + public int getClusterCentroid(ClusterNode cluster) { + if (cluster.samples.size() == 1) { + return cluster.samples.getFirst(); + } + + int medoidIndex = cluster.samples.getFirst(); + double minTotalDistance = Double.MAX_VALUE; + + for (int pointI : cluster.samples) { + double totalDistance = 0.0; + for (int pointJ : cluster.samples) { + totalDistance += distanceMatrix[pointI][pointJ]; + } + + if (totalDistance < minTotalDistance) { + minTotalDistance = totalDistance; + medoidIndex = pointI; + } + } + + return medoidIndex; + } + + /** + * Backward compatibility method for cosine similarity calculation + */ + public static double calculateCosineSimilarity(double[] a, double[] b) { + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + + if (normA == 0 || normB == 0) { + return 0; + } + + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } +} diff --git a/src/test/java/org/opensearch/agent/ToolPluginTests.java b/src/test/java/org/opensearch/agent/ToolPluginTests.java index 22150cf6..f7cfc3dd 100644 --- a/src/test/java/org/opensearch/agent/ToolPluginTests.java +++ b/src/test/java/org/opensearch/agent/ToolPluginTests.java @@ -96,7 +96,7 @@ public void test_getRestHandlers_successful() { @Test public void test_getToolFactories_successful() { - assertEquals(12, toolPlugin.getToolFactories().size()); + assertEquals(13, toolPlugin.getToolFactories().size()); } @Test diff --git a/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java b/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java new file mode 100644 index 00000000..3adffe17 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java @@ -0,0 +1,566 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.HashMap; +import java.util.Map; + +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; +import org.opensearch.transport.client.Client; + +import com.google.common.collect.ImmutableMap; +import com.google.gson.JsonElement; + +import lombok.SneakyThrows; + +public class LogPatternAnalysisToolTests { + + private Map params = new HashMap<>(); + private final Client client = mock(Client.class); + @Mock + private TransportPPLQueryResponse pplQueryResponse; + + @SneakyThrows + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + LogPatternAnalysisTool.Factory.getInstance().init(client); + } + + private void mockPPLInvocation(String response) { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + when(pplQueryResponse.getResult()).thenReturn(response); + } + + @Test + @SneakyThrows + public void testCreateTool() { + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + assertEquals("LogPatternAnalysisTool", tool.getType()); + assertEquals("LogPatternAnalysisTool", tool.getName()); + assertEquals(LogPatternAnalysisTool.Factory.getInstance().getDefaultDescription(), tool.getDescription()); + assertNull(LogPatternAnalysisTool.Factory.getInstance().getDefaultVersion()); + } + + @Test + public void testValidate() { + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + // Valid parameters + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ) + ) + ); + + // Missing required parameters + assertFalse(tool.validate(Map.of("index", "test_index"))); + assertFalse(tool.validate(Map.of())); + } + + @Test + @SneakyThrows + public void testLogInsightExecution() { + String pplResponse = + """ + {"schema":[{"name":"patterns_field","type":"string"},{"name":"pattern_count","type":"long"},{"name":"sample_logs","type":"array"}], + "datarows":[["Error in processing <*>",5,["Error in processing request","Error in processing data"]], + ["Failed to connect <*>",3,["Failed to connect to database","Failed to connect to server"]]], + "total":2,"size":2} + """; + + mockPPLInvocation(pplResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener.wrap(response -> { + System.out.println(response); + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("logInsights")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testLogPatternDiffAnalysis() { + // Mock different responses for base and selection time ranges + String baseResponse = """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[100,"User login successful"],[20,"Database query executed"],[10,"Cache hit"]], + "total":3,"size":3} + """; + + String selectionResponse = + """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[50,"User login successful"],[80,"Error in authentication <*>"],[15,"Connection timeout <*>"],[5,"Database query executed"]], + "total":4,"size":4} + """; + + // Mock sequential PPL calls - first base, then selection + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + when(pplQueryResponse.getResult()) + .thenReturn(baseResponse) // First call returns base data + .thenReturn(selectionResponse); // Second call returns selection data + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "baseTimeRangeStart", + "2025-01-01T00:00:00Z", + "baseTimeRangeEnd", + "2025-01-01T01:00:00Z", + "selectionTimeRangeStart", + "2025-01-01T01:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T02:00:00Z" + ), + ActionListener.wrap(response -> { + System.out.println("Pattern diff response: " + response); + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("patternMapDifference")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testLogSequenceAnalysis() { + // Mock different responses for base and selection time ranges + String baseResponse = + """ + {"schema":[{"name":"traceId","type":"string"},{"name":"patterns_field","type":"string"},{"name":"@timestamp","type":"timestamp"}], + "datarows":[["trace1","User login attempt","2025-01-01T00:00:00Z"],["trace1","Authentication successful","2025-01-01T00:00:01Z"],["trace1","Session created","2025-01-01T00:00:02Z"], + ["trace2","User login attempt","2025-01-01T00:00:10Z"],["trace2","Authentication successful","2025-01-01T00:00:11Z"],["trace2","Session created","2025-01-01T00:00:12Z"]], + "total":6,"size":6} + """; + + String selectionResponse = + """ + {"schema":[{"name":"traceId","type":"string"},{"name":"patterns_field","type":"string"},{"name":"@timestamp","type":"timestamp"}], + "datarows":[["trace3","User login attempt","2025-01-01T01:00:00Z"],["trace3","Authentication failed","2025-01-01T01:00:01Z"],["trace3","Account locked","2025-01-01T01:00:02Z"], + ["trace4","Database connection timeout","2025-01-01T01:00:10Z"],["trace4","Retry connection","2025-01-01T01:00:11Z"],["trace4","Connection failed","2025-01-01T01:00:12Z"], + ["trace5","User login attempt","2025-01-01T01:00:20Z"],["trace5","Authentication successful","2025-01-01T01:00:21Z"],["trace5","Session created","2025-01-01T01:00:22Z"]], + "total":9,"size":9} + """; + + // Mock sequential PPL calls - first base, then selection + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + when(pplQueryResponse.getResult()) + .thenReturn(baseResponse) // First call returns base data + .thenReturn(selectionResponse); // Second call returns selection data + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "traceFieldName", + "traceId", + "baseTimeRangeStart", + "2025-01-01T00:00:00Z", + "baseTimeRangeEnd", + "2025-01-01T01:00:00Z", + "selectionTimeRangeStart", + "2025-01-01T01:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T02:00:00Z" + ), + ActionListener.wrap(response -> { + System.out.println("Sequence analysis response: " + response); + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("BASE") || result.getAsJsonObject().has("EXCEPTIONAL")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidParameters() { + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap.of("index", "test_index"), + ActionListener + .wrap( + response -> fail("Should have failed with invalid parameters"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Missing required parameters")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithEmptyPPLResponse() { + String emptyResponse = """ + {"schema":[],"datarows":[],"total":0,"size":0} + """; + + mockPPLInvocation(emptyResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener.wrap(response -> { + System.out.println("response: " + response); + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("logInsights")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testExecutionFailedInPPL() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("PPL execution failed")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("PPL execution failed:")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithIndexNotFound() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("IndexNotFoundException: no such index")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "nonexistent_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with IndexNotFoundException"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("IndexNotFoundException")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithEmptyPPLResult() { + String emptyResponse = ""; + mockPPLInvocation(emptyResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with empty response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Empty PPL response")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidPPLResponse() { + String invalidResponse = "{\"invalid\":\"response\"}"; + mockPPLInvocation(invalidResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with invalid response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Invalid PPL response")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithNonExistentIndex() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("no such index [nonexistent_index]")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "nonexistent_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with non-existent index"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("no such index")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithNonExistentLogField() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Unknown field [nonexistent_field]")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "nonexistent_field", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with non-existent field"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Unknown field")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidTimeFormat() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Invalid date format: invalid-time-format")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "invalid-time-format", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with invalid time format"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Invalid date format")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithPPLErrorResponse() { + String errorResponse = "{\"error\":{\"type\":\"parsing_exception\",\"reason\":\"Syntax error in PPL query\"}}"; + mockPPLInvocation(errorResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with PPL error response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("PPL query error")) + ) + ); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/utils/ClusteringHelperTests.java b/src/test/java/org/opensearch/agent/tools/utils/ClusteringHelperTests.java new file mode 100644 index 00000000..81345627 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/utils/ClusteringHelperTests.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.agent.tools.utils.clustering.ClusteringHelper; +import org.opensearch.test.OpenSearchTestCase; + +public class ClusteringHelperTests extends OpenSearchTestCase { + + public void testConstructorWithValidThreshold() { + new ClusteringHelper(0.0); + new ClusteringHelper(0.5); + new ClusteringHelper(1.0); + } + + public void testConstructorWithInvalidThreshold() { + assertThrows(IllegalArgumentException.class, () -> new ClusteringHelper(-0.1)); + assertThrows(IllegalArgumentException.class, () -> new ClusteringHelper(1.1)); + } + + public void testClusterLogVectorsWithNullInput() { + ClusteringHelper helper = new ClusteringHelper(0.8); + assertTrue(helper.clusterLogVectorsAndGetRepresentative(null).isEmpty()); + } + + public void testClusterLogVectorsWithEmptyInput() { + ClusteringHelper helper = new ClusteringHelper(0.8); + assertTrue(helper.clusterLogVectorsAndGetRepresentative(new HashMap<>()).isEmpty()); + } + + public void testClusterLogVectorsWithSingleVector() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 2.0, 3.0 }); + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertEquals(1, result.size()); + assertEquals("trace1", result.get(0)); + } + + public void testClusterLogVectorsWithSmallDataset() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 0.0, 0.0 }); + logVectors.put("trace2", new double[] { 0.9, 0.1, 0.0 }); + logVectors.put("trace3", new double[] { 0.0, 1.0, 0.0 }); + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertFalse(result.isEmpty()); + assertTrue(result.size() <= 3); + } + + public void testValidateLogVectorsWithNullTraceId() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put(null, new double[] { 1.0, 2.0 }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithEmptyTraceId() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("", new double[] { 1.0, 2.0 }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithNullVector() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", null); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithEmptyVector() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] {}); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithDimensionMismatch() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 2.0 }); + logVectors.put("trace2", new double[] { 1.0, 2.0, 3.0 }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithNaNValue() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, Double.NaN }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithInfiniteValue() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, Double.POSITIVE_INFINITY }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testClusterLogVectorsWithLargeDataset() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + + // Create 1500 vectors to trigger large dataset processing + for (int i = 0; i < 1500; i++) { + double[] vector = new double[] { Math.random(), Math.random(), Math.random() }; + logVectors.put("trace" + i, vector); + } + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertFalse(result.isEmpty()); + assertTrue(result.size() < 1500); // Should reduce the number of representatives + } + + public void testClusterLogVectorsWithIdenticalVectors() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + double[] vector = { 1.0, 2.0, 3.0 }; + + for (int i = 0; i < 5; i++) { + logVectors.put("trace" + i, vector.clone()); + } + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertEquals(1, result.size()); // Should cluster identical vectors into one + } + + public void testClusterLogVectorsWithHighThreshold() { + ClusteringHelper helper = new ClusteringHelper(0.99); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 0.0 }); + logVectors.put("trace2", new double[] { 0.0, 1.0 }); + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertEquals(2, result.size()); // High threshold should keep vectors separate + } + + public void testClusterLogVectorsWithLowThreshold() { + ClusteringHelper helper = new ClusteringHelper(0.1); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 0.1 }); + logVectors.put("trace2", new double[] { 0.9, 0.2 }); + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertTrue(result.size() <= 2); // Low threshold may cluster similar vectors + } +} diff --git a/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java b/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java new file mode 100644 index 00000000..4ad50662 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java @@ -0,0 +1,220 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; + +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import lombok.SneakyThrows; + +public class LogPatternAnalysisToolIT extends BaseAgentToolsIT { + + public static String requestBodyResourceFile = + "org/opensearch/agent/tools/register_flow_agent_of_log_pattern_analysis_tool_request_body.json"; + public String registerAgentRequestBody; + public static String TEST_LOG_INDEX_NAME = "test_log_analysis_index"; + + private String agentId; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + prepareLogIndex(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(requestBodyResourceFile).toURI())); + agentId = createAgent(registerAgentRequestBody); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + @SneakyThrows + private void prepareLogIndex() { + createIndexWithConfiguration( + TEST_LOG_INDEX_NAME, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"@timestamp\": {\n" + + " \"type\": \"date\",\n" + + " \"format\": \"yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis\"\n" + + " },\n" + + " \"message\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"traceId\": {\n" + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + + // Add baseline data in base time range (09:00:00 to 10:00:00) + addDocToIndex( + TEST_LOG_INDEX_NAME, + "base1", + List.of("@timestamp", "message", "traceId"), + List.of("2025-01-01 09:30:00", "System startup completed", "trace-base-001") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "base2", + List.of("@timestamp", "message", "traceId"), + List.of("2025-01-01 09:45:00", "Database connection established", "trace-base-002") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "base3", + List.of("@timestamp", "message", "traceId"), + List.of("2025-01-01 09:50:00", "User session initialized", "trace-base-003") + ); + + // Add test log data with error keywords for logInsight + addDocToIndex( + TEST_LOG_INDEX_NAME, + "1", + List.of("@timestamp", "message", "traceId"), + List.of("2025-01-01 10:00:00", "User login successful", "trace-001") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "2", + List.of("@timestamp", "message", "traceId"), + List.of("2025-01-01 10:01:00", "Database connection established", "trace-001") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "3", + List.of("@timestamp", "message", "traceId"), + List.of("2025-01-01 10:02:00", "Error connection timeout failed", "trace-002") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "4", + List.of("@timestamp", "message", "traceId"), + List.of("2025-01-01 10:03:00", "User logout completed", "trace-001") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "5", + List.of("@timestamp", "message", "traceId"), + List.of("2025-01-01 10:04:00", "Exception in authentication service", "trace-003") + ); + } + + @SneakyThrows + public void testLogPatternAnalysisToolLogInsight() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("logInsights")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithBaseTimeRange() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("patternMapDifference")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithTraceField() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"traceFieldName\": \"traceId\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}", + TEST_LOG_INDEX_NAME + ) + ); + System.out.println(result); + assertNotNull(result); + assertTrue(result.contains("BASE") || result.contains("EXCEPTIONAL")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolMissingRequiredParameters() { + Exception exception = assertThrows(Exception.class, () -> executeAgent(agentId, "{\"parameters\": {\"index\": \"%s\"}}")); + MatcherAssert.assertThat(exception.getMessage(), containsString("Missing required parameters")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolInvalidIndex() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"non_existent_index\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("no such index")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolNonExistentLogField() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"nonexistent_field\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("not a valid term")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolInvalidTimeFormat() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"invalid-time-format\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("not a valid term")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolEmptyTimeRange() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"2025-01-01 10:05:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:00:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("not a valid term")); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_log_pattern_analysis_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_log_pattern_analysis_tool_request_body.json new file mode 100644 index 00000000..86bdbc7c --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_log_pattern_analysis_tool_request_body.json @@ -0,0 +1,10 @@ +{ + "name": "Test_log_pattern_analysis_tool_flow_agent", + "type": "flow", + "tools": [ + { + "type": "LogPatternAnalysisTool", + "parameters": {} + } + ] +} \ No newline at end of file From ca93fb4aae20d71d2db1314acb45dbc607849aab Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:33:08 +0800 Subject: [PATCH 07/30] Increment version to 3.3.0-SNAPSHOT (#626) Signed-off-by: opensearch-ci-bot Co-authored-by: opensearch-ci-bot --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index 541ba649..423333bb 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ import com.github.jengelman.gradle.plugins.shadow.ShadowPlugin buildscript { ext { opensearch_group = "org.opensearch" - opensearch_version = System.getProperty("opensearch.version", "3.2.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.3.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") isSnapshot = "true" == System.getProperty("build.snapshot", "true") version_tokens = opensearch_version.tokenize('-') From 4450039d7e52c1c459b9b53b428f699085acf190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E4=BD=B3=E5=A6=82=EF=BC=88Jiaru=20Jiang=EF=BC=89?= Date: Thu, 18 Sep 2025 14:07:24 +0800 Subject: [PATCH 08/30] feat: Data Distribution Tool (#634) * add:init dataDistribution Signed-off-by: Jiaru Jiang * fix:spotlessApply Signed-off-by: Jiaru Jiang * fix:forbidden API Signed-off-by: Jiaru Jiang * fix:UT Signed-off-by: Jiaru Jiang * add:add unit and integration tests to improve code coverage Signed-off-by: Jiaru Jiang * fix:spotlessApply Signed-off-by: Jiaru Jiang * fix:Test assertion Signed-off-by: Jiaru Jiang * fix:getPPLQueryWithTimeRange Signed-off-by: Jiaru Jiang * fix:spotlessApply Signed-off-by: Jiaru Jiang * fix:reuse and check Signed-off-by: Jiaru Jiang * fix:reuse executePPLAndParseResult Signed-off-by: Jiaru Jiang * fix:constant Signed-off-by: Jiaru Jiang * fix:add error log Signed-off-by: Jiaru Jiang * fix:remove meaningless error log Signed-off-by: Jiaru Jiang * fix:throw exception instead of returning in-completed PPL Signed-off-by: Jiaru Jiang * fix:use ActionListener Signed-off-by: Jiaru Jiang * fix:remove redundant validate methods Signed-off-by: Jiaru Jiang * fix:magic number Signed-off-by: Jiaru Jiang * fix:reduce loops Signed-off-by: Jiaru Jiang * fix:pre-check special cases Signed-off-by: Jiaru Jiang * fix:use NumberUtils Signed-off-by: Jiaru Jiang * fix:split buildQueryFromMap Signed-off-by: Jiaru Jiang * fix:optimize groupNumericKeys Signed-off-by: Jiaru Jiang * fix:update description Signed-off-by: Jiaru Jiang * fix:change log level Signed-off-by: Jiaru Jiang * add: maximum value check for size Signed-off-by: Jiaru Jiang * add: support complete dsl query Signed-off-by: Jiaru Jiang * fix: milliseconds Signed-off-by: Jiaru Jiang * fix: getPPLQueryWithTimeRange Signed-off-by: Jiaru Jiang * delete: duplicate code Signed-off-by: Jiaru Jiang * fix: getUsefulFields Signed-off-by: Jiaru Jiang * fix: use Math.abs Signed-off-by: Jiaru Jiang * fix: verify the actual content of output Signed-off-by: Jiaru Jiang * fix: simplify IT Signed-off-by: Jiaru Jiang * fix: recover version Signed-off-by: Jiaru Jiang --------- Signed-off-by: Jiaru Jiang --- .../java/org/opensearch/agent/ToolPlugin.java | 5 +- .../agent/tools/DataDistributionTool.java | 1535 ++++++++++ .../agent/tools/LogPatternAnalysisTool.java | 232 +- .../agent/tools/utils/PPLExecuteHelper.java | 112 + .../org/opensearch/agent/ToolPluginTests.java | 2 +- .../tools/DataDistributionToolTests.java | 2720 +++++++++++++++++ .../integTest/DataDistributionToolIT.java | 448 +++ ...f_data_distribution_tool_request_body.json | 10 + 8 files changed, 4921 insertions(+), 143 deletions(-) create mode 100644 src/main/java/org/opensearch/agent/tools/DataDistributionTool.java create mode 100644 src/main/java/org/opensearch/agent/tools/utils/PPLExecuteHelper.java create mode 100644 src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java create mode 100644 src/test/java/org/opensearch/integTest/DataDistributionToolIT.java create mode 100644 src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_data_distribution_tool_request_body.json diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 6dcdc829..5de1227d 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -13,6 +13,7 @@ import org.opensearch.agent.tools.CreateAlertTool; import org.opensearch.agent.tools.CreateAnomalyDetectorTool; +import org.opensearch.agent.tools.DataDistributionTool; import org.opensearch.agent.tools.LogPatternAnalysisTool; import org.opensearch.agent.tools.LogPatternTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; @@ -100,6 +101,7 @@ public Collection createComponents( LogPatternTool.Factory.getInstance().init(client, xContentRegistry); WebSearchTool.Factory.getInstance().init(threadPool); LogPatternAnalysisTool.Factory.getInstance().init(client); + DataDistributionTool.Factory.getInstance().init(client); return Collections.emptyList(); } @@ -119,7 +121,8 @@ public List> getToolFactories() { CreateAnomalyDetectorTool.Factory.getInstance(), LogPatternTool.Factory.getInstance(), WebSearchTool.Factory.getInstance(), - LogPatternAnalysisTool.Factory.getInstance() + LogPatternAnalysisTool.Factory.getInstance(), + DataDistributionTool.Factory.getInstance() ); } diff --git a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java new file mode 100644 index 00000000..d0300906 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java @@ -0,0 +1,1535 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.math.NumberUtils; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.agent.tools.utils.PPLExecuteHelper; +import org.opensearch.agent.tools.utils.ToolHelper; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.transport.client.Client; + +import com.google.gson.reflect.TypeToken; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Usage: + * 1. Register agent: + * POST /_plugins/_ml/agents/_register + * { + * "name": "DataDistribution", + * "type": "flow", + * "tools": [ + * { + * "name": "data_distribution_tool", + * "type": "DataDistributionTool", + * "parameters": { + * } + * } + * ] + * } + * 2. Execute agent: + * POST /_plugins/_ml/agents/{agent_id}/_execute + * { + * "parameters": { + * "index": "logs-2025.01.15", + * "timeField": "@timestamp", + * "selectionTimeRangeStart": "2025-01-15 10:00:00", + * "selectionTimeRangeEnd": "2025-01-15 11:00:00", + * "baselineTimeRangeStart": "2025-01-15 08:00:00", + * "baselineTimeRangeEnd": "2025-01-15 09:00:00", + * "size": 1000, + * "queryType": "dsl", + * "filter": ["{'term': {'status': 'error'}}", "{'range': {'response_time': {'gte': 100}}}"], + * "dsl": "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}]}}", + * "ppl": "source index where a=0" + * } + * } + * 3. Result: analysis of data distribution patterns + * { + * "comparisonAnalysis": [ + * { + * "field": "status", + * "divergence": 0.2, + * "topChanges": [ + * { + * "value": "error", + * "selectionPercentage": 0.3, + * "baselinePercentage": 0.1 + * }, + * { + * "value": "success", + * "selectionPercentage": 0.7, + * "baselinePercentage": 0.9 + * } + * ] + * } + * ] + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(DataDistributionTool.TYPE) +public class DataDistributionTool implements Tool { + public static final String TYPE = "DataDistributionTool"; + public static final String STRICT_FIELD = "strict"; + + private static final String DEFAULT_DESCRIPTION = + "This tool analyzes data distribution differences between time ranges or provides single dataset insights."; + private static final String DEFAULT_TIME_FIELD = "@timestamp"; + + private static final String PARAM_INDEX = "index"; + private static final String PARAM_TIME_FIELD = "timeField"; + private static final String PARAM_SELECTION_TIME_RANGE_START = "selectionTimeRangeStart"; + private static final String PARAM_SELECTION_TIME_RANGE_END = "selectionTimeRangeEnd"; + private static final String PARAM_BASELINE_TIME_RANGE_START = "baselineTimeRangeStart"; + private static final String PARAM_BASELINE_TIME_RANGE_END = "baselineTimeRangeEnd"; + private static final String PARAM_SIZE = "size"; + private static final String PARAM_QUERY_TYPE = "queryType"; + private static final String PARAM_FILTER = "filter"; + private static final String PARAM_DSL = "dsl"; + private static final String QUERY_TYPE_PPL = "ppl"; + private static final String QUERY_TYPE_DSL = "dsl"; + private static final String DEFAULT_SIZE = "1000"; + private static final String DATE_FORMAT_PATTERN = "yyyy-MM-dd HH:mm:ss"; + + private static final Set USEFUL_FIELD_TYPES = Set + .of("keyword", "boolean", "text", "byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); + private static final Set NUMBER_FIELD_TYPES = Set + .of("byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); + + private static final int DEFAULT_COMPARISON_RESULT_LIMIT = 10; + private static final int DEFAULT_SINGLE_ANALYSIS_RESULT_LIMIT = 30; + private static final int MIN_CARDINALITY_DIVISOR = 4; + private static final int MIN_CARDINALITY_BASE = 5; + private static final int ID_FIELD_MAX_CARDINALITY = 30; + private static final int DATA_FIELD_MAX_CARDINALITY = 10; + private static final int DATA_FIELD_CARDINALITY_DIVISOR = 2; + private static final int NUMERIC_GROUPING_THRESHOLD = 10; + private static final double PERCENTAGE_MULTIPLIER = 100.0; + private static final int TOP_CHANGES_LIMIT = 10; + private static final int MAX_SIZE_LIMIT = 10000; + + public static final String DEFAULT_INPUT_SCHEMA = """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "Target OpenSearch index name" + }, + "timeField": { + "type": "string", + "description": "Date/time field for filtering" + }, + "selectionTimeRangeStart": { + "type": "string", + "description": "Start time for analysis period" + }, + "selectionTimeRangeEnd": { + "type": "string", + "description": "End time for analysis period" + }, + "baselineTimeRangeStart": { + "type": "string", + "description": "Start time for baseline period (optional)" + }, + "baselineTimeRangeEnd": { + "type": "string", + "description": "End time for baseline period (optional)" + }, + "size": { + "type": "integer", + "description": "Maximum number of documents to analyze (default: 1000)" + }, + "queryType": { + "type": "string", + "description": "Query type: 'ppl' or 'dsl' (default: 'dsl')" + }, + "filter": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Additional DSL query conditions for filtering (optional)" + }, + "dsl": { + "type": "string", + "description": "Complete raw DSL query as JSON string (optional)" + }, + "ppl": { + "type": "string", + "description": "Complete PPL statement without time information (optional)" + } + }, + "required": ["index", "selectionTimeRangeStart", "selectionTimeRangeEnd"], + "additionalProperties": false + } + """; + + public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); + + /** + * Parameter class to hold analysis parameters with validation + */ + private static class AnalysisParameters { + final String index; + final String timeField; + final String selectionTimeRangeStart; + final String selectionTimeRangeEnd; + final String baselineTimeRangeStart; + final String baselineTimeRangeEnd; + final int size; + final String queryType; + final List filter; + final String dsl; + final String ppl; + + /** + * Constructs analysis parameters from input map with default values + * + * @param parameters Input parameter map from user request + */ + AnalysisParameters(Map parameters) { + this.index = parameters.getOrDefault(PARAM_INDEX, ""); + this.timeField = parameters.getOrDefault(PARAM_TIME_FIELD, DEFAULT_TIME_FIELD); + this.selectionTimeRangeStart = parameters.getOrDefault(PARAM_SELECTION_TIME_RANGE_START, ""); + this.selectionTimeRangeEnd = parameters.getOrDefault(PARAM_SELECTION_TIME_RANGE_END, ""); + this.baselineTimeRangeStart = parameters.getOrDefault(PARAM_BASELINE_TIME_RANGE_START, ""); + this.baselineTimeRangeEnd = parameters.getOrDefault(PARAM_BASELINE_TIME_RANGE_END, ""); + + try { + this.size = Integer.parseInt(parameters.getOrDefault(PARAM_SIZE, DEFAULT_SIZE)); + if (this.size > MAX_SIZE_LIMIT) { + throw new IllegalArgumentException("Size parameter exceeds maximum limit of " + MAX_SIZE_LIMIT + ", got: " + this.size); + } + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Invalid 'size' parameter: must be a valid integer, got '" + parameters.get(PARAM_SIZE) + "'" + ); + } + + this.queryType = parameters.getOrDefault(PARAM_QUERY_TYPE, QUERY_TYPE_DSL); + + String filterParam = parameters.getOrDefault(PARAM_FILTER, ""); + if (Strings.isEmpty(filterParam)) { + this.filter = List.of(); + } else { + try { + this.filter = Arrays.asList(gson.fromJson(filterParam, String[].class)); + } catch (Exception e) { + throw new IllegalArgumentException( + "Invalid 'filter' parameter: must be a valid JSON array of strings, got '" + + filterParam + + "'. Example: [\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]" + ); + } + } + + this.dsl = parameters.getOrDefault(PARAM_DSL, ""); + this.ppl = parameters.getOrDefault(QUERY_TYPE_PPL, ""); + } + + /** + * Validates required parameters are present + * + * @throws IllegalArgumentException if required parameters are missing + */ + void validate() { + List missingParams = new ArrayList<>(); + if (Strings.isEmpty(index)) + missingParams.add(PARAM_INDEX); + if (Strings.isEmpty(selectionTimeRangeStart)) + missingParams.add(PARAM_SELECTION_TIME_RANGE_START); + if (Strings.isEmpty(selectionTimeRangeEnd)) + missingParams.add(PARAM_SELECTION_TIME_RANGE_END); + if (Strings.isEmpty(timeField)) + missingParams.add(PARAM_TIME_FIELD); + if (!missingParams.isEmpty()) { + throw new IllegalArgumentException("Missing required parameters: " + String.join(", ", missingParams)); + } + } + + /** + * Checks if baseline time range is provided for comparison analysis + * + * @return true if both baseline start and end times are provided + */ + boolean hasBaselineTime() { + return !Strings.isEmpty(baselineTimeRangeStart) && !Strings.isEmpty(baselineTimeRangeEnd); + } + } + + /** + * Result class for data distribution analysis + */ + private record SummaryDataItem(String field, double divergence, List topChanges) { + } + + /** + * Individual change item for field values + */ + private record ChangeItem(String value, double selectionPercentage, Double baselinePercentage) { + } + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String version; + private Client client; + + /** + * Constructs a DataDistributionTool with the given OpenSearch client + * + * @param client The OpenSearch client for executing queries + */ + public DataDistributionTool(Client client) { + this.client = client; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public Map getAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public void setAttributes(Map map) {} + + @Override + public boolean validate(Map map) { + try { + new AnalysisParameters(map).validate(); + } catch (Exception e) { + log.error("Failed to validate the data distribution analysis parameter: {}", e.getMessage()); + return false; + } + return true; + } + + /** + * Executes data distribution analysis based on provided parameters. + * Supports both single dataset analysis and comparative analysis between time periods. + * + * @param The response type + * @param originalParameters Input parameters for analysis + * @param listener Action listener for handling results or failures + */ + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, DEFAULT_ATTRIBUTES); + log.debug("Starting data distribution analysis with parameters: {}", parameters.keySet()); + AnalysisParameters params = new AnalysisParameters(parameters); + + if (QUERY_TYPE_PPL.equals(params.queryType)) { + executePPLAnalysis(params, listener); + } else { + executeDSLAnalysis(params, listener); + } + } catch (IllegalArgumentException e) { + log.error("Invalid parameters for DataDistributionTool: {}", e.getMessage()); + listener.onFailure(e); + } catch (Exception e) { + log.error("Unexpected error in DataDistributionTool", e); + listener.onFailure(e); + } + } + + /** + * Executes analysis using PPL (Piped Processing Language) queries + * + * @param The response type + * @param params Analysis parameters containing query details + * @param listener Action listener for handling results + */ + private void executePPLAnalysis(AnalysisParameters params, ActionListener listener) { + if (params.hasBaselineTime()) { + fetchPPLComparisonData(params, listener); + } else { + String pplQuery = buildPPLQuery( + params.index, + params.timeField, + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd, + params.size, + params.ppl + ); + + Function, List>> pplResultParser = this::parsePPLResult; + + PPLExecuteHelper.executePPLAndParseResult(client, pplQuery, pplResultParser, ActionListener.wrap(data -> { + try { + analyzeSingleDataset(data, params.index, ActionListener.wrap(result -> { + listener.onResponse((T) gson.toJson(Map.of("singleAnalysis", result))); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + } + } + + /** + * Executes analysis using DSL (Domain Specific Language) queries + * + * @param The response type + * @param params Analysis parameters containing query details + * @param listener Action listener for handling results + */ + private void executeDSLAnalysis(AnalysisParameters params, ActionListener listener) { + if (params.hasBaselineTime()) { + fetchComparisonData(params, listener); + } else { + getSingleDataDistribution(params, listener); + } + } + + /** + * Fetches data for both selection and baseline time ranges for comparison analysis + * + * @param The response type + * @param params Analysis parameters containing time ranges + * @param listener Action listener for handling comparison results + */ + private void fetchComparisonData(AnalysisParameters params, ActionListener listener) { + fetchIndexData(params.selectionTimeRangeStart, params.selectionTimeRangeEnd, params, ActionListener.wrap(selectionData -> { + fetchIndexData(params.baselineTimeRangeStart, params.baselineTimeRangeEnd, params, ActionListener.wrap(baselineData -> { + try { + if (selectionData.isEmpty()) { + throw new IllegalStateException("No data found for selection time range"); + } + if (baselineData.isEmpty()) { + throw new IllegalStateException("No data found for baseline time range"); + } + getComparisonDataDistribution(selectionData, baselineData, params.index, ActionListener.wrap(result -> { + listener.onResponse((T) gson.toJson(Map.of("comparisonAnalysis", result))); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + }, listener::onFailure)); + } + + /** + * Performs single dataset distribution analysis for the selection time range + * + * @param The response type + * @param params Analysis parameters containing selection time range + * @param listener Action listener for handling single analysis results + */ + private void getSingleDataDistribution(AnalysisParameters params, ActionListener listener) { + fetchIndexData(params.selectionTimeRangeStart, params.selectionTimeRangeEnd, params, ActionListener.wrap(data -> { + try { + if (data.isEmpty()) { + throw new IllegalStateException("No data found for selection time range"); + } + analyzeSingleDataset(data, params.index, ActionListener.wrap(result -> { + listener.onResponse((T) gson.toJson(Map.of("singleAnalysis", result))); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + } + + /** + * Formats time string to ISO 8601 format for OpenSearch compatibility + * + * @param timeString Input time string + * @return Formatted time string in ISO 8601 format + * @throws DateTimeParseException if time string cannot be parsed + */ + private String formatTimeString(String timeString) throws DateTimeParseException { + log.debug("Attempting to parse time string: {}", timeString); + + // Try parsing with zone first + try { + if (timeString.endsWith("Z")) { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss'Z'", Locale.ROOT); + ZonedDateTime dateTime = ZonedDateTime.parse(timeString, formatter.withZone(ZoneOffset.UTC)); + return dateTime.format(DateTimeFormatter.ISO_INSTANT); + } + } catch (DateTimeParseException e) { + log.debug("Failed to parse as UTC time: {}", e.getMessage()); + } + + // Try parsing as local time without zone + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DATE_FORMAT_PATTERN, Locale.ROOT); + LocalDateTime localDateTime = LocalDateTime.parse(timeString, formatter); + ZonedDateTime zonedDateTime = localDateTime.atOffset(ZoneOffset.UTC).toZonedDateTime(); + return zonedDateTime.format(DateTimeFormatter.ISO_INSTANT); + } catch (DateTimeParseException e) { + log.debug("Failed to parse as local time: {}", e.getMessage()); + } + + // Try ISO format + try { + ZonedDateTime dateTime = ZonedDateTime.parse(timeString); + return dateTime.format(DateTimeFormatter.ISO_INSTANT); + } catch (DateTimeParseException e) { + log.debug("Failed to parse as ISO format: {}", e.getMessage()); + } + + throw new DateTimeParseException("Unable to parse time string: " + timeString, timeString, 0); + } + + /** + * Fetches data from the specified index within the given time range + * + * @param startTime Start time for data retrieval + * @param endTime End time for data retrieval + * @param params Analysis parameters containing index and field information + * @param listener Action listener for handling retrieved data + */ + private void fetchIndexData( + String startTime, + String endTime, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + String formattedStartTime = formatTimeString(startTime); + String formattedEndTime = formatTimeString(endTime); + BoolQueryBuilder query; + + // Use raw DSL query if provided + if (!Strings.isEmpty(params.dsl)) { + try { + Map dslMap = gson.fromJson(params.dsl, new TypeToken>() { + }.getType()); + query = QueryBuilders.boolQuery(); + + // Handle DSL query structure - check if it has "query" wrapper + if (dslMap.containsKey("query")) { + @SuppressWarnings("unchecked") + Map queryMap = (Map) dslMap.get("query"); + log.debug("Processing DSL query with wrapper: {}", queryMap); + + // Build the DSL query directly into the main query + buildQueryFromMap(queryMap, query); + + // Add time range filter + query.filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); + } else { + log.debug("Processing DSL query without wrapper: {}", dslMap); + buildQueryFromMap(dslMap, query); + // Add time range filter to the raw DSL query + query.filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); + } + + log.debug("Final DSL query: {}", query.toString()); + } catch (Exception e) { + log.warn("Failed to parse raw DSL query: {}, falling back to time range only", params.dsl, e); + query = QueryBuilders + .boolQuery() + .filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); + } + } else { + query = QueryBuilders + .boolQuery() + .filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); + + // Add additional filters if provided + if (!params.filter.isEmpty()) { + for (String filterStr : params.filter) { + try { + Map filterMap = gson.fromJson(filterStr, new TypeToken>() { + }.getType()); + BoolQueryBuilder filterQuery = QueryBuilders.boolQuery(); + buildQueryFromMap(filterMap, filterQuery); + query.must(filterQuery); + } catch (Exception e) { + log.warn("Failed to parse filter parameter: {}", filterStr, e); + } + } + } + } + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query).size(params.size); + + SearchRequest request = new SearchRequest(params.index).source(sourceBuilder); + + client.search(request, ActionListener.wrap(response -> { + List> data = Arrays + .stream(response.getHits().getHits()) + .map(SearchHit::getSourceAsMap) + .collect(Collectors.toList()); + listener.onResponse(data); + }, listener::onFailure)); + } catch (Exception e) { + log.error("Failed to format time strings: {}", e.getMessage()); + listener.onFailure(new IllegalArgumentException("Invalid time format: " + e.getMessage(), e)); + } + } + + /** + * Fetches data for both selection and baseline time ranges using PPL for comparison analysis + * + * @param The response type + * @param params Analysis parameters containing time ranges + * @param listener Action listener for handling comparison results + */ + private void fetchPPLComparisonData(AnalysisParameters params, ActionListener listener) { + String selectionQuery = buildPPLQuery( + params.index, + params.timeField, + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd, + params.size, + params.ppl + ); + String baselineQuery = buildPPLQuery( + params.index, + params.timeField, + params.baselineTimeRangeStart, + params.baselineTimeRangeEnd, + params.size, + params.ppl + ); + + Function, List>> pplResultParser = this::parsePPLResult; + + PPLExecuteHelper.executePPLAndParseResult(client, selectionQuery, pplResultParser, ActionListener.wrap(selectionData -> { + PPLExecuteHelper.executePPLAndParseResult(client, baselineQuery, pplResultParser, ActionListener.wrap(baselineData -> { + try { + if (selectionData.isEmpty()) { + throw new IllegalStateException("No data found for selection time range"); + } + if (baselineData.isEmpty()) { + throw new IllegalStateException("No data found for baseline time range"); + } + getComparisonDataDistribution(selectionData, baselineData, params.index, ActionListener.wrap(result -> { + listener.onResponse((T) gson.toJson(Map.of("comparisonAnalysis", result))); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + }, listener::onFailure)); + } + + /** + * Converts time string to PPL format (yyyy-MM-dd HH:mm:ss) + * + * @param timeString Input time string + * @return Formatted time string for PPL + */ + private String formatTimeForPPL(String timeString) { + try { + // Parse ISO format and convert to PPL format + ZonedDateTime dateTime = ZonedDateTime.parse(timeString); + return dateTime.format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS", Locale.ROOT)); + } catch (DateTimeParseException e) { + // Try parsing as local time without zone + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DATE_FORMAT_PATTERN, Locale.ROOT); + LocalDateTime localDateTime = LocalDateTime.parse(timeString, formatter); + return localDateTime.format(formatter); + } catch (DateTimeParseException e2) { + // Return original if parsing fails + return timeString; + } + } + } + + /** + * Adds time range filter to PPL query + * + * @param query PPL query string (can be empty) + * @param startTime Start time for filtering + * @param endTime End time for filtering + * @param timeField Time field name + * @return PPL query with time range filter added + */ + private String getPPLQueryWithTimeRange(String query, String startTime, String endTime, String timeField) { + if (Strings.isEmpty(query)) { + throw new IllegalArgumentException("PPL query cannot be empty"); + } + if (Strings.isEmpty(timeField)) { + return query; + } + + String formattedStartTime = formatTimeForPPL(startTime); + String formattedEndTime = formatTimeForPPL(endTime); + String timePredicate = String + .format(Locale.ROOT, "`%s` >= '%s' AND `%s` <= '%s'", timeField, formattedStartTime, timeField, formattedEndTime); + + String[] commands = query.split("\\|"); + List commandList = new ArrayList<>(); + + // Always insert time filter right after first command (safest approach) + commandList.add(commands[0].trim()); + commandList.add("WHERE " + timePredicate); + + // Add remaining commands + for (int i = 1; i < commands.length; i++) { + String cmd = commands[i].trim(); + if (!cmd.isEmpty()) { + commandList.add(cmd); + } + } + + return String.join(" | ", commandList); + } + + /** + * Builds PPL query string for data retrieval within specified time range + * + * @param index Index name + * @param timeField Time field name + * @param startTime Start time for query + * @param endTime End time for query + * @param size Maximum number of documents + * @param customPpl Custom PPL statement (optional) + * @return Formatted PPL query string + */ + private String buildPPLQuery(String index, String timeField, String startTime, String endTime, int size, String customPpl) { + String baseQuery; + + if (!Strings.isEmpty(customPpl)) { + baseQuery = getPPLQueryWithTimeRange(customPpl, startTime, endTime, timeField); + } else { + baseQuery = getPPLQueryWithTimeRange(String.format(Locale.ROOT, "source=%s", index), startTime, endTime, timeField); + } + + return baseQuery + String.format(Locale.ROOT, " | head %d", size); + } + + /** + * Analyzes and compares data distributions between selection and baseline datasets + * + * @param selectionData Data from the selection time period + * @param baselineData Data from the baseline time period + * @param index Index name for field mapping retrieval + * @param listener Action listener for handling comparison results + */ + private void getComparisonDataDistribution( + List> selectionData, + List> baselineData, + String index, + ActionListener> listener + ) { + getFieldTypes(index, ActionListener.wrap(fieldTypes -> { + try { + List usefulFields = getUsefulFields(selectionData, fieldTypes); + Set numberFields = getNumberFields(fieldTypes); + List analyses = new ArrayList<>(); + + for (String field : usefulFields) { + Map selectionDist = calculateFieldDistribution(selectionData, field); + Map baselineDist = calculateFieldDistribution(baselineData, field); + + if (numberFields.contains(field)) { + GroupedDistributions grouped = groupNumericKeys(selectionDist, baselineDist); + selectionDist = grouped.groupedSelectionDist(); + baselineDist = grouped.groupedBaselineDist(); + } + + double divergence = calculateMaxDifference(selectionDist, baselineDist); + analyses.add(new FieldAnalysis(field, divergence, selectionDist, baselineDist)); + } + + analyses.sort(Comparator.comparingDouble((FieldAnalysis a) -> a.divergence).reversed()); + listener.onResponse(formatComparisonSummary(analyses, DEFAULT_COMPARISON_RESULT_LIMIT)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + } + + /** + * Analyzes distribution patterns within a single dataset + * + * @param data Dataset to analyze + * @param index Index name for field mapping retrieval + * @param listener Action listener for handling single analysis results + */ + private void analyzeSingleDataset(List> data, String index, ActionListener> listener) { + getFieldTypes(index, ActionListener.wrap(fieldTypes -> { + try { + List usefulFields = getUsefulFields(data, fieldTypes); + Set numberFields = getNumberFields(fieldTypes); + List analyses = new ArrayList<>(); + + for (String field : usefulFields) { + Map selectionDist = calculateFieldDistribution(data, field); + Map baselineDist = new HashMap<>(); + + if (numberFields.contains(field)) { + GroupedDistributions grouped = groupNumericKeys(selectionDist, baselineDist); + selectionDist = grouped.groupedSelectionDist(); + } + + double divergence = calculateMaxDifference(selectionDist, baselineDist); + analyses.add(new FieldAnalysis(field, divergence, selectionDist, baselineDist)); + } + + analyses.sort(Comparator.comparingDouble((FieldAnalysis a) -> a.divergence).reversed()); + listener.onResponse(formatComparisonSummary(analyses, DEFAULT_SINGLE_ANALYSIS_RESULT_LIMIT)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + } + + /** + * Internal record for field analysis results + */ + private record FieldAnalysis(String field, double divergence, Map selectionDist, Map baselineDist) { + } + + /** + * Record for grouped numeric distributions + */ + private record GroupedDistributions(Map groupedSelectionDist, Map groupedBaselineDist) { + } + + /** + * Gets field type mappings from index + * + * @param index Index name for mapping retrieval + * @param listener Action listener for handling field types result + */ + private void getFieldTypes(String index, ActionListener> listener) { + try { + GetMappingsRequest getMappingsRequest = new GetMappingsRequest().indices(index); + client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(response -> { + try { + Map mappings = response.getMappings(); + if (mappings.isEmpty()) { + listener.onResponse(Map.of()); + return; + } + + MappingMetadata mappingMetadata = mappings.values().iterator().next(); + Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); + if (mappingSource == null) { + listener.onResponse(Map.of()); + return; + } + + Map fieldsToType = new HashMap<>(); + ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", true); + listener.onResponse(fieldsToType); + } catch (Exception e) { + log.error("Failed to process field types for index: {}", index, e); + listener.onResponse(Map.of()); + } + }, e -> { + log.error("Failed to get field types for index: {}", index, e); + listener.onResponse(Map.of()); + })); + } catch (Exception e) { + log.error("Failed to create getMappings request for index: {}", index, e); + listener.onResponse(Map.of()); + } + } + + /** + * Identifies useful fields for analysis based on index mapping and data characteristics + * + * @param data Sample data for cardinality analysis + * @param fieldTypes Map of field names to their types + * @return List of field names suitable for distribution analysis + */ + private List getUsefulFields(List> data, Map fieldTypes) { + if (fieldTypes.isEmpty()) { + log.warn("No field types available, using data-based field detection"); + return getFieldsFromData(data); + } + + Set keywordFields = new HashSet<>(); + Set numberFields = new HashSet<>(); + + for (Map.Entry entry : fieldTypes.entrySet()) { + String fieldType = entry.getValue(); + String fieldName = entry.getKey(); + + if (USEFUL_FIELD_TYPES.contains(fieldType)) { + keywordFields.add(fieldName); + } + if (NUMBER_FIELD_TYPES.contains(fieldType)) { + numberFields.add(fieldName); + } + } + + Set normalizedFields = keywordFields + .stream() + .map(field -> field.endsWith(".keyword") ? field.replace(".keyword", "") : field) + .collect(Collectors.toSet()); + + Map> fieldValueSets = new HashMap<>(); + normalizedFields.forEach(field -> fieldValueSets.put(field, new HashSet<>())); + + int maxCardinality = Math.max(MIN_CARDINALITY_BASE, data.size() / MIN_CARDINALITY_DIVISOR); + + data.forEach(doc -> { + normalizedFields.forEach(field -> { + Object value = getFlattenedValue(doc, field); + if (value != null) { + fieldValueSets.get(field).add(gson.toJson(value)); + } + }); + }); + + return normalizedFields.stream().filter(field -> { + int cardinality = fieldValueSets.get(field).size(); + if (field.toLowerCase(Locale.ROOT).endsWith("id")) { + return cardinality <= ID_FIELD_MAX_CARDINALITY && cardinality > 0; + } else if (numberFields.contains(field)) { + return true; + } + return cardinality <= maxCardinality && cardinality > 0; + }).collect(Collectors.toList()); + } + + /** + * Extracts nested field values from document using dot notation + * + * @param doc Document map to extract value from + * @param field Field path using dot notation (e.g., "user.name") + * @return Field value or null if not found + */ + private Object getFlattenedValue(Map doc, String field) { + String[] parts = field.split("\\."); + Object current = doc; + + for (String part : parts) { + if (current instanceof Map) { + current = ((Map) current).get(part); + } else if (current instanceof List) { + return gson.toJson(current); + } else { + return null; + } + } + + return current; + } + + /** + * Calculates distribution of values for a specific field across the dataset + * + * @param data Dataset to analyze + * @param field Field name to calculate distribution for + * @return Map of field values to their relative frequencies + */ + private Map calculateFieldDistribution(List> data, String field) { + if (data == null || data.isEmpty()) { + return new HashMap<>(); + } + + Map counts = new HashMap<>(); + + for (Map doc : data) { + Object value = getFlattenedValue(doc, field); + if (value != null) { + String strValue = String.valueOf(value); + counts.merge(strValue, 1, Integer::sum); + } + } + return counts.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> (double) entry.getValue() / data.size())); + } + + /** + * Calculates maximum difference between selection and baseline distributions + * + * @param selectionDist Selection period distribution + * @param baselineDist Baseline period distribution + * @return Maximum difference value across all field values + */ + private double calculateMaxDifference(Map selectionDist, Map baselineDist) { + Set allKeys = new HashSet<>(selectionDist.keySet()); + allKeys.addAll(baselineDist.keySet()); + + if (allKeys.isEmpty()) { + return Double.NEGATIVE_INFINITY; + } + return allKeys.stream().mapToDouble(key -> { + double selectionVal = selectionDist.getOrDefault(key, 0.0); + double baselineVal = baselineDist.getOrDefault(key, 0.0); + return Math.abs(selectionVal - baselineVal); + }).max().orElse(Double.NEGATIVE_INFINITY); + } + + /** + * Extracts field names from sample data when mapping is not available + * + * @param data Sample data to analyze + * @return List of field names suitable for analysis + */ + private List getFieldsFromData(List> data) { + if (data.isEmpty()) { + return List.of(); + } + + Set allFields = new HashSet<>(); + for (Map doc : data) { + allFields.addAll(doc.keySet()); + } + + // Filter out timestamp and other non-useful fields + return allFields + .stream() + .filter(field -> !field.equals("@timestamp") && !field.equals("_id") && !field.equals("_index")) + .filter(field -> { + // Check cardinality - exclude high cardinality fields + Set values = new HashSet<>(); + for (Map doc : data) { + Object value = doc.get(field); + if (value != null) { + values.add(String.valueOf(value)); + } + } + int cardinality = values.size(); + return cardinality > 0 && cardinality <= Math.max(DATA_FIELD_MAX_CARDINALITY, data.size() / DATA_FIELD_CARDINALITY_DIVISOR); + }) + .collect(Collectors.toList()); + } + + /** + * Gets number fields from field type mappings + * + * @param fieldTypes Map of field names to their types + * @return Set of number field names + */ + private Set getNumberFields(Map fieldTypes) { + return fieldTypes + .entrySet() + .stream() + .filter(entry -> NUMBER_FIELD_TYPES.contains(entry.getValue())) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + } + + /** + * Groups numeric keys and merges counts + * + * @param selectionDist Selection distribution + * @param baselineDist Baseline distribution + * @return Grouped distributions + */ + private GroupedDistributions groupNumericKeys(Map selectionDist, Map baselineDist) { + Set allKeys = new HashSet<>(selectionDist.keySet()); + allKeys.addAll(baselineDist.keySet()); + + if (allKeys.size() <= NUMERIC_GROUPING_THRESHOLD || allKeys.stream().anyMatch(key -> !NumberUtils.isCreatable(key))) { + return new GroupedDistributions(selectionDist, baselineDist); + } + + List numericKeys = allKeys.stream().map(Double::parseDouble).sorted().collect(Collectors.toList()); + Function getGroupLabel = getDoubleStringFunction(numericKeys); + // Group the keys and aggregate the values + Map groupedSelectionDist = numericKeys + .stream() + .collect( + Collectors + .groupingBy(getGroupLabel, Collectors.summingDouble(numKey -> selectionDist.getOrDefault(String.valueOf(numKey), 0.0))) + ); + Map groupedBaselineDist = numericKeys + .stream() + .collect( + Collectors + .groupingBy(getGroupLabel, Collectors.summingDouble(numKey -> baselineDist.getOrDefault(String.valueOf(numKey), 0.0))) + ); + // Ensure all groups are present in both maps (in case some have zero values) + Set allGroups = new HashSet<>(); + allGroups.addAll(groupedSelectionDist.keySet()); + allGroups.addAll(groupedBaselineDist.keySet()); + allGroups.forEach(group -> { + groupedSelectionDist.putIfAbsent(group, 0.0); + groupedBaselineDist.putIfAbsent(group, 0.0); + }); + + return new GroupedDistributions(groupedSelectionDist, groupedBaselineDist); + } + + private static Function getDoubleStringFunction(List numericKeys) { + double min = numericKeys.get(0); + double max = numericKeys.get(numericKeys.size() - 1); + double range = max - min; + int numGroups = 5; + double groupSize = range / numGroups; + // Create a function to determine which group a key belongs to + Function getGroupLabel = numKey -> { + int groupIndex = numKey == max ? numGroups - 1 : (int) ((numKey - min) / groupSize); + double lowerBound = min + groupIndex * groupSize; + double upperBound = groupIndex == numGroups - 1 ? max : min + (groupIndex + 1) * groupSize; + return String.format(Locale.ROOT, "%.1f-%.1f", lowerBound, upperBound); + }; + return getGroupLabel; + } + + /** + * Formats field analysis results into summary data items + * + * @param differences List of field analysis results + * @param maxResults Maximum number of results to return + * @return Formatted list of summary data items + */ + private List formatComparisonSummary(List differences, int maxResults) { + return differences.stream().filter(diff -> diff.divergence > 0).limit(maxResults).map(diff -> { + Set allKeys = new HashSet<>(diff.selectionDist.keySet()); + allKeys.addAll(diff.baselineDist.keySet()); + + List changes = allKeys.stream().map(value -> { + double selectionPercentage = Math.round(diff.selectionDist.getOrDefault(value, 0.0) * PERCENTAGE_MULTIPLIER) + / PERCENTAGE_MULTIPLIER; + double baselinePercentage = Math.round(diff.baselineDist.getOrDefault(value, 0.0) * PERCENTAGE_MULTIPLIER) + / PERCENTAGE_MULTIPLIER; + return new ChangeItem(value, selectionPercentage, baselinePercentage); + }).collect(Collectors.toList()); + + List topChanges = changes + .stream() + .sorted( + (a, b) -> Double + .compare( + Math.max(b.baselinePercentage, b.selectionPercentage), + Math.max(a.baselinePercentage, a.selectionPercentage) + ) + ) + .limit(TOP_CHANGES_LIMIT) + .collect(Collectors.toList()); + + return new SummaryDataItem(diff.field, diff.divergence, topChanges); + }).collect(Collectors.toList()); + } + + /** + * Builds query conditions from filter map for DSL queries + * + * @param filterMap Filter conditions as map + * @param queryBuilder Query builder to add conditions to + */ + private void buildQueryFromMap(Map filterMap, BoolQueryBuilder queryBuilder) { + log.debug("Building query from map: {}", filterMap); + + for (Map.Entry entry : filterMap.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + + log.debug("Processing query key: {}, value: {}", key, value); + + // Handle special query types + switch (key) { + case "match_all" -> { + // {"match_all": {}} + log.debug("Adding match_all query"); + queryBuilder.must(QueryBuilders.matchAllQuery()); + } + case "match_none" -> { + // {"match_none": {}} + log.debug("Adding match_none query"); + queryBuilder.mustNot(QueryBuilders.matchAllQuery()); + } + case "bool" -> { + if (value instanceof Map) { + log.debug("Processing bool query: {}", value); + processBoolQuery((Map) value, queryBuilder); + } + } + case "term" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + log.debug("Adding term query: {}", valueMap); + // {"term": {"field": "value"}} + for (Map.Entry termEntry : valueMap.entrySet()) { + log.debug("Term query - field: {}, value: {}", termEntry.getKey(), termEntry.getValue()); + queryBuilder.must(QueryBuilders.termQuery(termEntry.getKey(), termEntry.getValue())); + } + } + } + case "wildcard" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + log.debug("Adding wildcard query: {}", valueMap); + // {"wildcard": {"field": "pattern"}} + for (Map.Entry wildcardEntry : valueMap.entrySet()) { + log.debug("Wildcard query - field: {}, pattern: {}", wildcardEntry.getKey(), wildcardEntry.getValue()); + queryBuilder.must(QueryBuilders.wildcardQuery(wildcardEntry.getKey(), wildcardEntry.getValue().toString())); + } + } + } + case "range" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + // {"range": {"field": {"gte": 1, "lte": 10}}} + for (Map.Entry rangeEntry : valueMap.entrySet()) { + String field = rangeEntry.getKey(); + Object rangeValue = rangeEntry.getValue(); + if (rangeValue instanceof Map) { + processRangeQuery(field, rangeValue, queryBuilder); + } + } + } + } + case "match" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + // {"match": {"field": "value"}} + for (Map.Entry matchEntry : valueMap.entrySet()) { + queryBuilder.must(QueryBuilders.matchQuery(matchEntry.getKey(), matchEntry.getValue())); + } + } + } + case "match_phrase" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + // {"match_phrase": {"field": "value"}} + for (Map.Entry matchPhraseEntry : valueMap.entrySet()) { + queryBuilder.must(QueryBuilders.matchPhraseQuery(matchPhraseEntry.getKey(), matchPhraseEntry.getValue())); + } + } + } + case "prefix" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + // {"prefix": {"field": "value"}} + for (Map.Entry prefixEntry : valueMap.entrySet()) { + queryBuilder.must(QueryBuilders.prefixQuery(prefixEntry.getKey(), prefixEntry.getValue().toString())); + } + } + } + case "exists" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + // {"exists": {"field": "fieldname"}} + Object fieldValue = valueMap.get("field"); + if (fieldValue != null) { + queryBuilder.must(QueryBuilders.existsQuery(fieldValue.toString())); + } + } + } + case "regexp" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + // {"regexp": {"field": "pattern"}} + for (Map.Entry regexpEntry : valueMap.entrySet()) { + queryBuilder.must(QueryBuilders.regexpQuery(regexpEntry.getKey(), regexpEntry.getValue().toString())); + } + } + } + case "terms" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + // {"terms": {"field": ["value1", "value2"]}} + for (Map.Entry termsEntry : valueMap.entrySet()) { + if (termsEntry.getValue() instanceof List) { + queryBuilder.must(QueryBuilders.termsQuery(termsEntry.getKey(), (List) termsEntry.getValue())); + } + } + } + } + case "multi_match" -> { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + Object queryValue = valueMap.get("query"); + Object fieldsValue = valueMap.get("fields"); + if (queryValue != null && fieldsValue instanceof List) { + @SuppressWarnings("unchecked") + List fields = (List) fieldsValue; + queryBuilder.must(QueryBuilders.multiMatchQuery(queryValue, fields.toArray(new String[0]))); + } + } + } + default -> { + // Handle direct field-value pairs or unknown query types + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + // This might be a field with nested operators like {"field": {"term": "value"}} + processNestedQuery(key, valueMap, queryBuilder); + } else { + // Direct field-value mapping + queryBuilder.must(QueryBuilders.termQuery(key, value)); + } + } + } + } + } + + /** + * Processes bool query conditions + * + * @param boolMap Bool query conditions + * @param queryBuilder Query builder to add conditions to + */ + private void processBoolQuery(Map boolMap, BoolQueryBuilder queryBuilder) { + for (Map.Entry boolEntry : boolMap.entrySet()) { + String boolType = boolEntry.getKey(); + Object boolValue = boolEntry.getValue(); + + if (boolValue instanceof List) { + @SuppressWarnings("unchecked") + List> clauses = (List>) boolValue; + for (Map clause : clauses) { + BoolQueryBuilder subQuery = QueryBuilders.boolQuery(); + buildQueryFromMap(clause, subQuery); + switch (boolType) { + case "must" -> queryBuilder.must(subQuery); + case "should" -> queryBuilder.should(subQuery); + case "must_not" -> queryBuilder.mustNot(subQuery); + case "filter" -> queryBuilder.filter(subQuery); + default -> log.warn("Unsupported bool query type: {}", boolType); + } + } + } + } + } + + /** + * Processes nested query conditions for a field + * + * @param field Field name + * @param nestedMap Nested query conditions + * @param queryBuilder Query builder to add conditions to + */ + private void processNestedQuery(String field, Map nestedMap, BoolQueryBuilder queryBuilder) { + for (Map.Entry nestedEntry : nestedMap.entrySet()) { + String operator = nestedEntry.getKey(); + Object operatorValue = nestedEntry.getValue(); + + switch (operator) { + case "term" -> queryBuilder.must(QueryBuilders.termQuery(field, operatorValue)); + case "range" -> processRangeQuery(field, operatorValue, queryBuilder); + case "match" -> queryBuilder.must(QueryBuilders.matchQuery(field, operatorValue)); + case "match_phrase" -> queryBuilder.must(QueryBuilders.matchPhraseQuery(field, operatorValue)); + case "prefix" -> queryBuilder.must(QueryBuilders.prefixQuery(field, operatorValue.toString())); + case "wildcard" -> processWildcardQuery(field, operatorValue, queryBuilder); + case "exists" -> queryBuilder.must(QueryBuilders.existsQuery(field)); + case "regexp" -> processRegexpQuery(field, operatorValue, queryBuilder); + default -> { + // Handle direct field-value mapping for nested structures + if (operatorValue instanceof Map) { + @SuppressWarnings("unchecked") + Map valueMap = (Map) operatorValue; + BoolQueryBuilder nestedQuery = QueryBuilders.boolQuery(); + buildQueryFromMap(Map.of(operator, valueMap), nestedQuery); + queryBuilder.must(nestedQuery); + } else { + log.warn("Unsupported query operator: {}", operator); + } + } + } + } + } + + /** + * Processes range query conditions + * + * @param field Field name + * @param operatorValue Range conditions + * @param queryBuilder Query builder to add conditions to + */ + private void processRangeQuery(String field, Object operatorValue, BoolQueryBuilder queryBuilder) { + if (!(operatorValue instanceof Map)) { + return; + } + + @SuppressWarnings("unchecked") + Map rangeMap = (Map) operatorValue; + RangeQueryBuilder rangeQuery = QueryBuilders.rangeQuery(field); + + rangeMap.forEach((rangeOp, rangeVal) -> { + switch (rangeOp) { + case "gte" -> rangeQuery.gte(rangeVal); + case "lte" -> rangeQuery.lte(rangeVal); + case "gt" -> rangeQuery.gt(rangeVal); + case "lt" -> rangeQuery.lt(rangeVal); + } + }); + + queryBuilder.must(rangeQuery); + } + + /** + * Processes wildcard query conditions + * + * @param field Field name + * @param operatorValue Wildcard conditions + * @param queryBuilder Query builder to add conditions to + */ + private void processWildcardQuery(String field, Object operatorValue, BoolQueryBuilder queryBuilder) { + if (operatorValue instanceof Map) { + @SuppressWarnings("unchecked") + Map wildcardMap = (Map) operatorValue; + Object wildcardValue = wildcardMap.get("value"); + if (wildcardValue != null) { + queryBuilder.must(QueryBuilders.wildcardQuery(field, wildcardValue.toString())); + } + } else { + queryBuilder.must(QueryBuilders.wildcardQuery(field, operatorValue.toString())); + } + } + + /** + * Processes regexp query conditions + * + * @param field Field name + * @param operatorValue Regexp conditions + * @param queryBuilder Query builder to add conditions to + */ + private void processRegexpQuery(String field, Object operatorValue, BoolQueryBuilder queryBuilder) { + if (operatorValue instanceof Map) { + @SuppressWarnings("unchecked") + Map regexpMap = (Map) operatorValue; + Object regexpValue = regexpMap.get("value"); + if (regexpValue != null) { + queryBuilder.must(QueryBuilders.regexpQuery(field, regexpValue.toString())); + } + } else { + queryBuilder.must(QueryBuilders.regexpQuery(field, operatorValue.toString())); + } + } + + /** + * Parses PPL query result into list of documents + * + * @param pplResult PPL query result containing datarows and schema + * @return List of documents as maps + */ + private List> parsePPLResult(Map pplResult) { + Object datarowsObj = pplResult.get("datarows"); + Object schemaObj = pplResult.get("schema"); + + if (!(datarowsObj instanceof List) || !(schemaObj instanceof List)) { + return List.of(); + } + + @SuppressWarnings("unchecked") + List> dataRows = (List>) datarowsObj; + @SuppressWarnings("unchecked") + List> schema = (List>) schemaObj; + + List> result = new ArrayList<>(); + for (List row : dataRows) { + Map doc = new HashMap<>(); + for (int i = 0; i < Math.min(row.size(), schema.size()); i++) { + String columnName = (String) schema.get(i).get("name"); + if (columnName != null) { + doc.put(columnName, row.get(i)); + } + } + result.add(doc); + } + return result; + } + + /** + * Factory class for creating DataDistributionTool instances + */ + public static class Factory implements Tool.Factory { + private Client client; + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (DataDistributionTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public DataDistributionTool create(Map map) { + return new DataDistributionTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java index 0f63ff6e..71508d95 100644 --- a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java +++ b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java @@ -5,7 +5,6 @@ package org.opensearch.agent.tools; -import static org.opensearch.agent.tools.utils.ToolHelper.getPPLTransportActionListener; import static org.opensearch.agent.tools.utils.clustering.HierarchicalAgglomerativeClustering.calculateCosineSimilarity; import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -28,21 +27,15 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import org.json.JSONObject; +import org.opensearch.agent.tools.utils.PPLExecuteHelper; import org.opensearch.agent.tools.utils.clustering.ClusteringHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.ToolUtils; -import org.opensearch.sql.plugin.transport.PPLQueryAction; -import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; -import org.opensearch.sql.ppl.domain.PPLQueryRequest; import org.opensearch.transport.client.Client; -import com.google.common.collect.ImmutableMap; -import com.google.gson.reflect.TypeToken; - import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -386,7 +379,7 @@ private void executePPL(String ppl, ActionListener listen return new PatternAnalysisResult(tracePatternMap, patternCountMap, patternVectors); }; - executePPLAndParseResult(ppl, rowParser, listener); + PPLExecuteHelper.executePPLAndParseResult(client, ppl, PPLExecuteHelper.dataRowsParser(rowParser), listener); } private String buildLogPatternPPL( @@ -638,53 +631,76 @@ private void logPatternDiffAnalysis(AnalysisParameters params, ActionListene }; log.debug("Executing base time range pattern PPL: {}", baseTimeRangeLogPatternPPL); - executePPLAndParseResult(baseTimeRangeLogPatternPPL, dataRowsParser, ActionListener.wrap(basePatterns -> { - try { - mergeSimilarPatterns(basePatterns); - - log.debug("Base patterns processed: {} patterns", basePatterns.size()); - - // Step 2: Generate log patterns for selection time range - String selectionTimeRangeLogPatternPPL = buildLogPatternPPL( - params.index, - params.timeField, - params.logFieldName, - params.selectionTimeRangeStart, - params.selectionTimeRangeEnd - ); - - log.debug("Executing selection time range pattern PPL: {}", selectionTimeRangeLogPatternPPL); - executePPLAndParseResult(selectionTimeRangeLogPatternPPL, dataRowsParser, ActionListener.wrap(selectionPatterns -> { - mergeSimilarPatterns(selectionPatterns); - - log.debug("Selection patterns processed: {} patterns", selectionPatterns.size()); - - // Step 3: Calculate pattern differences - List patternDifferences = calculatePatternDifferences(basePatterns, selectionPatterns); - - // Step 4: Sort the difference and get top 10 - List topDiffs = Stream - .concat( - patternDifferences.stream().filter(diff -> !Objects.isNull(diff.lift)).sorted(comparator).limit(10), - patternDifferences.stream().filter(diff -> Objects.isNull(diff.lift)).sorted(comparator).limit(10) - ) - .collect(Collectors.toList()); - - Map finalResult = new HashMap<>(); - finalResult.put("patternMapDifference", topDiffs); - - log.debug("Pattern analysis completed: {} differences found", patternDifferences.size()); - listener.onResponse((T) gson.toJson(finalResult)); - }, listener::onFailure)); - - } catch (Exception e) { - log.error("Failed to process base pattern response", e); - listener.onFailure(new RuntimeException("Failed to process base patterns: " + e.getMessage(), e)); - } - }, error -> { - log.error("Failed to execute pattern analysis", error); - listener.onFailure(new RuntimeException("Analysis failed: " + error.getMessage(), error)); - })); + PPLExecuteHelper + .executePPLAndParseResult( + client, + baseTimeRangeLogPatternPPL, + PPLExecuteHelper.dataRowsParser(dataRowsParser), + ActionListener.wrap(basePatterns -> { + try { + mergeSimilarPatterns(basePatterns); + + log.debug("Base patterns processed: {} patterns", basePatterns.size()); + + // Step 2: Generate log patterns for selection time range + String selectionTimeRangeLogPatternPPL = buildLogPatternPPL( + params.index, + params.timeField, + params.logFieldName, + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd + ); + + log.debug("Executing selection time range pattern PPL: {}", selectionTimeRangeLogPatternPPL); + PPLExecuteHelper + .executePPLAndParseResult( + client, + selectionTimeRangeLogPatternPPL, + PPLExecuteHelper.dataRowsParser(dataRowsParser), + ActionListener.wrap(selectionPatterns -> { + mergeSimilarPatterns(selectionPatterns); + + log.debug("Selection patterns processed: {} patterns", selectionPatterns.size()); + + // Step 3: Calculate pattern differences + List patternDifferences = calculatePatternDifferences( + basePatterns, + selectionPatterns + ); + + // Step 4: Sort the difference and get top 10 + List topDiffs = Stream + .concat( + patternDifferences + .stream() + .filter(diff -> !Objects.isNull(diff.lift)) + .sorted(comparator) + .limit(10), + patternDifferences + .stream() + .filter(diff -> Objects.isNull(diff.lift)) + .sorted(comparator) + .limit(10) + ) + .collect(Collectors.toList()); + + Map finalResult = new HashMap<>(); + finalResult.put("patternMapDifference", topDiffs); + + log.debug("Pattern analysis completed: {} differences found", patternDifferences.size()); + listener.onResponse((T) gson.toJson(finalResult)); + }, listener::onFailure) + ); + + } catch (Exception e) { + log.error("Failed to process base pattern response", e); + listener.onFailure(new RuntimeException("Failed to process base patterns: " + e.getMessage(), e)); + } + }, error -> { + log.error("Failed to execute pattern analysis", error); + listener.onFailure(new RuntimeException("Analysis failed: " + error.getMessage(), error)); + }) + ); } private void logInsight(AnalysisParameters params, ActionListener listener) { @@ -763,19 +779,25 @@ private void logInsight(AnalysisParameters params, ActionListener listene return patternWithSamplesList; }; - executePPLAndParseResult(selectionTimeRangeLogPatternPPL, dataRowsParser, ActionListener.wrap(logInsights -> { - try { - Map finalResult = new HashMap<>(); - finalResult.put("logInsights", logInsights); - listener.onResponse((T) gson.toJson(finalResult)); - } catch (Exception e) { - log.error("Failed to process base pattern response", e); - listener.onFailure(new RuntimeException("Failed to process base patterns: " + e.getMessage(), e)); - } - }, error -> { - log.error("Failed to execute log insights analysis", error); - listener.onFailure(new RuntimeException("Log insights analysis failed: " + error.getMessage(), error)); - })); + PPLExecuteHelper + .executePPLAndParseResult( + client, + selectionTimeRangeLogPatternPPL, + PPLExecuteHelper.dataRowsParser(dataRowsParser), + ActionListener.wrap(logInsights -> { + try { + Map finalResult = new HashMap<>(); + finalResult.put("logInsights", logInsights); + listener.onResponse((T) gson.toJson(finalResult)); + } catch (Exception e) { + log.error("Failed to process base pattern response", e); + listener.onFailure(new RuntimeException("Failed to process base patterns: " + e.getMessage(), e)); + } + }, error -> { + log.error("Failed to execute log insights analysis", error); + listener.onFailure(new RuntimeException("Log insights analysis failed: " + error.getMessage(), error)); + }) + ); } private String buildLogPatternPPL(String index, String timeField, String logFieldName, String startTime, String endTime) { @@ -824,15 +846,6 @@ private List calculatePatternDifferences(Map return differences; } - private void handlePPLError(Throwable error) { - String errorMsg = error.getMessage(); - String errorType = error.getClass().getSimpleName(); - log.error("PPL execution failed [{}]: {}", errorType, errorMsg); - String errorString = error.toString(); - String fullErrorMessage = errorMsg != null ? errorMsg : errorString; - throw new RuntimeException("PPL execution failed: " + fullErrorMessage, error); - } - private double jaccardSimilarity(String pattern1, String pattern2) { if (Strings.isEmpty(pattern1) && Strings.isEmpty(pattern2)) { return 1.0; @@ -914,69 +927,6 @@ private String postProcessPattern(String pattern) { return pattern; } - private void executePPLAndParseResult(String ppl, Function>, T> rowParser, ActionListener listener) { - try { - JSONObject jsonContent = new JSONObject(ImmutableMap.of("query", ppl)); - PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc"); - TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest); - - client - .execute( - PPLQueryAction.INSTANCE, - transportPPLQueryRequest, - getPPLTransportActionListener(ActionListener.wrap(transportPPLQueryResponse -> { - String result = transportPPLQueryResponse.getResult(); - if (Strings.isEmpty(result)) { - listener.onFailure(new RuntimeException("Empty PPL response")); - } else { - Map pplResult = gson.fromJson(result, new TypeToken>() { - }.getType()); - if (pplResult.containsKey("error")) { - Object errorObj = pplResult.get("error"); - String errorDetail; - if (errorObj instanceof Map) { - Map errorMap = (Map) errorObj; - Object reason = errorMap.get("reason"); - errorDetail = reason != null ? reason.toString() : errorMap.toString(); - } else { - errorDetail = errorObj != null ? errorObj.toString() : "Unknown error"; - } - throw new RuntimeException("PPL query error: " + errorDetail); - } - - Object datarowsObj = pplResult.get("datarows"); - if (!(datarowsObj instanceof List)) { - throw new IllegalStateException("Invalid PPL response format: missing or invalid datarows"); - } - - @SuppressWarnings("unchecked") - List> dataRows = (List>) datarowsObj; - if (dataRows.isEmpty()) { - log.warn("PPL query returned no data rows for the specified criteria"); - } - listener.onResponse(rowParser.apply(dataRows)); - } - }, error -> { - try { - handlePPLError(error); - } catch (Exception handledException) { - listener.onFailure(handledException); - } - })) - ); - } catch (Exception e) { - String errorMessage = String - .format( - Locale.ROOT, - "Failed to execute PPL query: %s. Query: %s", - e.getMessage(), - ppl.substring(0, Math.min(100, ppl.length())) - ); - log.error(errorMessage, e); - listener.onFailure(new RuntimeException(errorMessage, e)); - } - } - public static class Factory implements Tool.Factory { private Client client; diff --git a/src/main/java/org/opensearch/agent/tools/utils/PPLExecuteHelper.java b/src/main/java/org/opensearch/agent/tools/utils/PPLExecuteHelper.java new file mode 100644 index 00000000..10f7cd6a --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/PPLExecuteHelper.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils; + +import static org.opensearch.agent.tools.utils.ToolHelper.getPPLTransportActionListener; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; + +import org.json.JSONObject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; +import org.opensearch.sql.ppl.domain.PPLQueryRequest; +import org.opensearch.transport.client.Client; + +import com.google.common.collect.ImmutableMap; +import com.google.gson.reflect.TypeToken; + +import lombok.extern.log4j.Log4j2; + +/** + * Utility class for executing PPL queries and parsing results + */ +@Log4j2 +public class PPLExecuteHelper { + + /** + * Executes PPL query and parses the result using provided result parser + * + * @param The parsed result type + * @param client OpenSearch client + * @param ppl PPL query string to execute + * @param resultParser Function to parse PPL result into desired format + * @param listener Action listener for handling parsed results or failures + */ + public static void executePPLAndParseResult( + Client client, + String ppl, + Function, T> resultParser, + ActionListener listener + ) { + try { + JSONObject jsonContent = new JSONObject(ImmutableMap.of("query", ppl)); + PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc"); + TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest); + + client + .execute( + PPLQueryAction.INSTANCE, + transportPPLQueryRequest, + getPPLTransportActionListener(ActionListener.wrap(transportPPLQueryResponse -> { + String result = transportPPLQueryResponse.getResult(); + if (Strings.isEmpty(result)) { + listener.onFailure(new RuntimeException("Empty PPL response")); + } else { + Map pplResult = gson.fromJson(result, new TypeToken>() { + }.getType()); + if (pplResult.containsKey("error")) { + Object errorObj = pplResult.get("error"); + String errorDetail; + if (errorObj instanceof Map) { + Map errorMap = (Map) errorObj; + Object reason = errorMap.get("reason"); + errorDetail = reason != null ? reason.toString() : errorMap.toString(); + } else { + errorDetail = errorObj != null ? errorObj.toString() : "Unknown error"; + } + throw new RuntimeException("PPL query error: " + errorDetail); + } + + Object datarowsObj = pplResult.get("datarows"); + if (!(datarowsObj instanceof List)) { + throw new IllegalStateException("Invalid PPL response format: missing or invalid datarows"); + } + + listener.onResponse(resultParser.apply(pplResult)); + } + }, error -> { + log.error("PPL execution failed: {}", error.getMessage()); + listener.onFailure(new RuntimeException("PPL execution failed: " + error.getMessage(), error)); + })) + ); + } catch (Exception e) { + String errorMessage = String.format(Locale.ROOT, "Failed to execute PPL query: %s", e.getMessage()); + log.error(errorMessage, e); + listener.onFailure(new RuntimeException(errorMessage, e)); + } + } + + /** + * Helper method to create a result parser that extracts datarows + */ + public static Function, T> dataRowsParser(Function>, T> rowParser) { + return pplResult -> { + Object datarowsObj = pplResult.get("datarows"); + @SuppressWarnings("unchecked") + List> dataRows = (List>) datarowsObj; + if (dataRows.isEmpty()) { + log.debug("PPL query returned no data rows for the specified criteria"); + } + return rowParser.apply(dataRows); + }; + } +} diff --git a/src/test/java/org/opensearch/agent/ToolPluginTests.java b/src/test/java/org/opensearch/agent/ToolPluginTests.java index f7cfc3dd..0bbee973 100644 --- a/src/test/java/org/opensearch/agent/ToolPluginTests.java +++ b/src/test/java/org/opensearch/agent/ToolPluginTests.java @@ -96,7 +96,7 @@ public void test_getRestHandlers_successful() { @Test public void test_getToolFactories_successful() { - assertEquals(13, toolPlugin.getToolFactories().size()); + assertEquals(14, toolPlugin.getToolFactories().size()); } @Test diff --git a/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java b/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java new file mode 100644 index 00000000..d53bf0d8 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java @@ -0,0 +1,2720 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.hamcrest.Matchers.containsString; +import static org.jsoup.helper.Validate.fail; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.IndicesAdminClient; + +import com.google.common.collect.ImmutableMap; +import com.google.gson.JsonElement; + +import lombok.SneakyThrows; + +public class DataDistributionToolTests { + + private Map params = new HashMap<>(); + private final Client client = mock(Client.class); + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + @Mock + private SearchResponse searchResponse; + @Mock + private TransportPPLQueryResponse pplQueryResponse; + + @SneakyThrows + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + setupMockMappings(); + DataDistributionTool.Factory.getInstance().init(client); + } + + private void mockSearchResponse() { + SearchHit[] hits = createSampleHits(); + SearchHits searchHits = new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + when(searchResponse.getHits()).thenReturn(searchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + } + + private void mockPPLInvocation(String response) { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + when(pplQueryResponse.getResult()).thenReturn(response); + } + + @Test + @SneakyThrows + public void testCreateTool() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + assertEquals("DataDistributionTool", tool.getType()); + assertEquals("DataDistributionTool", tool.getName()); + assertEquals(DataDistributionTool.Factory.getInstance().getDefaultDescription(), tool.getDescription()); + assertNull(DataDistributionTool.Factory.getInstance().getDefaultVersion()); + } + + @Test + public void testValidate() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Valid parameters + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00" + ) + ) + ); + + // Valid parameters with new fields + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "filter", + "[\"{'term': {'status': 'error'}}\"]", + "ppl", + "source=logs-* | where status='error'" + ) + ) + ); + + // Missing required parameters + assertFalse(tool.validate(Map.of("index", "test_index"))); + assertFalse(tool.validate(Map.of())); + + // Missing selectionTimeRangeStart + assertFalse(tool.validate(Map.of("index", "test_index", "selectionTimeRangeEnd", "2025-01-15 11:00:00"))); + + // Missing selectionTimeRangeEnd + assertFalse(tool.validate(Map.of("index", "test_index", "selectionTimeRangeStart", "2025-01-15 10:00:00"))); + + // Valid with default queryType and timeField + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00" + ) + ) + ); + + // Valid with explicit queryType and timeField + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl", + "timeField", + "timestamp" + ) + ) + ); + } + + @Test + @SneakyThrows + public void testDSLSingleAnalysis() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify the analysis contains field distribution data + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("singleAnalysis should contain at least one field analysis", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify each field analysis has required structure (SummaryDataItem) + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue("Field analysis should be a JSON object", fieldAnalysis.isJsonObject()); + assertTrue("Field analysis should have 'field' property", fieldAnalysis.getAsJsonObject().has("field")); + assertTrue("Field analysis should have 'divergence' property", fieldAnalysis.getAsJsonObject().has("divergence")); + assertTrue("Field analysis should have 'topChanges' property", fieldAnalysis.getAsJsonObject().has("topChanges")); + assertNotNull("Field name should not be null", fieldAnalysis.getAsJsonObject().get("field").getAsString()); + assertTrue("TopChanges should be a JSON array", fieldAnalysis.getAsJsonObject().get("topChanges").isJsonArray()); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify filter was applied (should still have analysis data) + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("singleAnalysis should contain field analyses even with filter", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify structure is maintained with filter + JsonElement firstField = singleAnalysis.getAsJsonArray().get(0); + assertTrue("Field analysis should have proper structure with filter", firstField.getAsJsonObject().has("field")); + assertTrue("Field analysis should have topChanges with filter", firstField.getAsJsonObject().has("topChanges")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithMultipleFilters() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify multiple filters were applied + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("singleAnalysis should contain analyses with multiple filters", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify each field has proper structure with multiple filters + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue( + "Field analysis should maintain structure with multiple filters", + fieldAnalysis.getAsJsonObject().has("field") + ); + assertTrue( + "Field analysis should have topChanges with multiple filters", + fieldAnalysis.getAsJsonObject().has("topChanges") + ); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testPPLSingleAnalysis() { + String pplResponse = + """ + {"schema":[{"name":"status","type":"keyword"},{"name":"level","type":"integer"},{"name":"host","type":"keyword"}], + "datarows":[["error",3,"server-01"],["info",1,"server-02"],["warning",2,"server-03"],["error",4,"server-01"],["debug",1,"server-02"]], + "total":5,"size":5} + """; + + mockPPLInvocation(pplResponse); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify PPL data was processed correctly + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("Should have at least one field from PPL response", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify each field has proper structure + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue("Field analysis should have field property", fieldAnalysis.getAsJsonObject().has("field")); + assertTrue("Field analysis should have divergence property", fieldAnalysis.getAsJsonObject().has("divergence")); + assertTrue("Field analysis should have topChanges property", fieldAnalysis.getAsJsonObject().has("topChanges")); + assertTrue("TopChanges should be an array", fieldAnalysis.getAsJsonObject().get("topChanges").isJsonArray()); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testPPLWithCustomStatement() { + String pplResponse = """ + {"schema":[{"name":"status","type":"keyword"},{"name":"host","type":"keyword"},{"name":"count","type":"long"}], + "datarows":[["error","server-01",15],["error","server-02",8],["warning","server-01",3]], + "total":3,"size":3} + """; + + mockPPLInvocation(pplResponse); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl", + "ppl", + "source=logs-* | where status='error' | stats count() by host" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify custom PPL statement was processed + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("Should have at least one field from custom PPL response", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify each field has proper structure + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue("Field analysis should have field property", fieldAnalysis.getAsJsonObject().has("field")); + assertTrue("Field analysis should have divergence property", fieldAnalysis.getAsJsonObject().has("divergence")); + assertTrue("Field analysis should have topChanges property", fieldAnalysis.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testComparisonAnalysis() { + // Mock different responses for baseline and selection data + SearchHit[] baselineHits = createBaselineHits(); + SearchHit[] selectionHits = createSelectionHits(); + + SearchHits baselineSearchHits = new SearchHits(baselineHits, new TotalHits(baselineHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchHits selectionSearchHits = new SearchHits( + selectionHits, + new TotalHits(selectionHits.length, TotalHits.Relation.EQUAL_TO), + 1.0f + ); + + // Mock sequential search calls - first selection, then baseline (based on new implementation) + when(searchResponse.getHits()) + .thenReturn(selectionSearchHits) // First call returns selection data + .thenReturn(baselineSearchHits); // Second call returns baseline data + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "baselineTimeRangeStart", + "2025-01-15 08:00:00", + "baselineTimeRangeEnd", + "2025-01-15 09:00:00", + "queryType", + "dsl" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain comparisonAnalysis", result.getAsJsonObject().has("comparisonAnalysis")); + + // Verify comparison analysis contains divergence data + JsonElement comparisonAnalysis = result.getAsJsonObject().get("comparisonAnalysis"); + assertTrue("comparisonAnalysis should be a JSON array", comparisonAnalysis.isJsonArray()); + assertTrue("comparisonAnalysis should contain field comparisons", comparisonAnalysis.getAsJsonArray().size() > 0); + + // Verify each comparison has required structure (SummaryDataItem) + for (int i = 0; i < comparisonAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldComparison = comparisonAnalysis.getAsJsonArray().get(i); + assertTrue("Field comparison should be a JSON object", fieldComparison.isJsonObject()); + assertTrue("Field comparison should have 'field' property", fieldComparison.getAsJsonObject().has("field")); + assertTrue( + "Field comparison should have 'divergence' property", + fieldComparison.getAsJsonObject().has("divergence") + ); + assertTrue( + "Field comparison should have 'topChanges' property", + fieldComparison.getAsJsonObject().has("topChanges") + ); + + // Verify divergence is a valid number + assertTrue("Divergence should be a number", fieldComparison.getAsJsonObject().get("divergence").isJsonPrimitive()); + double divergence = fieldComparison.getAsJsonObject().get("divergence").getAsDouble(); + assertTrue("Divergence should be non-negative", divergence >= 0.0); + + // Verify topChanges structure + JsonElement topChanges = fieldComparison.getAsJsonObject().get("topChanges"); + assertTrue("TopChanges should be a JSON array", topChanges.isJsonArray()); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testPPLComparisonAnalysis() { + String baseResponse = """ + {"schema":[{"name":"status","type":"keyword"},{"name":"level","type":"integer"}], + "datarows":[["info",1],["warning",2],["debug",1]], + "total":3,"size":3} + """; + + String selectionResponse = """ + {"schema":[{"name":"status","type":"keyword"},{"name":"level","type":"integer"}], + "datarows":[["error",3],["error",4],["warning",2]], + "total":3,"size":3} + """; + + // Mock sequential PPL calls - first selection, then baseline (based on new implementation) + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + when(pplQueryResponse.getResult()) + .thenReturn(selectionResponse) // First call returns selection data + .thenReturn(baseResponse); // Second call returns baseline data + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "baselineTimeRangeStart", + "2025-01-15 08:00:00", + "baselineTimeRangeEnd", + "2025-01-15 09:00:00", + "queryType", + "ppl", + "ppl", + "source=logs-* | where level > 1" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain comparisonAnalysis", result.getAsJsonObject().has("comparisonAnalysis")); + + // Verify comparison shows differences between baseline and selection + JsonElement comparisonAnalysis = result.getAsJsonObject().get("comparisonAnalysis"); + assertTrue("comparisonAnalysis should be a JSON array", comparisonAnalysis.isJsonArray()); + assertTrue("Should have at least one field from PPL comparison", comparisonAnalysis.getAsJsonArray().size() > 0); + + // Verify each field has proper structure + for (int i = 0; i < comparisonAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldComparison = comparisonAnalysis.getAsJsonArray().get(i); + assertTrue("Field comparison should have field property", fieldComparison.getAsJsonObject().has("field")); + assertTrue("Field comparison should have divergence property", fieldComparison.getAsJsonObject().has("divergence")); + assertTrue("Field comparison should have topChanges property", fieldComparison.getAsJsonObject().has("topChanges")); + + // Verify divergence is a valid number + double divergence = fieldComparison.getAsJsonObject().get("divergence").getAsDouble(); + assertTrue("Divergence should be non-negative", divergence >= 0.0); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidParameters() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap.of("index", "test_index"), + ActionListener + .wrap( + response -> fail("Should have failed with invalid parameters"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Invalid time format")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "invalid-json" + ), + ActionListener.wrap(response -> fail("Should have failed with invalid filter JSON"), e -> { + MatcherAssert.assertThat(e.getMessage(), containsString("Invalid 'filter' parameter")); + MatcherAssert.assertThat(e.getMessage(), containsString("must be a valid JSON array of strings")); + }) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithCustomTimeField() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "custom_timestamp", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify custom time field was used + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue( + "singleAnalysis should contain field analyses with custom time field", + singleAnalysis.getAsJsonArray().size() > 0 + ); + + // Verify that the custom time field doesn't appear in the analysis (it's used for filtering, not analysis) + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + String fieldName = fieldAnalysis.getAsJsonObject().get("field").getAsString(); + assertFalse("Custom time field should not appear in analysis results", "custom_timestamp".equals(fieldName)); + assertTrue("Field analysis should have topChanges property", fieldAnalysis.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testExecutionFailedInSearch() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new Exception("Search execution failed")); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener + .wrap( + response -> fail("Should have failed"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Search execution failed")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionFailedInPPL() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("PPL execution failed")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl" + ), + ActionListener + .wrap( + response -> fail("Should have failed"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("PPL execution failed")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithEmptyPPLResponse() { + String emptyResponse = ""; + mockPPLInvocation(emptyResponse); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl" + ), + ActionListener + .wrap( + response -> fail("Should have failed with empty response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Empty PPL response")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithPPLErrorResponse() { + String errorResponse = "{\"error\":{\"type\":\"parsing_exception\",\"reason\":\"Syntax error in PPL query\"}}"; + mockPPLInvocation(errorResponse); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl" + ), + ActionListener + .wrap( + response -> fail("Should have failed with PPL error response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("PPL query error")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithNoData() { + // Mock empty search response + SearchHit[] emptyHits = new SearchHit[0]; + SearchHits emptySearchHits = new SearchHits(emptyHits, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); + when(searchResponse.getHits()).thenReturn(emptySearchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener + .wrap( + response -> fail("Should have failed with no data"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("No data found for selection time range")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidTimeFormat() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "invalid-time-format", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener + .wrap( + response -> fail("Should have failed with invalid time format"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Invalid time format")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidSize() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "size", + "not-a-number" + ), + ActionListener.wrap(response -> fail("Should have failed with invalid size"), e -> { + MatcherAssert.assertThat(e.getMessage(), containsString("Invalid 'size' parameter")); + MatcherAssert.assertThat(e.getMessage(), containsString("must be a valid integer")); + }) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithSizeExceedsMaxLimit() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "size", + "15000" + ), + ActionListener.wrap(response -> fail("Should have failed with size exceeding limit"), e -> { + MatcherAssert.assertThat(e.getMessage(), containsString("Size parameter exceeds maximum limit of 10000")); + MatcherAssert.assertThat(e.getMessage(), containsString("got: 15000")); + }) + ); + } + + private void setupMockMappings() { + Map indexMappings = Map + .of( + "properties", + Map + .of( + "status", + Map.of("type", "keyword"), + "level", + Map.of("type", "integer"), + "@timestamp", + Map.of("type", "date"), + "message", + Map.of("type", "text"), + "host", + Map.of("type", "keyword"), + "service", + Map.of("type", "keyword") + ) + ); + Map mockedMappings = Map.of("test_index", mappingMetadata); + + when(mappingMetadata.getSourceAsMap()).thenReturn(indexMappings); + when(getMappingsResponse.getMappings()).thenReturn(mockedMappings); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + + // Mock the ActionFuture returned by getMappings + org.opensearch.common.action.ActionFuture mockActionFuture = mock( + org.opensearch.common.action.ActionFuture.class + ); + when(mockActionFuture.actionGet(anyLong())).thenReturn(getMappingsResponse); + when(mockActionFuture.actionGet()).thenReturn(getMappingsResponse); + when(indicesAdminClient.getMappings(any())).thenReturn(mockActionFuture); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getMappingsResponse); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + } + + private SearchHit[] createSampleHits() { + SearchHit[] hits = new SearchHit[20]; + String[] statuses = { "error", "info", "warning", "debug" }; + String[] hosts = { "server-01", "server-02", "server-03" }; + String[] services = { "auth", "payment", "notification" }; + int[] levels = { 1, 2, 3, 4, 5 }; + + for (int i = 0; i < 20; i++) { + SearchHit hit = new SearchHit(i + 1); + String status = statuses[i % statuses.length]; + String host = hosts[i % hosts.length]; + String service = services[i % services.length]; + int level = levels[i % levels.length]; + + String source = String + .format( + "{\"status\":\"%s\",\"level\":%d,\"@timestamp\":\"2025-01-15T10:%02d:00Z\",\"host\":\"%s\",\"service\":\"%s\",\"message\":\"Sample message %d\"}", + status, + level, + 30 + i, + host, + service, + i + ); + + BytesReference sourceRef = new BytesArray(source); + hit.sourceRef(sourceRef); + hits[i] = hit; + } + + return hits; + } + + private SearchHit[] createBaselineHits() { + SearchHit[] hits = new SearchHit[10]; + // Baseline data: mostly info and warning + String[] statuses = { "info", "warning" }; + String[] hosts = { "server-01", "server-02" }; + String[] services = { "auth", "payment" }; + int[] levels = { 1, 2 }; + + for (int i = 0; i < 10; i++) { + SearchHit hit = new SearchHit(i + 1); + String status = statuses[i % statuses.length]; + String host = hosts[i % hosts.length]; + String service = services[i % services.length]; + int level = levels[i % levels.length]; + + String source = String + .format( + "{\"status\":\"%s\",\"level\":%d,\"@timestamp\":\"2025-01-15T08:%02d:00Z\",\"host\":\"%s\",\"service\":\"%s\",\"message\":\"Baseline message %d\"}", + status, + level, + 30 + i, + host, + service, + i + ); + + BytesReference sourceRef = new BytesArray(source); + hit.sourceRef(sourceRef); + hits[i] = hit; + } + + return hits; + } + + private SearchHit[] createSelectionHits() { + SearchHit[] hits = new SearchHit[10]; + // Selection data: mostly error and debug (different from baseline) + String[] statuses = { "error", "debug" }; + String[] hosts = { "server-02", "server-03" }; + String[] services = { "payment", "notification" }; + int[] levels = { 3, 4, 5 }; + + for (int i = 0; i < 10; i++) { + SearchHit hit = new SearchHit(i + 1); + String status = statuses[i % statuses.length]; + String host = hosts[i % hosts.length]; + String service = services[i % services.length]; + int level = levels[i % levels.length]; + + String source = String + .format( + "{\"status\":\"%s\",\"level\":%d,\"@timestamp\":\"2025-01-15T10:%02d:00Z\",\"host\":\"%s\",\"service\":\"%s\",\"message\":\"Selection message %d\"}", + status, + level, + 30 + i, + host, + service, + i + ); + + BytesReference sourceRef = new BytesArray(source); + hit.sourceRef(sourceRef); + hits[i] = hit; + } + + return hits; + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithValidMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> testData = createTestDataForFieldAnalysis(); + Map fieldTypes = Map + .of("status", "keyword", "level", "integer", "host", "keyword", "service", "keyword", "@timestamp", "date", "message", "text"); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, testData, fieldTypes); + + assertNotNull(usefulFields); + assertFalse(usefulFields.isEmpty()); + assertTrue(usefulFields.contains("status")); + assertTrue(usefulFields.contains("level")); + assertTrue(usefulFields.contains("host")); + assertTrue(usefulFields.contains("service")); + assertFalse(usefulFields.contains("@timestamp")); + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithEmptyMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> testData = createTestDataForFieldAnalysis(); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, testData, emptyFieldTypes); + + assertNotNull(usefulFields); + assertTrue(usefulFields.size() > 0); + assertFalse(usefulFields.contains("@timestamp")); + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithMappingException() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> testData = createTestDataForFieldAnalysis(); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, testData, emptyFieldTypes); + + assertNotNull(usefulFields); + assertTrue(usefulFields.size() > 0); + assertFalse(usefulFields.contains("@timestamp")); + assertFalse(usefulFields.contains("_id")); + assertFalse(usefulFields.contains("_index")); + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithHighCardinalityFields() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> testData = createHighCardinalityTestData(); + Map fieldTypes = Map.of("status", "keyword", "unique_field", "keyword", "@timestamp", "date"); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, testData, fieldTypes); + + assertNotNull(usefulFields); + // unique_field has high cardinality (20 unique values in 20 documents) so should be excluded + assertFalse("High cardinality field unique_field should be excluded", usefulFields.contains("unique_field")); + // status has low cardinality (2 unique values) so should be included + assertTrue("Low cardinality field status should be included", usefulFields.contains("status")); + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithEmptyData() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> emptyData = List.of(); + Map fieldTypes = Map.of("status", "keyword", "level", "integer"); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, emptyData, fieldTypes); + + assertNotNull(usefulFields); + assertTrue(usefulFields.size() > 0); + } + + private List> createTestDataForFieldAnalysis() { + List> data = new ArrayList<>(); + String[] statuses = { "error", "info", "warning" }; + String[] hosts = { "server-01", "server-02" }; + String[] services = { "auth", "payment" }; + + for (int i = 0; i < 10; i++) { + Map doc = new HashMap<>(); + doc.put("status", statuses[i % statuses.length]); + doc.put("level", i % 5 + 1); + doc.put("host", hosts[i % hosts.length]); + doc.put("service", services[i % services.length]); + doc.put("@timestamp", "2025-01-15T10:" + String.format("%02d", 30 + i) + ":00Z"); + doc.put("message", "Test message " + i); + data.add(doc); + } + return data; + } + + private List> createHighCardinalityTestData() { + List> data = new ArrayList<>(); + String[] statuses = { "error", "info" }; + + for (int i = 0; i < 20; i++) { + Map doc = new HashMap<>(); + doc.put("status", statuses[i % statuses.length]); + doc.put("unique_field", "value_" + i); + doc.put("@timestamp", "2025-01-15T10:" + String.format("%02d", 30 + i) + ":00Z"); + data.add(doc); + } + return data; + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithTermQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("status", Map.of("term", "error")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("term")); + assertTrue(queryBuilder.toString().contains("status")); + assertTrue(queryBuilder.toString().contains("error")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithRangeQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("level", Map.of("range", Map.of("gte", 3, "lte", 5))); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("range")); + assertTrue(queryBuilder.toString().contains("level")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithMatchQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("message", Map.of("match", "test message")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("match")); + assertTrue(queryBuilder.toString().contains("message")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithExistsQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("status", Map.of("exists", true)); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("exists")); + assertTrue(queryBuilder.toString().contains("status")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithDirectTermQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("status", "error"); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("term")); + assertTrue(queryBuilder.toString().contains("status")); + assertTrue(queryBuilder.toString().contains("error")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithMatchPhraseQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("message", Map.of("match_phrase", "exact phrase")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("match_phrase")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithPrefixQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("host", Map.of("prefix", "server")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("prefix")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithWildcardQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("host", Map.of("wildcard", "server*")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("wildcard")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithWildcardMapQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("host", Map.of("wildcard", Map.of("value", "server*"))); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("wildcard")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithRegexpQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("host", Map.of("regexp", "server-[0-9]+")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("regexp")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithRegexpMapQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("host", Map.of("regexp", Map.of("value", "server-[0-9]+"))); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("regexp")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithTermsQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("terms", Map.of("status", List.of("error", "warning"))); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("terms")); + assertTrue(queryBuilder.toString().contains("status")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithMultiMatchQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map + .of("multi_match", Map.of("query", "error message", "fields", List.of("message", "description"))); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("multi_match")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithComplexRangeQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("level", Map.of("range", Map.of("gt", 1, "lt", 10))); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + assertTrue(queryBuilder.toString().contains("range")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithUnsupportedOperator() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + Map filterMap = Map.of("status", Map.of("unsupported_op", "value")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + } + + @Test + @SneakyThrows + public void testGroupNumericKeysWithManyNumericValues() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method groupNumericKeysMethod = DataDistributionTool.class + .getDeclaredMethod("groupNumericKeys", Map.class, Map.class); + groupNumericKeysMethod.setAccessible(true); + + Map selectionDist = new HashMap<>(); + Map baselineDist = new HashMap<>(); + + for (int i = 1; i <= 15; i++) { + selectionDist.put(String.valueOf(i), 0.1); + baselineDist.put(String.valueOf(i + 5), 0.1); + } + + Object result = groupNumericKeysMethod.invoke(tool, selectionDist, baselineDist); + + assertNotNull(result); + java.lang.reflect.Method groupedSelectionDistMethod = result.getClass().getDeclaredMethod("groupedSelectionDist"); + @SuppressWarnings("unchecked") + Map groupedSelection = (Map) groupedSelectionDistMethod.invoke(result); + + assertEquals(5, groupedSelection.size()); + assertTrue(groupedSelection.keySet().stream().allMatch(key -> key.contains("-"))); + } + + @Test + @SneakyThrows + public void testGroupNumericKeysWithFewNumericValues() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method groupNumericKeysMethod = DataDistributionTool.class + .getDeclaredMethod("groupNumericKeys", Map.class, Map.class); + groupNumericKeysMethod.setAccessible(true); + + Map selectionDist = Map.of("1", 0.3, "2", 0.4, "3", 0.3); + Map baselineDist = Map.of("1", 0.2, "2", 0.5, "3", 0.3); + + Object result = groupNumericKeysMethod.invoke(tool, selectionDist, baselineDist); + + assertNotNull(result); + java.lang.reflect.Method groupedSelectionDistMethod = result.getClass().getDeclaredMethod("groupedSelectionDist"); + @SuppressWarnings("unchecked") + Map groupedSelection = (Map) groupedSelectionDistMethod.invoke(result); + + assertEquals(selectionDist, groupedSelection); + } + + @Test + @SneakyThrows + public void testGroupNumericKeysWithNonNumericValues() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method groupNumericKeysMethod = DataDistributionTool.class + .getDeclaredMethod("groupNumericKeys", Map.class, Map.class); + groupNumericKeysMethod.setAccessible(true); + + Map selectionDist = new HashMap<>(); + Map baselineDist = new HashMap<>(); + + for (int i = 1; i <= 15; i++) { + selectionDist.put(String.valueOf(i), 0.1); + } + selectionDist.put("error", 0.2); + selectionDist.put("warning", 0.3); + + Object result = groupNumericKeysMethod.invoke(tool, selectionDist, baselineDist); + + assertNotNull(result); + java.lang.reflect.Method groupedSelectionDistMethod = result.getClass().getDeclaredMethod("groupedSelectionDist"); + @SuppressWarnings("unchecked") + Map groupedSelection = (Map) groupedSelectionDistMethod.invoke(result); + + assertEquals(selectionDist, groupedSelection); + } + + @Test + @SneakyThrows + public void testGetNumberFieldsWithValidMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + Map fieldTypes = Map.of("status", "keyword", "level", "integer", "host", "keyword", "response_time", "float"); + + java.lang.reflect.Method getNumberFieldsMethod = DataDistributionTool.class.getDeclaredMethod("getNumberFields", Map.class); + getNumberFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + java.util.Set numberFields = (java.util.Set) getNumberFieldsMethod.invoke(tool, fieldTypes); + + assertNotNull(numberFields); + assertTrue(numberFields.contains("level")); + assertTrue(numberFields.contains("response_time")); + assertFalse(numberFields.contains("status")); + assertFalse(numberFields.contains("host")); + } + + @Test + @SneakyThrows + public void testGetNumberFieldsWithEmptyMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getNumberFieldsMethod = DataDistributionTool.class.getDeclaredMethod("getNumberFields", Map.class); + getNumberFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + java.util.Set numberFields = (java.util.Set) getNumberFieldsMethod.invoke(tool, emptyFieldTypes); + + assertNotNull(numberFields); + assertTrue(numberFields.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetNumberFieldsWithMappingException() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getNumberFieldsMethod = DataDistributionTool.class.getDeclaredMethod("getNumberFields", Map.class); + getNumberFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + java.util.Set numberFields = (java.util.Set) getNumberFieldsMethod.invoke(tool, emptyFieldTypes); + + assertNotNull(numberFields); + assertTrue(numberFields.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetNumberFieldsWithNullActionFuture() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getNumberFieldsMethod = DataDistributionTool.class.getDeclaredMethod("getNumberFields", Map.class); + getNumberFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + java.util.Set numberFields = (java.util.Set) getNumberFieldsMethod.invoke(tool, emptyFieldTypes); + + assertNotNull(numberFields); + assertTrue(numberFields.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetFieldTypesWithValidMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getFieldTypesMethod = DataDistributionTool.class + .getDeclaredMethod("getFieldTypes", String.class, ActionListener.class); + getFieldTypesMethod.setAccessible(true); + + java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1); + java.util.concurrent.atomic.AtomicReference> resultRef = new java.util.concurrent.atomic.AtomicReference<>(); + + ActionListener> listener = ActionListener.wrap(result -> { + resultRef.set(result); + latch.countDown(); + }, e -> { + latch.countDown(); + fail("getFieldTypes failed: " + e.getMessage()); + }); + + getFieldTypesMethod.invoke(tool, "test_index", listener); + latch.await(); + + Map fieldTypes = resultRef.get(); + assertNotNull(fieldTypes); + assertFalse(fieldTypes.isEmpty()); + assertEquals("keyword", fieldTypes.get("status")); + assertEquals("integer", fieldTypes.get("level")); + assertEquals("keyword", fieldTypes.get("host")); + } + + @Test + @SneakyThrows + public void testGetFieldTypesWithEmptyMapping() { + when(getMappingsResponse.getMappings()).thenReturn(Map.of()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getFieldTypesMethod = DataDistributionTool.class + .getDeclaredMethod("getFieldTypes", String.class, ActionListener.class); + getFieldTypesMethod.setAccessible(true); + + java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1); + java.util.concurrent.atomic.AtomicReference> resultRef = new java.util.concurrent.atomic.AtomicReference<>(); + + ActionListener> listener = ActionListener.wrap(result -> { + resultRef.set(result); + latch.countDown(); + }, e -> { + latch.countDown(); + fail("getFieldTypes failed: " + e.getMessage()); + }); + + getFieldTypesMethod.invoke(tool, "test_index", listener); + latch.await(); + + Map fieldTypes = resultRef.get(); + assertNotNull(fieldTypes); + assertTrue(fieldTypes.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetFieldTypesWithMappingException() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new RuntimeException("Mapping failed")); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getFieldTypesMethod = DataDistributionTool.class + .getDeclaredMethod("getFieldTypes", String.class, ActionListener.class); + getFieldTypesMethod.setAccessible(true); + + java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1); + java.util.concurrent.atomic.AtomicReference> resultRef = new java.util.concurrent.atomic.AtomicReference<>(); + + ActionListener> listener = ActionListener.wrap(result -> { + resultRef.set(result); + latch.countDown(); + }, e -> { + resultRef.set(Map.of()); + latch.countDown(); + }); + + getFieldTypesMethod.invoke(tool, "test_index", listener); + latch.await(); + + Map fieldTypes = resultRef.get(); + assertNotNull(fieldTypes); + assertTrue(fieldTypes.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetPPLQueryWithTimeRangeEmptyQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getPPLQueryWithTimeRangeMethod = DataDistributionTool.class + .getDeclaredMethod("getPPLQueryWithTimeRange", String.class, String.class, String.class, String.class); + getPPLQueryWithTimeRangeMethod.setAccessible(true); + + try { + getPPLQueryWithTimeRangeMethod.invoke(tool, "", "2025-01-15 10:00:00", "2025-01-15 11:00:00", "@timestamp"); + fail("Expected IllegalArgumentException for empty PPL query"); + } catch (java.lang.reflect.InvocationTargetException e) { + assertTrue("Expected IllegalArgumentException", e.getCause() instanceof IllegalArgumentException); + assertEquals("PPL query cannot be empty", e.getCause().getMessage()); + } + } + + @Test + @SneakyThrows + public void testGetPPLQueryWithTimeRangeEmptyTimeField() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getPPLQueryWithTimeRangeMethod = DataDistributionTool.class + .getDeclaredMethod("getPPLQueryWithTimeRange", String.class, String.class, String.class, String.class); + getPPLQueryWithTimeRangeMethod.setAccessible(true); + + String result = (String) getPPLQueryWithTimeRangeMethod + .invoke(tool, "source=logs-*", "2025-01-15 10:00:00", "2025-01-15 11:00:00", ""); + + assertEquals("source=logs-*", result); + } + + @Test + @SneakyThrows + public void testGetPPLQueryWithTimeRangeExistingWhere() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getPPLQueryWithTimeRangeMethod = DataDistributionTool.class + .getDeclaredMethod("getPPLQueryWithTimeRange", String.class, String.class, String.class, String.class); + getPPLQueryWithTimeRangeMethod.setAccessible(true); + + String result = (String) getPPLQueryWithTimeRangeMethod + .invoke(tool, "source=logs-* | where status='error'", "2025-01-15 10:00:00", "2025-01-15 11:00:00", "@timestamp"); + + assertEquals( + "source=logs-* | WHERE `@timestamp` >= '2025-01-15 10:00:00' AND `@timestamp` <= '2025-01-15 11:00:00' | where status='error'", + result + ); + } + + @Test + @SneakyThrows + public void testGetPPLQueryWithTimeRangeNoExistingWhere() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getPPLQueryWithTimeRangeMethod = DataDistributionTool.class + .getDeclaredMethod("getPPLQueryWithTimeRange", String.class, String.class, String.class, String.class); + getPPLQueryWithTimeRangeMethod.setAccessible(true); + + String result = (String) getPPLQueryWithTimeRangeMethod + .invoke(tool, "source=logs-* | stats count() by status", "2025-01-15 10:00:00", "2025-01-15 11:00:00", "@timestamp"); + + assertEquals( + "source=logs-* | WHERE `@timestamp` >= '2025-01-15 10:00:00' AND `@timestamp` <= '2025-01-15 11:00:00' | stats count() by status", + result + ); + } + + // ========== DSL Query Format Tests ========== + + @Test + @SneakyThrows + public void testDSLWithRawDSLQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}]}}" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("singleAnalysis should contain field analyses with raw DSL", singleAnalysis.getAsJsonArray().size() > 0); + + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue( + "Field analysis should have proper structure with raw DSL", + fieldAnalysis.getAsJsonObject().has("field") + ); + assertTrue("Field analysis should have topChanges with raw DSL", fieldAnalysis.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithComplexRawDSLQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + String complexDSL = """ + { + "bool": { + "must": [ + {"term": {"status": "error"}}, + {"range": {"level": {"gte": 3}}} + ], + "should": [ + {"match": {"message": "timeout"}}, + {"wildcard": {"host": "server-*"}} + ], + "minimum_should_match": 1 + } + } + """; + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + complexDSL + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("Complex DSL should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithInvalidRawDSLQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "invalid-json-query" + ), + ActionListener.wrap(response -> { + // Should fallback to time range only query when DSL parsing fails + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue( + "Response should contain singleAnalysis even with invalid DSL", + result.getAsJsonObject().has("singleAnalysis") + ); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArraySingleFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Single filter should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayComplexFilters() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3, 'lte': 5}}}\", \"{'wildcard': {'host': 'server-*'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Complex filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayMatchQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'match': {'message': 'error timeout'}}\", \"{'match_phrase': {'service': 'payment service'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Match query filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayExistsAndPrefixQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'exists': {'field': 'error_code'}}\", \"{'prefix': {'host': 'prod'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Exists and prefix filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayRegexpQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'regexp': {'host': 'server-[0-9]+'}}\", \"{'wildcard': {'service': '*payment*'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Regexp and wildcard filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayTermsQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'terms': {'status': ['error', 'warning']}}\", \"{'terms': {'level': [3, 4, 5]}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Terms filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayMultiMatchQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'multi_match': {'query': 'error timeout', 'fields': ['message', 'description']}}\", \"{'multi_match': {'query': 'connection failed', 'fields': ['error_msg', 'details']}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Multi-match filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayInvalidFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\", \"invalid-json-filter\"]" + ), + ActionListener.wrap(response -> { + // Should continue processing valid filters and ignore invalid ones + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue( + "Response should contain singleAnalysis even with some invalid filters", + result.getAsJsonObject().has("singleAnalysis") + ); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLComparisonWithRawDSLQuery() { + SearchHit[] baselineHits = createBaselineHits(); + SearchHit[] selectionHits = createSelectionHits(); + + SearchHits baselineSearchHits = new SearchHits(baselineHits, new TotalHits(baselineHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchHits selectionSearchHits = new SearchHits( + selectionHits, + new TotalHits(selectionHits.length, TotalHits.Relation.EQUAL_TO), + 1.0f + ); + + when(searchResponse.getHits()).thenReturn(selectionSearchHits).thenReturn(baselineSearchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "baselineTimeRangeStart", + "2025-01-15 08:00:00", + "baselineTimeRangeEnd", + "2025-01-15 09:00:00", + "queryType", + "dsl", + "dsl", + "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}]}}" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain comparisonAnalysis", result.getAsJsonObject().has("comparisonAnalysis")); + + JsonElement comparisonAnalysis = result.getAsJsonObject().get("comparisonAnalysis"); + assertTrue("comparisonAnalysis should be a JSON array", comparisonAnalysis.isJsonArray()); + assertTrue("Raw DSL comparison should produce results", comparisonAnalysis.getAsJsonArray().size() > 0); + + for (int i = 0; i < comparisonAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldComparison = comparisonAnalysis.getAsJsonArray().get(i); + assertTrue("Field comparison should have divergence", fieldComparison.getAsJsonObject().has("divergence")); + assertTrue("Field comparison should have topChanges", fieldComparison.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLComparisonWithFilterArray() { + SearchHit[] baselineHits = createBaselineHits(); + SearchHit[] selectionHits = createSelectionHits(); + + SearchHits baselineSearchHits = new SearchHits(baselineHits, new TotalHits(baselineHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchHits selectionSearchHits = new SearchHits( + selectionHits, + new TotalHits(selectionHits.length, TotalHits.Relation.EQUAL_TO), + 1.0f + ); + + when(searchResponse.getHits()).thenReturn(selectionSearchHits).thenReturn(baselineSearchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "baselineTimeRangeStart", + "2025-01-15 08:00:00", + "baselineTimeRangeEnd", + "2025-01-15 09:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain comparisonAnalysis", result.getAsJsonObject().has("comparisonAnalysis")); + + JsonElement comparisonAnalysis = result.getAsJsonObject().get("comparisonAnalysis"); + assertTrue("comparisonAnalysis should be a JSON array", comparisonAnalysis.isJsonArray()); + assertTrue("Filter array comparison should produce results", comparisonAnalysis.getAsJsonArray().size() > 0); + + for (int i = 0; i < comparisonAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldComparison = comparisonAnalysis.getAsJsonArray().get(i); + assertTrue("Field comparison should have divergence", fieldComparison.getAsJsonObject().has("divergence")); + assertTrue("Field comparison should have topChanges", fieldComparison.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithBothRawDSLAndFilterArray() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}]}}", + "filter", + "[\"{'range': {'level': {'gte': 3}}}\"]" + ), + ActionListener.wrap(response -> { + // When both dsl and filter are provided, dsl should take precedence + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Both DSL and filter should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + // ========== Query Format Validation Tests ========== + + @Test + @SneakyThrows + public void testValidateFilterArrayFormat() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Valid filter array formats + assertTrue( + "Single filter should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "filter", + "[\"{'term': {'status': 'error'}}\"]" + ) + ) + ); + + assertTrue( + "Multiple filters should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "filter", + "[\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]" + ) + ) + ); + + assertTrue( + "Empty filter array should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "filter", + "[]" + ) + ) + ); + } + + @Test + @SneakyThrows + public void testValidateRawDSLFormat() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Valid DSL formats + assertTrue( + "Simple DSL should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "dsl", + "{\"term\": {\"status\": \"error\"}}" + ) + ) + ); + + assertTrue( + "Complex DSL should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "dsl", + "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}], \"filter\": [{\"range\": {\"level\": {\"gte\": 3}}}]}}" + ) + ) + ); + + assertTrue( + "Empty DSL should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "dsl", + "" + ) + ) + ); + } + + @Test + @SneakyThrows + public void testValidateBothDSLAndFilterFormats() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Both DSL and filter provided should be valid + assertTrue( + "Both DSL and filter should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "dsl", + "{\"term\": {\"status\": \"error\"}}", + "filter", + "[\"{'range': {'level': {'gte': 3}}}\"]" + ) + ) + ); + } + + // ========== Edge Cases and Error Handling Tests ========== + + @Test + @SneakyThrows + public void testDSLWithEmptyFilterArray() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis with empty filter", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Empty filter should still produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithMalformedFilterJSON() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[malformed-json]" + ), + ActionListener.wrap(response -> fail("Should have failed with malformed filter JSON"), e -> { + MatcherAssert.assertThat(e.getMessage(), containsString("Invalid 'filter' parameter")); + MatcherAssert.assertThat(e.getMessage(), containsString("must be a valid JSON array of strings")); + }) + ); + } + + @Test + @SneakyThrows + public void testDSLWithBoolQueryInRawDSL() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + String boolDSL = """ + { + "bool": { + "must": [ + {"term": {"status": "error"}} + ], + "should": [ + {"match": {"message": "timeout"}}, + {"match": {"message": "connection"}} + ], + "must_not": [ + {"term": {"level": 1}} + ], + "filter": [ + {"range": {"@timestamp": {"gte": "2025-01-15T09:00:00Z"}}} + ], + "minimum_should_match": 1 + } + } + """; + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + boolDSL + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Bool query DSL should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithRawDSLTermsQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"terms\": {\"status\": [\"error\", \"warning\"]}}" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Terms query DSL should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithRawDSLMultiMatchQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"multi_match\": {\"query\": \"error timeout\", \"fields\": [\"message\", \"description\"]}}" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Multi-match query DSL should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLQueryPrecedenceOverFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // When both dsl and filter are provided, dsl should take precedence + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"term\": {\"status\": \"error\"}}", + "filter", + "[\"{'term': {'status': 'info'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("DSL should take precedence over filter", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithTermsQueryNonListValue() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + // Test terms query with non-list value (should be ignored) + Map filterMap = Map.of("terms", Map.of("status", "error")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + // Should not contain terms query since value is not a list + assertFalse(queryBuilder.toString().contains("terms")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithMultiMatchQueryMissingFields() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + // Test multi_match query with missing fields (should be ignored) + Map filterMap = Map.of("multi_match", Map.of("query", "error message")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + // Should not contain multi_match query since fields is missing + assertFalse(queryBuilder.toString().contains("multi_match")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithMultiMatchQueryMissingQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + // Test multi_match query with missing query (should be ignored) + Map filterMap = Map.of("multi_match", Map.of("fields", List.of("message", "description"))); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + // Should not contain multi_match query since query is missing + assertFalse(queryBuilder.toString().contains("multi_match")); + } + + @Test + @SneakyThrows + public void testBuildQueryFromMapWithMultiMatchQueryNonListFields() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class + .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); + buildQueryMethod.setAccessible(true); + + // Test multi_match query with non-list fields (should be ignored) + Map filterMap = Map.of("multi_match", Map.of("query", "error message", "fields", "message")); + org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); + + buildQueryMethod.invoke(tool, filterMap, queryBuilder); + + assertNotNull(queryBuilder); + // Should not contain multi_match query since fields is not a list + assertFalse(queryBuilder.toString().contains("multi_match")); + } +} diff --git a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java new file mode 100644 index 00000000..c99f5ec7 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java @@ -0,0 +1,448 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; + +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import lombok.SneakyThrows; + +public class DataDistributionToolIT extends BaseAgentToolsIT { + + public static String requestBodyResourceFile = + "org/opensearch/agent/tools/register_flow_agent_of_data_distribution_tool_request_body.json"; + public String registerAgentRequestBody; + public static String TEST_DATA_INDEX_NAME = "test_data_distribution_index"; + + private String agentId; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + prepareDataIndex(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(requestBodyResourceFile).toURI())); + agentId = createAgent(registerAgentRequestBody); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + @SneakyThrows + private void prepareDataIndex() { + createIndexWithConfiguration( + TEST_DATA_INDEX_NAME, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"@timestamp\": {\n" + + " \"type\": \"date\",\n" + + " \"format\": \"yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis\"\n" + + " },\n" + + " \"status\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"level\": {\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " \"host\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"response_time\": {\n" + + " \"type\": \"float\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + + // Add baseline data (09:00:00 to 10:00:00) + addDocToIndex( + TEST_DATA_INDEX_NAME, + "base1", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 09:30:00", "success", 1, "server-01", 120.5) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "base2", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 09:45:00", "success", 1, "server-02", 95.2) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "base3", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 09:50:00", "info", 2, "server-01", 110.8) + ); + + // Add selection data (10:00:00 to 11:00:00) + addDocToIndex( + TEST_DATA_INDEX_NAME, + "sel1", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 10:15:00", "error", 3, "server-01", 250.3) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "sel2", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 10:30:00", "error", 4, "server-02", 180.7) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "sel3", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 10:45:00", "warning", 2, "server-03", 140.1) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "sel4", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 10:50:00", "error", 3, "server-01", 300.5) + ); + } + + @SneakyThrows + public void testDataDistributionToolSingleAnalysis() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolComparisonAnalysis() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"baselineTimeRangeStart\": \"2025-01-01 09:00:00\", \"baselineTimeRangeEnd\": \"2025-01-01 10:00:00\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"comparisonAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"success\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.67},{\"value\":\"info\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"1\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.67},{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.33},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"110.8\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"95.2\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"120.5\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.33},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'term': {'status': 'error'}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolMissingRequiredParameters() { + Exception exception = assertThrows(Exception.class, () -> executeAgent(agentId, "{\"parameters\": {\"index\": \"test_index\"}}")); + MatcherAssert.assertThat(exception.getMessage(), containsString("Unable to parse time string")); + } + + @SneakyThrows + public void testDataDistributionToolInvalidIndex() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"non_existent_index\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("no such index")); + } + + @SneakyThrows + public void testDataDistributionToolPPLSingleAnalysis() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"queryType\": \"ppl\", \"ppl\": \"source=%s\"}}", + TEST_DATA_INDEX_NAME, + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2.0\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4.0\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolPPLComparisonAnalysis() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"baselineTimeRangeStart\": \"2025-01-01 09:00:00\", \"baselineTimeRangeEnd\": \"2025-01-01 10:00:00\", \"queryType\": \"ppl\", \"ppl\": \"source=%s\"}}", + TEST_DATA_INDEX_NAME, + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"comparisonAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"success\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.67},{\"value\":\"info\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"1.0\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.67},{\"value\":\"3.0\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2.0\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.33},{\"value\":\"4.0\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"110.8\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"95.2\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"120.5\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.33},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolPPLWithCustomQuery() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"queryType\": \"ppl\", \"ppl\": \"source=%s | where level > 2\"}}", + TEST_DATA_INDEX_NAME, + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4.0\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithDSLQueryType() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"queryType\": \"dsl\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithMultipleFilters() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'term': {'status': 'error'}}\\\", \\\"{'range': {'level': {'gte': 3}}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithCustomSize() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"size\": \"500\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithCustomTimeField() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"timeField\": \"@timestamp\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithRangeFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'range': {'response_time': {'gte': 150.0}}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithMatchFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'match': {'status': 'error'}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithRawDSLQuery() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"dsl\": \"{\\\"bool\\\": {\\\"must\\\": [{\\\"term\\\": {\\\"status\\\": \\\"error\\\"}}]}}\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithExistsFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'exists': {'field': 'response_time'}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolInvalidFilterFormat() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"invalid-json\"}}", + TEST_DATA_INDEX_NAME + ) + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid 'filter' parameter")); + } + + @SneakyThrows + public void testDataDistributionToolInvalidSizeParameter() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"size\": \"not-a-number\"}}", + TEST_DATA_INDEX_NAME + ) + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid 'size' parameter")); + } + + @SneakyThrows + public void testDataDistributionToolInvalidTimeFormat() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"invalid-time-format\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\"}}", + TEST_DATA_INDEX_NAME + ) + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("Unable to parse time string")); + } + + @SneakyThrows + public void testDataDistributionToolPPLWithComplexQuery() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"queryType\": \"ppl\", \"ppl\": \"source=%s | where level > 2\"}}", + TEST_DATA_INDEX_NAME, + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4.0\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_data_distribution_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_data_distribution_tool_request_body.json new file mode 100644 index 00000000..31aba32a --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_data_distribution_tool_request_body.json @@ -0,0 +1,10 @@ +{ + "name": "Test_data_distribution_tool_flow_agent", + "type": "flow", + "tools": [ + { + "type": "DataDistributionTool", + "parameters": {} + } + ] +} \ No newline at end of file From b2556bc292b44d0a74d4225eabb54b3787513560 Mon Sep 17 00:00:00 2001 From: Xinyuan Lu Date: Fri, 19 Sep 2025 15:42:04 +0800 Subject: [PATCH 09/30] Add more information in ppl tool when passing to sagemaker (#636) * apply multiply Signed-off-by: xinyual * add mappings Signed-off-by: xinyual * apply spotless Signed-off-by: xinyual * fix payload Signed-off-by: xinyual * apply spotless Signed-off-by: xinyual * fix IT Signed-off-by: xinyual * add ut Signed-off-by: xinyual * fix payload Signed-off-by: xinyual * use ml commons Signed-off-by: xinyual * Update src/main/java/org/opensearch/agent/tools/PPLTool.java Co-authored-by: zane-neo Signed-off-by: Xinyuan Lu * fix comment Signed-off-by: xinyual * fix comment Signed-off-by: xinyual --------- Signed-off-by: xinyual Signed-off-by: Xinyuan Lu Co-authored-by: zane-neo --- .../org/opensearch/agent/tools/PPLTool.java | 209 +++++++++++------- .../opensearch/agent/tools/PPLToolTests.java | 13 ++ .../org/opensearch/integTest/PPLToolIT.java | 2 +- 3 files changed, 137 insertions(+), 87 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index 82ec6871..15bdc7c8 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -7,6 +7,7 @@ import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.utils.ToolUtils.NO_ESCAPE_PARAMS; import java.io.IOException; import java.io.InputStream; @@ -14,6 +15,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -22,10 +24,14 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.StringJoiner; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.math.NumberUtils; @@ -33,6 +39,7 @@ import org.apache.spark.sql.types.DataType; import org.json.JSONObject; import org.opensearch.OpenSearchStatusException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; import org.opensearch.action.search.SearchRequest; @@ -82,6 +89,9 @@ public class PPLTool implements WithModelTool { private static final String DEFAULT_DESCRIPTION = "\"Use this tool when user ask question based on the data in the cluster or parse user statement about which index to use in a conversion.\nAlso use this tool when question only contains index information.\n1. If uesr question contain both question and index name, the input parameters are {'question': UserQuestion, 'index': IndexName}.\n2. If user question contain only question, the input parameter is {'question': UserQuestion}.\n3. If uesr question contain only index name, find the original human input from the conversation histroy and formulate parameter as {'question': UserQuestion, 'index': IndexName}\nThe index name should be exactly as stated in user's input."; + private static final String TABLE_INFO_KEY = "table_info"; + private static final String MAPPING_KEY = "mappings"; + @Setter private String name = TYPE; @Getter @@ -201,27 +211,50 @@ public void run(Map originalParameters, ActionListener li Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String tenantId = parameters.get(TENANT_ID_FIELD); extractFromChatParameters(parameters); - String indexName = getIndexNameFromParameters(parameters); - if (StringUtils.isBlank(indexName)) { + List indices = Optional + .ofNullable(getIndexNameFromParameters(parameters, "index")) + .filter(list -> !list.isEmpty()) + .orElseGet(() -> getIndexNameFromParameters(parameters, this.previousToolKey + ".output")); + if (indices.isEmpty()) { throw new IllegalArgumentException( "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" ); } String question = parameters.get("question"); - if (StringUtils.isBlank(indexName) || StringUtils.isBlank(question)) { + if (StringUtils.isBlank(question)) { throw new IllegalArgumentException("Parameter index and question can not be null or empty."); } - if (indexName.startsWith(".")) { - throw new IllegalArgumentException( - "PPLTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " - + indexName - ); + for (String index : indices) { + if (index.startsWith(".")) { + throw new IllegalArgumentException( + "PPLTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + + index + ); + } } - ActionListener actionsAfterTableinfo = ActionListener.wrap(tableInfo -> { - String prompt = constructPrompt(tableInfo, question.strip(), indexName); + ActionListener> actionsAfterTableinfo = ActionListener.wrap(indexInfo -> { + if (Objects.isNull(indexInfo.get(TABLE_INFO_KEY)) || Objects.isNull(indexInfo.get(MAPPING_KEY))) { + log.error("The table info and mappings are missing in: {}", indexInfo); + listener.onFailure(new RuntimeException("The table info and mappings are missing in: " + indexInfo)); + } + String tableInfo = indexInfo.get(TABLE_INFO_KEY).toString(); + String prompt = constructPrompt(tableInfo, question.strip(), indices); + Map reformattedInput = Map + .of( + "prompt", + prompt, + "mappings", + indexInfo.get(MAPPING_KEY), + "os_version", + Version.CURRENT.toString(), + "current_time", + Instant.now().toString(), + "datasourceType", + parameters.getOrDefault("type", "Opensearch") + ); RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet .builder() - .parameters(Map.of("prompt", prompt, "datasourceType", parameters.getOrDefault("type", "Opensearch"))) + .parameters(Map.of("prompt", formatString(reformattedInput), NO_ESCAPE_PARAMS, "prompt")) .build(); ActionRequest request = new MLPredictionTaskRequest( modelId, @@ -238,7 +271,7 @@ public void run(Map originalParameters, ActionListener li listener.onFailure(new IllegalStateException("Remote endpoint fails to inference.")); return; } - String ppl = parseOutput(dataAsMap.get("response"), indexName); + String ppl = parseOutput(dataAsMap.get("response")); if (!this.execute) { Map ret = ImmutableMap.of("ppl", ppl); listener.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(ret))); @@ -281,6 +314,7 @@ public void run(Map originalParameters, ActionListener li } ); + // Logic for schema/samples as input if (parameters.containsKey("schema") && parameters.containsKey("samples") && Objects.equals(parameters.getOrDefault("type", ""), "s3")) { @@ -291,45 +325,61 @@ public void run(Map originalParameters, ActionListener li transferS3SchemaFormat(schema), (Map) samples.get(0) ); - actionsAfterTableinfo.onResponse(tableInfo); + actionsAfterTableinfo.onResponse(Map.of(TABLE_INFO_KEY, tableInfo, MAPPING_KEY, gson.toJson(schema))); } catch (Exception e) { log.info("fail to get table info for s3"); actionsAfterTableinfo.onFailure(e); } - return; } - GetMappingsRequest getMappingsRequest = buildGetMappingRequest(indexName); - client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> { - Map mappings = getMappingsResponse.getMappings(); - if (mappings.isEmpty()) { - throw new IllegalArgumentException("No matching mapping with index name: " + indexName); - } - String firstIndexName = (String) mappings.keySet().toArray()[0]; - SearchRequest searchRequest = buildSearchRequest(firstIndexName); - client.search(searchRequest, ActionListener.wrap(searchResponse -> { - SearchHit[] searchHits = searchResponse.getHits().getHits(); - String tableInfo = constructTableInfo(searchHits, mappings); - actionsAfterTableinfo.onResponse(tableInfo); + CountDownLatch latch = new CountDownLatch(indices.size()); + ConcurrentHashMap tableInfos = new ConcurrentHashMap<>(); + ConcurrentHashMap mappingInfos = new ConcurrentHashMap<>(); + for (String index : indices) { + GetMappingsRequest getMappingsRequest = buildGetMappingRequest(index); + client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> { + Map mappings = getMappingsResponse.getMappings(); + if (mappings.isEmpty()) { + throw new IllegalArgumentException("No matching mapping with index name: " + index); + } + String firstIndexName = (String) mappings.keySet().toArray()[0]; + SearchRequest searchRequest = buildSearchRequest(firstIndexName); + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + SearchHit[] searchHits = searchResponse.getHits().getHits(); + Map finalMappings = new HashMap<>(); + for (MappingMetadata mappingMetadata : mappings.values()) { + Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); + MergeRuleHelper.merge(mappingSource, finalMappings); + } + String tableInfo = constructTableInfo(searchHits, finalMappings); + tableInfos.put(index, tableInfo); + mappingInfos.put(index, finalMappings); + latch.countDown(); + if (latch.getCount() == 0) { + String mergedTableInfo = mergeTableInfo(tableInfos); + actionsAfterTableinfo.onResponse(Map.of(TABLE_INFO_KEY, mergedTableInfo, MAPPING_KEY, mappingInfos)); + } + }, e -> { + log.error(String.format(Locale.ROOT, "fail to search index: %s with error: %s", firstIndexName, e.getMessage()), e); + listener.onFailure(e); + })); }, e -> { - log.error(String.format(Locale.ROOT, "fail to search model: %s with error: %s", modelId, e.getMessage()), e); - listener.onFailure(e); + log.error(String.format(Locale.ROOT, "fail to get mapping of index: %s with error: %s", indices, e.getMessage()), e); + String errorMessage = e.getMessage(); + if (errorMessage.contains("no such index")) { + listener + .onFailure( + new IllegalArgumentException( + "Return this final answer to human directly and do not use other tools: 'Please provide the existing index name(s)'. Please try to directly send this message to human to ask for index name" + ) + ); + } else { + listener.onFailure(e); + } })); - }, e -> { - log.error(String.format(Locale.ROOT, "fail to get mapping of index: %s with error: %s", indexName, e.getMessage()), e); - String errorMessage = e.getMessage(); - if (errorMessage.contains("no such index")) { - listener - .onFailure( - new IllegalArgumentException( - "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" - ) - ); - } else { - listener.onFailure(e); - } - })); + } + } @Override @@ -515,17 +565,7 @@ private String constructTableInfoByPPLResultForSpark(Map schema, } - private String constructTableInfo(SearchHit[] searchHits, Map mappings) throws PrivilegedActionException { - if (mappings.keySet().size() == 0) { - throw new IllegalArgumentException( - "The querying index doesn't have mapping metadata, please add data to it or using another index." - ); - } - Map allFields = new HashMap<>(); - for (MappingMetadata mappingMetadata : mappings.values()) { - Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); - MergeRuleHelper.merge(mappingSource, allFields); - } + private String constructTableInfo(SearchHit[] searchHits, Map allFields) throws PrivilegedActionException { Map fieldsToType = new HashMap<>(); ToolHelper.extractFieldNamesTypes(allFields, fieldsToType, "", false); @@ -560,8 +600,8 @@ private String constructTableInfo(SearchHit[] searchHits, Map indexInfo = ImmutableMap.of("mappingInfo", tableInfo, "question", question, "indexName", indexName); + private String constructPrompt(String tableInfo, String question, List indices) { + Map indexInfo = ImmutableMap.of("mappingInfo", tableInfo, "question", question, "indexName", indices.toString()); StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}"); return substitutor.replace(contextPrompt); } @@ -616,7 +656,7 @@ private void extractFromChatParameters(Map parameters) { } } - private String parseOutput(String llmOutput, String indexName) { + private String parseOutput(String llmOutput) { String ppl; Pattern pattern = Pattern.compile("((.|[\\r\\n])+?)"); // For ppl like source=a \n | fields b Matcher matcher = pattern.matcher(llmOutput); @@ -626,32 +666,10 @@ private String parseOutput(String llmOutput, String indexName) { } else { // logic for only ppl returned int sourceIndex = llmOutput.indexOf("source="); int describeIndex = llmOutput.indexOf("describe "); - if (sourceIndex != -1) { - llmOutput = llmOutput.substring(sourceIndex); - - // Splitting the string at "|" - String[] lists = llmOutput.split("\\|"); - - // Modifying the first element - if (lists.length > 0) { - lists[0] = "source=" + indexName; - } - - // Joining the string back together - ppl = String.join("|", lists); - } else if (describeIndex != -1) { - llmOutput = llmOutput.substring(describeIndex); - String[] lists = llmOutput.split("\\|"); - - // Modifying the first element - if (lists.length > 0) { - lists[0] = "describe " + indexName; - } - - // Joining the string back together - ppl = String.join("|", lists); - } else { + if (sourceIndex == -1 && describeIndex == -1) { throw new IllegalArgumentException("The returned PPL: " + llmOutput + " has wrong format"); + } else { + ppl = llmOutput; } } if (this.pplModelType != PPLModelType.FINETUNE) { @@ -670,12 +688,26 @@ private String parseOutput(String llmOutput, String indexName) { return ppl; } - private String getIndexNameFromParameters(Map parameters) { - String indexName = parameters.getOrDefault("index", ""); - if (!StringUtils.isBlank(this.previousToolKey) && StringUtils.isBlank(indexName)) { - indexName = parameters.getOrDefault(this.previousToolKey + ".output", ""); // read index name from previous key + private List getIndexNameFromParameters(Map parameters, String key) { + if (!parameters.containsKey(key)) { + return List.of(); + } + String indexName = parameters.get(key); + try { + List list = gson.fromJson(indexName, List.class); + return list.stream().map(Object::toString).map(String::trim).collect(Collectors.toList()); + } catch (Exception e) { + return List.of(indexName.trim()); + } + } + + private String mergeTableInfo(ConcurrentHashMap tableInfos) { + StringBuilder mergedTableInfo = new StringBuilder(); + for (Map.Entry entry : tableInfos.entrySet()) { + mergedTableInfo.append(entry.getKey()).append("\n"); + mergedTableInfo.append(entry.getValue()).append("\n"); } - return indexName.trim(); + return mergedTableInfo.toString(); } private Map transferS3SchemaFormat(Map originalSchema) { @@ -716,4 +748,9 @@ public static String redactCloudwatchUrl(String input) { return matcher.replaceAll(""); } + + public String formatString(Map targetMap) { + String mapString = gson.toJson(gson.toJson(targetMap)); + return mapString.substring(1, mapString.length() - 1); + } } diff --git a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java index fcd04202..140022be 100644 --- a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java @@ -10,6 +10,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.when; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -324,6 +325,18 @@ public void testTool_ForSparkInputWithStructInput() { } + @Test + public void testTool_basic() { + PPLTool tool = PPLTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "previous_tool_name", "previousTool", "head", "-5")); + assertEquals(tool.getDescription(), PPLTool.Factory.getInstance().getDefaultDescription()); + assertEquals(tool.getType(), PPLTool.Factory.getInstance().getDefaultType()); + assertEquals(null, PPLTool.Factory.getInstance().getDefaultVersion()); + assertEquals(List.of(COMMON_MODEL_ID_FIELD), PPLTool.Factory.getInstance().getAllModelKeys()); + + } + @Test public void testTool_withPreviousInput() { PPLTool tool = PPLTool.Factory diff --git a/src/test/java/org/opensearch/integTest/PPLToolIT.java b/src/test/java/org/opensearch/integTest/PPLToolIT.java index cf576be8..3d6120ee 100644 --- a/src/test/java/org/opensearch/integTest/PPLToolIT.java +++ b/src/test/java/org/opensearch/integTest/PPLToolIT.java @@ -54,7 +54,7 @@ public void testPPLTool() { String agentId = registerAgent(); String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee\"}}"); assertEquals( - "{\"ppl\":\"source\\u003demployee| where age \\u003e 56 | stats COUNT() as cnt\"," + "{\"ppl\":\"source\\u003demployee | where age \\u003e 56 | stats COUNT() as cnt\"," + "\"executionResult\":\"{\\n \\\"schema\\\": [\\n {\\n \\\"name\\\": \\\"cnt\\\",\\n " + "\\\"type\\\": \\\"int\\\"\\n }\\n ],\\n \\\"datarows\\\": [\\n [\\n 0\\n ]\\n ],\\n " + "\\\"total\\\": 1,\\n \\\"size\\\": 1\\n}\"}", From 1f866c0158fe468ac7cf059759b0a08f5190370a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E4=BD=B3=E5=A6=82=EF=BC=88Jiaru=20Jiang=EF=BC=89?= Date: Tue, 30 Sep 2025 16:46:02 +0800 Subject: [PATCH 10/30] fix: delete-single-baseline (#641) Signed-off-by: Jiaru Jiang --- .../agent/tools/DataDistributionTool.java | 19 +++++++++----- .../integTest/DataDistributionToolIT.java | 26 +++++++++---------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java index d0300906..1e6dde61 100644 --- a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java +++ b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java @@ -1121,22 +1121,27 @@ private List formatComparisonSummary(List differ Set allKeys = new HashSet<>(diff.selectionDist.keySet()); allKeys.addAll(diff.baselineDist.keySet()); + boolean hasBaseline = !diff.baselineDist.isEmpty(); + List changes = allKeys.stream().map(value -> { double selectionPercentage = Math.round(diff.selectionDist.getOrDefault(value, 0.0) * PERCENTAGE_MULTIPLIER) / PERCENTAGE_MULTIPLIER; - double baselinePercentage = Math.round(diff.baselineDist.getOrDefault(value, 0.0) * PERCENTAGE_MULTIPLIER) - / PERCENTAGE_MULTIPLIER; + Double baselinePercentage = hasBaseline + ? Math.round(diff.baselineDist.getOrDefault(value, 0.0) * PERCENTAGE_MULTIPLIER) / PERCENTAGE_MULTIPLIER + : null; return new ChangeItem(value, selectionPercentage, baselinePercentage); }).collect(Collectors.toList()); List topChanges = changes .stream() .sorted( - (a, b) -> Double - .compare( - Math.max(b.baselinePercentage, b.selectionPercentage), - Math.max(a.baselinePercentage, a.selectionPercentage) - ) + (a, b) -> hasBaseline + ? Double + .compare( + Math.max(b.baselinePercentage != null ? b.baselinePercentage : 0.0, b.selectionPercentage), + Math.max(a.baselinePercentage != null ? a.baselinePercentage : 0.0, a.selectionPercentage) + ) + : Double.compare(b.selectionPercentage, a.selectionPercentage) ) .limit(TOP_CHANGES_LIMIT) .collect(Collectors.toList()); diff --git a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java index c99f5ec7..79154570 100644 --- a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java +++ b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java @@ -131,7 +131,7 @@ public void testDataDistributionToolSingleAnalysis() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; assertEquals(expectedResult, result); } @@ -165,7 +165,7 @@ public void testDataDistributionToolWithFilter() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; assertEquals(expectedResult, result); } @@ -201,7 +201,7 @@ public void testDataDistributionToolPPLSingleAnalysis() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2.0\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4.0\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.5},{\"value\":\"2.0\",\"selectionPercentage\":0.25},{\"value\":\"4.0\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; assertEquals(expectedResult, result); } @@ -237,7 +237,7 @@ public void testDataDistributionToolPPLWithCustomQuery() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4.0\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67},{\"value\":\"4.0\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; assertEquals(expectedResult, result); } @@ -254,7 +254,7 @@ public void testDataDistributionToolWithDSLQueryType() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; assertEquals(expectedResult, result); } @@ -271,7 +271,7 @@ public void testDataDistributionToolWithMultipleFilters() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; assertEquals(expectedResult, result); } @@ -288,7 +288,7 @@ public void testDataDistributionToolWithCustomSize() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; assertEquals(expectedResult, result); } @@ -305,7 +305,7 @@ public void testDataDistributionToolWithCustomTimeField() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; assertEquals(expectedResult, result); } @@ -322,7 +322,7 @@ public void testDataDistributionToolWithRangeFilter() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; assertEquals(expectedResult, result); } @@ -339,7 +339,7 @@ public void testDataDistributionToolWithMatchFilter() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; assertEquals(expectedResult, result); } @@ -356,7 +356,7 @@ public void testDataDistributionToolWithRawDSLQuery() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; assertEquals(expectedResult, result); } @@ -373,7 +373,7 @@ public void testDataDistributionToolWithExistsFilter() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; assertEquals(expectedResult, result); } @@ -442,7 +442,7 @@ public void testDataDistributionToolPPLWithComplexQuery() { ); String expectedResult = - "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"4.0\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67,\"baselinePercentage\":0.0},{\"value\":\"server-02\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.33,\"baselinePercentage\":0.0}]}]}"; + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67},{\"value\":\"4.0\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; assertEquals(expectedResult, result); } } From f6f9cf680e5ef7c68f321d3b4cb964549b73e37f Mon Sep 17 00:00:00 2001 From: opensearch-ci <83309141+opensearch-ci-bot@users.noreply.github.com> Date: Tue, 7 Oct 2025 15:54:50 -0400 Subject: [PATCH 11/30] Add release notes for 3.3.0 (#642) Signed-off-by: opensearch-ci --- ...opensearch-skills.release-notes-3.3.0.0.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 release-notes/opensearch-skills.release-notes-3.3.0.0.md diff --git a/release-notes/opensearch-skills.release-notes-3.3.0.0.md b/release-notes/opensearch-skills.release-notes-3.3.0.0.md new file mode 100644 index 00000000..83b5be43 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.3.0.0.md @@ -0,0 +1,19 @@ +## Version 3.3.0 Release Notes + +Compatible with OpenSearch and OpenSearch Dashboards version 3.3.0 + +### Features +* Log patterns analysis tool ([#625](https://github.com/opensearch-project/skills/pull/625)) +* Data Distribution Tool ([#634](https://github.com/opensearch-project/skills/pull/634)) + +### Enhancements +* Add more information in ppl tool when passing to sagemaker ([#636](https://github.com/opensearch-project/skills/pull/636)) + +### Bug Fixes +* Delete-single-baseline ([#641](https://github.com/opensearch-project/skills/pull/641)) + +### Infrastructure +* Update System.env syntax for Gradle 9 compatibility ([#630](https://github.com/opensearch-project/skills/pull/630)) + +### Maintenance +* Increment version to 3.3.0-SNAPSHOT ([#626](https://github.com/opensearch-project/skills/pull/626)) \ No newline at end of file From 10d97accacda059eaee8044e6a74da39426134ef Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 8 Oct 2025 15:14:46 +0800 Subject: [PATCH 12/30] Fix websearchtool issue (#639) * Fix issue in WebSearchTool Signed-off-by: zane-neo * Optimize code Signed-off-by: zane-neo * Fix build error Signed-off-by: zane-neo * Fix CVE Signed-off-by: zane-neo * remove unused file Signed-off-by: zane-neo * Fix failure ITs Signed-off-by: zane-neo --------- Signed-off-by: zane-neo --- .gitignore | 1 + build.gradle | 11 +- .../opensearch/agent/tools/WebSearchTool.java | 540 +++++++++++------- .../SearchAnomalyDetectorsToolTests.java | 1 + .../integTest/DataDistributionToolIT.java | 54 +- .../org/opensearch/integTest/PPLToolIT.java | 6 +- 6 files changed, 385 insertions(+), 228 deletions(-) diff --git a/.gitignore b/.gitignore index 722e14d1..21b99ea6 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ out/ .settings .vscode bin/ +.factorypath diff --git a/build.gradle b/build.gradle index 423333bb..1eda7969 100644 --- a/build.gradle +++ b/build.gradle @@ -138,7 +138,7 @@ dependencies { compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.23.1" compileOnly group: 'org.json', name: 'json', version: '20240303' compileOnly("com.google.guava:guava:33.2.1-jre") - compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.16.0' + compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: "${versions.commonslang}" compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.12.0' compileOnly group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") @@ -148,6 +148,13 @@ dependencies { compileOnly ('com.jayway.jsonpath:json-path:2.9.0') { exclude group: 'net.minidev', module: 'json-smart' } + compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}") { + exclude(group: 'org.reactivestreams', module: 'reactive-streams') + exclude(group: 'org.slf4j', module: 'slf4j-api') + } + compileOnly(group: 'software.amazon.awssdk', name: 'http-client-spi', version: "${versions.aws}") + compileOnly(group: 'software.amazon.awssdk', name: 'utils', version: "${versions.aws}") + compileOnly(group: 'software.amazon.awssdk', name: 'sdk-core', version: "${versions.aws}") spark 'org.apache.spark:spark-sql-api_2.13:3.5.4' spark ('org.apache.spark:spark-core_2.13:3.5.4') { @@ -194,7 +201,7 @@ dependencies { testImplementation group: 'org.json', name: 'json', version: '20240303' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.14.2' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0' - testImplementation("net.bytebuddy:byte-buddy:1.17.5") + testImplementation("net.bytebuddy:byte-buddy:1.17.7") testImplementation("net.bytebuddy:byte-buddy-agent:1.17.5") testImplementation 'org.junit.jupiter:junit-jupiter-api:5.11.2' testImplementation 'org.mockito:mockito-junit-jupiter:5.14.2' diff --git a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java index 047d6769..1d5500f6 100644 --- a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java +++ b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java @@ -8,6 +8,9 @@ import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -17,24 +20,25 @@ import java.util.Optional; import org.apache.commons.lang3.math.NumberUtils; -import org.apache.hc.client5.http.classic.methods.HttpGet; -import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; -import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; -import org.apache.hc.client5.http.impl.classic.HttpClients; import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.io.entity.EntityUtils; import org.jsoup.Connection; import org.jsoup.Jsoup; import org.jsoup.nodes.Document; import org.jsoup.nodes.Element; import org.jsoup.select.Elements; +import org.opensearch.OpenSearchStatusException; import org.opensearch.agent.ToolPlugin; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.threadpool.ThreadPool; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import com.google.common.collect.ImmutableMap; import com.google.gson.JsonArray; @@ -45,6 +49,14 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.http.async.AsyncExecuteRequest; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; @Log4j2 @Setter @@ -77,6 +89,28 @@ public class WebSearchTool implements Tool { + "}"; public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, "strict", false); + public static final String NEXT_PAGE = "next_page"; + public static final String ENGINE_ID = "engine_id"; + public static final String OFFSET = "offset"; + public static final String DUCKDUCKGO = "duckduckgo"; + public static final String GOOGLE = "google"; + public static final String BING = "bing"; + public static final String CUSTOM = "custom"; + public static final String ITEMS = "items"; + public static final String ENGINE = "engine"; + public static final String ENDPOINT = "endpoint"; + public static final String API_KEY = "api_key"; + public static final String CUSTOM_API = "custom_api"; + public static final String AUTHORIZATION = "Authorization"; + public static final String TITLE = "title"; + public static final String URL = "url"; + public static final String CONTENT = "content"; + public static final String QUERY = "query"; + public static final String QUESTION = "question"; + public static final String QUERY_KEY = "query_key"; + public static final String LIMIT_KEY = "limit_key"; + public static final String CUSTOM_RES_URL_JSONPATH = "custom_res_url_jsonpath"; + public static final String START = "start"; @Setter @Getter @@ -86,13 +120,14 @@ public class WebSearchTool implements Tool { private String description = DEFAULT_DESCRIPTION; @Getter private String version; - private CloseableHttpClient httpClient; + private final SdkAsyncHttpClient httpClient; private final ThreadPool threadPool; private Map attributes; public WebSearchTool(ThreadPool threadPool) { - this.httpClient = HttpClients.createDefault(); + // Use 1s for connection timeout, 3s for read timeout, 30 for max connections of httpclient. + this.httpClient = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(1), Duration.ofSeconds(3), 30); this.threadPool = threadPool; this.attributes = new HashMap<>(); attributes.put(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); @@ -101,103 +136,101 @@ public WebSearchTool(ThreadPool threadPool) { @Override public void run(Map originalParameters, ActionListener listener) { - Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); // common search parameters - String query = parameters.getOrDefault("query", parameters.get("question")).replaceAll(" ", "+"); - String engine = parameters.getOrDefault("engine", "google"); - String endpoint = parameters.getOrDefault("endpoint", getDefaultEndpoint(engine)); - String apiKey = parameters.get("api_key"); - String nextPage = parameters.get("next_page"); + String query = parameters.getOrDefault(QUERY, parameters.get(QUESTION)).replaceAll(" ", "+"); + String engine = parameters.getOrDefault(ENGINE, GOOGLE); + String endpoint = parameters.getOrDefault(ENDPOINT, getDefaultEndpoint(engine)); + String apiKey = parameters.get(API_KEY); + String nextPage = parameters.get(NEXT_PAGE); // Google search parameters - String engineId = parameters.get("engine_id"); + String engineId = parameters.get(ENGINE_ID); // Custom search parameters - String authorization = parameters.get("Authorization"); - String queryKey = parameters.getOrDefault("query_key", "q"); - String offsetKey = parameters.getOrDefault("offset_key", "offset"); - String limitKey = parameters.getOrDefault("limit_key", "limit"); - String customResUrlJsonpath = parameters.get("custom_res_url_jsonpath"); + String authorization = parameters.get(AUTHORIZATION); + String queryKey = parameters.getOrDefault(QUERY_KEY, "q"); + String offsetKey = parameters.getOrDefault(OFFSET + "_key", OFFSET); + String limitKey = parameters.getOrDefault(LIMIT_KEY, "limit"); + String customResUrlJsonpath = parameters.get(CUSTOM_RES_URL_JSONPATH); + threadPool.executor(ToolPlugin.WEBSEARCH_CRAWLER_THREADPOOL).submit(() -> { - try { - String parsedNextPage = null; - if ("duckduckgo".equalsIgnoreCase(engine)) { - // duckduckgo has different approach to other APIs as it's not a standard public API. + String parsedNextPage; + if (DUCKDUCKGO.equalsIgnoreCase(engine)) { + // duckduckgo has different approach to other APIs as it's not a standard public API. + if (nextPage != null) { + fetchDuckDuckGoResult(nextPage, listener); + } else { + fetchDuckDuckGoResult(buildDDGEndpoint(getDefaultEndpoint(engine), query), listener); + } + } else { + SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder().method(SdkHttpMethod.GET); + if (GOOGLE.equalsIgnoreCase(engine)) { if (nextPage != null) { - fetchDuckDuckGoResult(nextPage, listener); + builder.uri(nextPage); + parsedNextPage = buildGoogleNextPage(endpoint, engineId, query, apiKey, nextPage); } else { - fetchDuckDuckGoResult(buildDDGEndpoint(getDefaultEndpoint(engine), query), listener); + builder.uri(buildGoogleUrl(endpoint, engineId, query, apiKey, 0)); + parsedNextPage = buildGoogleUrl(endpoint, engineId, query, apiKey, 10); } - } else { - HttpGet getRequest = null; - if ("google".equalsIgnoreCase(engine)) { - if (nextPage != null) { - getRequest = new HttpGet(nextPage); - parsedNextPage = buildGoogleNextPage(endpoint, engineId, query, apiKey, nextPage); - } else { - getRequest = new HttpGet(buildGoogleUrl(endpoint, engineId, query, apiKey, 0)); - parsedNextPage = buildGoogleUrl(endpoint, engineId, query, apiKey, 10); - } - } else if ("bing".equalsIgnoreCase(engine)) { - if (nextPage != null) { - getRequest = new HttpGet(nextPage); - parsedNextPage = buildBingNextPage(endpoint, query, nextPage); - } else { - getRequest = new HttpGet(buildBingUrl(endpoint, query, 0)); - parsedNextPage = buildBingUrl(endpoint, query, 10); - } - getRequest.addHeader("Ocp-Apim-Subscription-Key", apiKey); - } else if ("custom".equalsIgnoreCase(engine)) { - if (nextPage != null) { - getRequest = new HttpGet(nextPage); - parsedNextPage = buildCustomNextPage(endpoint, nextPage, queryKey, query, offsetKey, limitKey); - } else { - getRequest = new HttpGet(buildCustomUrl(endpoint, queryKey, query, offsetKey, 0, limitKey)); - parsedNextPage = buildCustomUrl(endpoint, queryKey, query, offsetKey, 10, limitKey); - } - getRequest.addHeader("Authorization", authorization); + } else if (BING.equalsIgnoreCase(engine)) { + if (nextPage != null) { + builder.uri(nextPage); + parsedNextPage = buildBingNextPage(endpoint, query, nextPage); } else { - // Search engine not supported. - listener.onFailure(new IllegalArgumentException("Unsupported search engine: %s".formatted(engine))); - return; + builder.uri(buildBingUrl(endpoint, query, 0)); + parsedNextPage = buildBingUrl(endpoint, query, 10); } - CloseableHttpResponse res = httpClient.execute(getRequest); - if (res.getCode() >= HttpStatus.SC_BAD_REQUEST) { - listener - .onFailure( - new IllegalArgumentException("Web search failed: %d %s".formatted(res.getCode(), res.getReasonPhrase())) - ); + builder.putHeader("Ocp-Apim-Subscription-Key", apiKey); + } else if (CUSTOM.equalsIgnoreCase(engine)) { + if (nextPage != null) { + builder.uri(nextPage); + parsedNextPage = buildCustomNextPage(endpoint, nextPage, queryKey, query, offsetKey, limitKey); } else { - String responseString = EntityUtils.toString(res.getEntity()); - parseResponse(responseString, authorization, parsedNextPage, engine, customResUrlJsonpath, listener); + builder.uri(buildCustomUrl(endpoint, queryKey, query, offsetKey, 0, limitKey)); + parsedNextPage = buildCustomUrl(endpoint, queryKey, query, offsetKey, 10, limitKey); } + builder.putHeader(AUTHORIZATION, authorization); + } else { + // Search engine not supported. + listener + .onFailure(new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported search engine: %s", engine))); + return; } - } catch (Exception e) { - listener.onFailure(new IllegalStateException("Web search failed: %s".formatted(e.getMessage()))); + SdkHttpFullRequest getRequest = builder.build(); + AsyncExecuteRequest executeRequest = AsyncExecuteRequest + .builder() + .request(getRequest) + .requestContentPublisher(new SimpleHttpContentPublisher(getRequest)) + .responseHandler( + new WebSearchResponseHandler(endpoint, authorization, parsedNextPage, engine, customResUrlJsonpath, listener) + ) + .build(); + httpClient.execute(executeRequest); } }); } catch (Exception e) { - listener.onFailure(new IllegalStateException("Web search failed: %s".formatted(e.getMessage()))); + listener.onFailure(new IllegalStateException(String.format(Locale.ROOT, "Web search failed: %s", e.getMessage()))); } } private String buildDDGEndpoint(String endpoint, String query) { - return "%s?q=%s".formatted(endpoint, query); + return String.format(Locale.ROOT, "%s?q=%s", endpoint, query); } private String buildGoogleNextPage(String endpoint, String engineId, String query, String apiKey, String currentPage) { - String[] offsetSplit = currentPage.split("&start="); + String[] offsetSplit = currentPage.split("&" + START + "="); int offset = NumberUtils.toInt(offsetSplit[1], 0) + 10; return buildGoogleUrl(endpoint, engineId, query, apiKey, offset); } private String buildGoogleUrl(String endpoint, String engineId, String query, String apiKey, int start) { - return "%s?q=%s&cx=%s&key=%s&start=%d".formatted(endpoint, query, engineId, apiKey, start); + return String.format(Locale.ROOT, "%s?q=%s&cx=%s&key=%s&" + START + "=%d", endpoint, query, engineId, apiKey, start); } private String buildBingNextPage(String endpoint, String query, String currentPage) { - String[] offsetSplit = currentPage.split("&offset="); + String[] offsetSplit = currentPage.split("&" + OFFSET + "="); int offset = NumberUtils.toInt(offsetSplit[1], 0) + 10; return buildBingUrl(endpoint, query, offset); } @@ -210,100 +243,28 @@ private String buildCustomNextPage( String offsetKey, String limitKey ) { - String[] pageSplit = currentPage.split("&%s=".formatted(offsetKey)); + String[] pageSplit = currentPage.split(String.format(Locale.ROOT, "&%s=", offsetKey)); int offsetValue = NumberUtils.toInt(pageSplit[1].split("&")[0], 0) + 10; return buildCustomUrl(endpoint, queryKey, query, offsetKey, offsetValue, limitKey); } private String buildCustomUrl(String endpoint, String queryKey, String query, String offsetKey, int offsetValue, String limitKey) { - return "%s?%s=%s&%s=%d&%s=10".formatted(endpoint, queryKey, query, offsetKey, offsetValue, limitKey); + return String.format(Locale.ROOT, "%s?%s=%s&%s=%d&%s=10", endpoint, queryKey, query, offsetKey, offsetValue, limitKey); } private String getDefaultEndpoint(String engine) { return switch (engine.toLowerCase(Locale.ROOT)) { - case "google" -> "https://customsearch.googleapis.com/customsearch/v1"; - case "bing" -> "https://api.bing.microsoft.com/v7.0/search"; - case "duckduckgo" -> "https://duckduckgo.com/html"; - case "custom" -> null; - default -> throw new IllegalArgumentException("Unsupported search engine: %s".formatted(engine)); + case GOOGLE -> "https://customsearch.googleapis.com/customsearch/v1"; + case BING -> "https://api.bing.microsoft.com/v7.0/search"; + case DUCKDUCKGO -> "https://duckduckgo.com/html"; + case CUSTOM -> null; + default -> throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported search engine: %s", engine)); }; } // pagination: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/page-results#paging-through-search-results private String buildBingUrl(String endpoint, String query, int offset) { - return "%s?q%s&textFormat=HTML&count=10&offset=%d".formatted(endpoint, query, offset); - } - - private void parseResponse( - String rawResponse, - String authorization, - String nextPage, - String engine, - String customResUrlJsonpath, - ActionListener listener - ) { - JsonObject rawJson = JsonParser.parseString(rawResponse).getAsJsonObject(); - switch (engine.toLowerCase(Locale.ROOT)) { - case "google": - parseGoogleResults(rawJson, nextPage, listener); - break; - case "bing": - parseBingResults(rawJson, nextPage, listener); - break; - case "custom": - List urls = JsonPath.read(rawResponse, customResUrlJsonpath); - parseCustomResults(urls, authorization, nextPage, listener); - break; - default: - listener.onFailure(new RuntimeException("Unsupported search engine: %s".formatted(engine))); - } - } - - private void parseGoogleResults(JsonObject googleResponse, String nextPage, ActionListener listener) { - Map results = new HashMap<>(); - results.put("next_page", nextPage); - // extract search results, each item is a search result: - // https://developers.google.com/custom-search/v1/reference/rest/v1/Search#result - JsonArray items = googleResponse.getAsJsonArray("items"); - List> crawlResults = new ArrayList<>(); - for (int i = 0; i < items.size(); i++) { - JsonObject item = items.get(i).getAsJsonObject(); - // extract the actual link for scrawl. - String link = item.get("link").getAsString(); - // extract title and content. - Map crawlResult = crawlPage(link, null); - crawlResults.add(crawlResult); - } - results.put("items", crawlResults); - listener.onResponse((T) StringUtils.gson.toJson(results)); - } - - private void parseBingResults(JsonObject bingResponse, String nextPage, ActionListener listener) { - Map results = new HashMap<>(); - results.put("next_page", nextPage); - List> crawlResults = new ArrayList<>(); - JsonArray values = bingResponse.get("webPages").getAsJsonObject().getAsJsonArray("value"); - for (int i = 0; i < values.size(); i++) { - JsonObject value = values.get(i).getAsJsonObject(); - String link = value.get("url").getAsString(); - Map crawlResult = crawlPage(link, null); - crawlResults.add(crawlResult); - } - results.put("items", crawlResults); - listener.onResponse((T) StringUtils.gson.toJson(results)); - } - - private void parseCustomResults(List urls, String authorization, String nextPage, ActionListener listener) { - Map results = new HashMap<>(); - results.put("next_page", nextPage); - List> crawlResults = new ArrayList<>(); - for (int i = 0; i < urls.size(); i++) { - String link = urls.get(i); - Map crawlResult = crawlPage(link, authorization); - crawlResults.add(crawlResult); - } - results.put("items", crawlResults); - listener.onResponse((T) StringUtils.gson.toJson(results)); + return String.format(Locale.ROOT, "%s?q%s&textFormat=HTML&count=10&" + OFFSET + "=%d", endpoint, query, offset); } private void fetchDuckDuckGoResult(String endpoint, ActionListener listener) { @@ -335,8 +296,8 @@ private void fetchDuckDuckGoResult(String endpoint, ActionListener listen Map crawlResult = crawlPage(link, null); crawlResults.add(crawlResult); } - results.put("next_page", nextPage); - results.put("items", crawlResults); + results.put(NEXT_PAGE, nextPage); + results.put(ITEMS, crawlResults); listener.onResponse((T) StringUtils.gson.toJson(results)); } catch (IOException e) { log.error("Failed to fetch duckduckgo results due to exception!"); @@ -378,48 +339,6 @@ private String getDDGNextPageLink(String endpoint, Document doc) { return sb.toString(); } - /** - * crawl a page and put the page content into the results map if it can be crawled successfully. - * - * @param url The url to crawl - */ - private Map crawlPage(String url, String authorization) { - try { - Connection connection = Jsoup.connect(url).timeout(10000).userAgent(USER_AGENT); - if (authorization != null) { - connection.header("Authorization", authorization); - } - Document doc = connection.get(); - Elements parentElements = doc.select("body"); - if (isCaptchaOrLoginPage(doc)) { - log.debug("Skipping {} - CAPTCHA required", url); - return null; - } - - Element bodyElement = parentElements.getFirst(); - String title = bodyElement.select("title").text(); - String content = bodyElement.text(); - return ImmutableMap.of("url", url, "title", title, "content", content); - } catch (Exception e) { - log.error("Failed to crawl link: {}", url); - return null; - } - } - - private boolean isCaptchaOrLoginPage(Document doc) { - String html = doc.html().toLowerCase(Locale.ROOT); - // 1. Check for CAPTCHA indicators - return !doc.select("input[name*='captcha'], input[id*='captcha']").isEmpty() || - // Google reCAPTCHA markers - !doc.select(".g-recaptcha, div[data-sitekey]").isEmpty() || - // CAPTCHA image patterns - !doc.select("img[src*='captcha'], img[src*='recaptcha']").isEmpty() || - // Text-based indicators - org.apache.commons.lang3.StringUtils.containsIgnoreCase(html, "verify you are human") || - // hCAPTCHA detection - !doc.select(".h-captcha").isEmpty(); - } - @Override public String getType() { return TYPE; @@ -427,46 +346,46 @@ public String getType() { @Override public boolean validate(Map parameters) { - String engine = parameters.get("engine"); + String engine = parameters.get(ENGINE); if (org.apache.commons.lang3.StringUtils.isEmpty(engine)) { return false; } - boolean isQueryEmpty = org.apache.commons.lang3.StringUtils.isEmpty(parameters.getOrDefault("query", parameters.get("question"))); + boolean isQueryEmpty = org.apache.commons.lang3.StringUtils.isEmpty(parameters.getOrDefault(QUERY, parameters.get(QUESTION))); if (isQueryEmpty) { log.warn("Query is empty"); return false; } boolean isEndpointEmpty = org.apache.commons.lang3.StringUtils - .isEmpty(parameters.getOrDefault("endpoint", getDefaultEndpoint(engine))); + .isEmpty(parameters.getOrDefault(ENDPOINT, getDefaultEndpoint(engine))); if (isEndpointEmpty) { log.warn("Endpoint is empty"); return false; } - if ("google".equalsIgnoreCase(engine)) { - boolean hasEngineIdAndApiKey = parameters.containsKey("engine_id") - && !parameters.get("engine_id").isEmpty() - && parameters.containsKey("api_key") - && !parameters.get("api_key").isEmpty(); + if (GOOGLE.equalsIgnoreCase(engine)) { + boolean hasEngineIdAndApiKey = parameters.containsKey(ENGINE_ID) + && !parameters.get(ENGINE_ID).isEmpty() + && parameters.containsKey(API_KEY) + && !parameters.get(API_KEY).isEmpty(); if (!hasEngineIdAndApiKey) { - log.warn("Google search engine_id or api_key is empty"); + log.warn("Google search" + ENGINE_ID + "or api_key is empty"); return false; } return true; - } else if ("duckduckgo".equalsIgnoreCase(engine)) { + } else if (DUCKDUCKGO.equalsIgnoreCase(engine)) { return true; - } else if ("bing".equalsIgnoreCase(engine)) { - boolean hasApiKey = org.apache.commons.lang3.StringUtils.isEmpty(parameters.get("api_key")); + } else if (BING.equalsIgnoreCase(engine)) { + boolean hasApiKey = org.apache.commons.lang3.StringUtils.isEmpty(parameters.get(API_KEY)); if (!hasApiKey) { log.warn("Bing search api_key is empty"); return false; } return true; - } else if ("custom".equalsIgnoreCase(engine)) { - String customApi = parameters.get("custom_api"); - String customResUrlJsonpath = parameters.get("custom_res_url_jsonpath"); + } else if (CUSTOM.equalsIgnoreCase(engine)) { + String customApi = parameters.get(CUSTOM_API); + String customResUrlJsonpath = parameters.get(CUSTOM_RES_URL_JSONPATH); if (org.apache.commons.lang3.StringUtils.isEmpty(customApi) || org.apache.commons.lang3.StringUtils.isEmpty(customResUrlJsonpath)) { log.warn("custom search API is empty or result json path is empty"); @@ -480,6 +399,48 @@ public boolean validate(Map parameters) { return false; } + /** + * crawl a page and put the page content into the results map if it can be crawled successfully. + * + * @param url The url to crawl + */ + public Map crawlPage(String url, String authorization) { + try { + Connection connection = Jsoup.connect(url).timeout(10000).userAgent(USER_AGENT); + if (authorization != null) { + connection.header(AUTHORIZATION, authorization); + } + Document doc = connection.get(); + Elements parentElements = doc.select("body"); + if (isCaptchaOrLoginPage(doc)) { + log.debug("Skipping {} - CAPTCHA required", url); + return null; + } + + Element bodyElement = parentElements.getFirst(); + String title = bodyElement.select(TITLE).text(); + String content = bodyElement.text(); + return ImmutableMap.of(URL, url, TITLE, title, CONTENT, content); + } catch (Exception e) { + log.error("Failed to crawl link: {}", url); + return null; + } + } + + private boolean isCaptchaOrLoginPage(Document doc) { + String html = doc.html().toLowerCase(Locale.ROOT); + // 1. Check for CAPTCHA indicators + return !doc.select("input[name*='captcha'], input[id*='captcha']").isEmpty() || + // Google reCAPTCHA markers + !doc.select(".g-recaptcha, div[data-sitekey]").isEmpty() || + // CAPTCHA image patterns + !doc.select("img[src*='captcha'], img[src*='recaptcha']").isEmpty() || + // Text-based indicators + org.apache.commons.lang3.StringUtils.containsIgnoreCase(html, "verify you are human") || + // hCAPTCHA detection + !doc.select(".h-captcha").isEmpty(); + } + public static class Factory implements Tool.Factory { private static Factory INSTANCE; private ThreadPool threadPool; @@ -524,4 +485,163 @@ public Map getDefaultAttributes() { return DEFAULT_ATTRIBUTES; } } + + private final class WebSearchResponseHandler implements SdkAsyncHttpResponseHandler { + private final String endpoint; + private final String authorization; + private final String parsedNextPage; + private final String engine; + private final String customResUrlJsonpath; + private final ActionListener listener; + + public WebSearchResponseHandler( + String endpoint, + String authorization, + String parsedNextPage, + String engine, + String customResUrlJsonpath, + ActionListener listener + ) { + this.endpoint = endpoint; + this.authorization = authorization; + this.parsedNextPage = parsedNextPage; + this.engine = engine; + this.customResUrlJsonpath = customResUrlJsonpath; + this.listener = listener; + } + + @Override + public void onHeaders(SdkHttpResponse response) { + SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse) response; + log.debug("received response headers: " + sdkResponse.headers()); + int statusCode = sdkResponse.statusCode(); + if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) { + log + .error( + "Received error from endpoint:{} with status code {}, response headers: {}", + endpoint, + statusCode, + sdkResponse.headers() + ); + listener + .onFailure( + new OpenSearchStatusException( + String.format(Locale.ROOT, "Failed to fetch results from endpoint: %s", endpoint), + RestStatus.fromCode(statusCode) + ) + ); + } + } + + @Override + public void onStream(Publisher stream) { + stream.subscribe(new Subscriber<>() { + private final StringBuilder responseBuilder = new StringBuilder(); + private Subscription subscription; + + @Override + public void onSubscribe(Subscription subscription) { + log.debug("Starting to fetch response..."); + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + responseBuilder.append(StandardCharsets.UTF_8.decode(byteBuffer)); + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onError(Throwable throwable) { + log.error("Failed to fetch results from endpoint: {}", endpoint, throwable); + listener.onFailure(new RuntimeException(throwable)); + } + + @Override + public void onComplete() { + log.debug("Successfully fetched results from endpoint: {}", endpoint); + parseResponse(responseBuilder.toString(), authorization, parsedNextPage, engine, customResUrlJsonpath, listener); + } + }); + } + + @Override + public void onError(Throwable error) { + log.error("Failed to fetch results from endpoint: {}", endpoint, error); + listener.onFailure(new RuntimeException(error)); + } + + private void parseResponse( + String rawResponse, + String authorization, + String nextPage, + String engine, + String customResUrlJsonpath, + ActionListener listener + ) { + JsonObject rawJson = JsonParser.parseString(rawResponse).getAsJsonObject(); + switch (engine.toLowerCase(Locale.ROOT)) { + case GOOGLE: + parseGoogleResults(rawJson, nextPage, listener); + break; + case BING: + parseBingResults(rawJson, nextPage, listener); + break; + case CUSTOM: + List urls = JsonPath.read(rawResponse, customResUrlJsonpath); + parseCustomResults(urls, authorization, nextPage, listener); + break; + default: + listener.onFailure(new RuntimeException(String.format(Locale.ROOT, "Unsupported search engine: %s", engine))); + } + } + + private void parseGoogleResults(JsonObject googleResponse, String nextPage, ActionListener listener) { + Map results = new HashMap<>(); + results.put(NEXT_PAGE, nextPage); + // extract search results, each item is a search result: + // https://developers.google.com/custom-search/v1/reference/rest/v1/Search#result + JsonArray items = googleResponse.getAsJsonArray(ITEMS); + List> crawlResults = new ArrayList<>(); + for (int i = 0; i < items.size(); i++) { + JsonObject item = items.get(i).getAsJsonObject(); + // extract the actual link for scrawl. + String link = item.get("link").getAsString(); + // extract title and content. + Map crawlResult = crawlPage(link, null); + crawlResults.add(crawlResult); + } + results.put(ITEMS, crawlResults); + listener.onResponse((T) StringUtils.gson.toJson(results)); + } + + private void parseBingResults(JsonObject bingResponse, String nextPage, ActionListener listener) { + Map results = new HashMap<>(); + results.put(NEXT_PAGE, nextPage); + List> crawlResults = new ArrayList<>(); + JsonArray values = bingResponse.get("webPages").getAsJsonObject().getAsJsonArray("value"); + for (int i = 0; i < values.size(); i++) { + JsonObject value = values.get(i).getAsJsonObject(); + String link = value.get(URL).getAsString(); + Map crawlResult = crawlPage(link, null); + crawlResults.add(crawlResult); + } + results.put(ITEMS, crawlResults); + listener.onResponse((T) StringUtils.gson.toJson(results)); + } + + private void parseCustomResults(List urls, String authorization, String nextPage, ActionListener listener) { + Map results = new HashMap<>(); + results.put(NEXT_PAGE, nextPage); + List> crawlResults = new ArrayList<>(); + for (int i = 0; i < urls.size(); i++) { + String link = urls.get(i); + Map crawlResult = crawlPage(link, authorization); + crawlResults.add(crawlResult); + } + results.put(ITEMS, crawlResults); + listener.onResponse((T) StringUtils.gson.toJson(results)); + } + } } diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java index 66e1c7d8..149a5148 100644 --- a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java @@ -95,6 +95,7 @@ public void setup() { null, null, null, + null, null ); } diff --git a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java index 79154570..dee0a019 100644 --- a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java +++ b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java @@ -15,6 +15,10 @@ import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; +import org.opensearch.ml.common.utils.StringUtils; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; import lombok.SneakyThrows; @@ -132,7 +136,7 @@ public void testDataDistributionToolSingleAnalysis() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -166,7 +170,7 @@ public void testDataDistributionToolWithFilter() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -202,7 +206,7 @@ public void testDataDistributionToolPPLSingleAnalysis() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.5},{\"value\":\"2.0\",\"selectionPercentage\":0.25},{\"value\":\"4.0\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -238,7 +242,7 @@ public void testDataDistributionToolPPLWithCustomQuery() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67},{\"value\":\"4.0\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -255,7 +259,7 @@ public void testDataDistributionToolWithDSLQueryType() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -272,7 +276,7 @@ public void testDataDistributionToolWithMultipleFilters() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -289,7 +293,7 @@ public void testDataDistributionToolWithCustomSize() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -306,7 +310,7 @@ public void testDataDistributionToolWithCustomTimeField() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -323,7 +327,7 @@ public void testDataDistributionToolWithRangeFilter() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -340,7 +344,7 @@ public void testDataDistributionToolWithMatchFilter() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -357,7 +361,7 @@ public void testDataDistributionToolWithRawDSLQuery() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -374,7 +378,7 @@ public void testDataDistributionToolWithExistsFilter() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -443,6 +447,30 @@ public void testDataDistributionToolPPLWithComplexQuery() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67},{\"value\":\"4.0\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); + } + + private void assertResults(String expectedResult, String result) { + try { + JsonNode resultJson = StringUtils.MAPPER.readTree(result); + JsonNode expectedJson = StringUtils.MAPPER.readTree(expectedResult); + JsonNode expectedAnalysis = expectedJson.get("singleAnalysis"); + JsonNode resultAnalysis = resultJson.get("singleAnalysis"); + for (int i = 0; i < expectedAnalysis.size(); i++) { + assertEquals(expectedAnalysis.get(i).get("field").asText(), resultAnalysis.get(i).get("field").asText()); + assertEquals(expectedAnalysis.get(i).get("divergence").asText(), resultAnalysis.get(i).get("divergence").asText()); + JsonNode expectedTopChanges = expectedAnalysis.get(i).get("topChanges"); + JsonNode resultTopChanges = resultAnalysis.get(i).get("topChanges"); + for (int j = 0; j < expectedTopChanges.size(); j++) { + assertEquals( + expectedTopChanges.get(j).get("selectionPercentage").asText(), + resultTopChanges.get(j).get("selectionPercentage").asText() + ); + assertEquals(expectedTopChanges.get(j).get("value").asText(), resultTopChanges.get(j).get("value").asText()); + } + } + } catch (JsonProcessingException e) { + fail("Failed to process jsons"); + } } } diff --git a/src/test/java/org/opensearch/integTest/PPLToolIT.java b/src/test/java/org/opensearch/integTest/PPLToolIT.java index 3d6120ee..9ca930aa 100644 --- a/src/test/java/org/opensearch/integTest/PPLToolIT.java +++ b/src/test/java/org/opensearch/integTest/PPLToolIT.java @@ -54,9 +54,9 @@ public void testPPLTool() { String agentId = registerAgent(); String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee\"}}"); assertEquals( - "{\"ppl\":\"source\\u003demployee | where age \\u003e 56 | stats COUNT() as cnt\"," + "{\"ppl\":\"source=employee | where age > 56 | stats COUNT() as cnt\"," + "\"executionResult\":\"{\\n \\\"schema\\\": [\\n {\\n \\\"name\\\": \\\"cnt\\\",\\n " - + "\\\"type\\\": \\\"int\\\"\\n }\\n ],\\n \\\"datarows\\\": [\\n [\\n 0\\n ]\\n ],\\n " + + "\\\"type\\\": \\\"bigint\\\"\\n }\\n ],\\n \\\"datarows\\\": [\\n [\\n 0\\n ]\\n ],\\n " + "\\\"total\\\": 1,\\n \\\"size\\\": 1\\n}\"}", result ); @@ -114,7 +114,7 @@ public void testPPLTool_withNonExistingIndex_thenThrowException() { exception.getMessage(), allOf( containsString( - "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + "Return this final answer to human directly and do not use other tools: 'Please provide the existing index name(s)'. Please try to directly send this message to human to ask for index name" ) ) ); From 5a3529c1f7348a65ea69e4aaac12e7b3a050e462 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 8 Oct 2025 19:26:23 +0800 Subject: [PATCH 13/30] update 3.3 release note (#650) Signed-off-by: zane-neo --- release-notes/opensearch-skills.release-notes-3.3.0.0.md | 1 + 1 file changed, 1 insertion(+) diff --git a/release-notes/opensearch-skills.release-notes-3.3.0.0.md b/release-notes/opensearch-skills.release-notes-3.3.0.0.md index 83b5be43..8a549a08 100644 --- a/release-notes/opensearch-skills.release-notes-3.3.0.0.md +++ b/release-notes/opensearch-skills.release-notes-3.3.0.0.md @@ -11,6 +11,7 @@ Compatible with OpenSearch and OpenSearch Dashboards version 3.3.0 ### Bug Fixes * Delete-single-baseline ([#641](https://github.com/opensearch-project/skills/pull/641)) +* Fix WebSearchTool issue ([#639](https://github.com/opensearch-project/skills/pull/639)) ### Infrastructure * Update System.env syntax for Gradle 9 compatibility ([#630](https://github.com/opensearch-project/skills/pull/630)) From 0d1bfa8ffa084e52c7033965a382384242567a51 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Thu, 23 Oct 2025 15:46:42 -0400 Subject: [PATCH 14/30] Onboarding new maven snapshots publishing to s3 (skills) (#657) Signed-off-by: Peter Zhu --- .github/workflows/maven-publish.yml | 10 ++++++++-- build.gradle | 15 +++++++-------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/.github/workflows/maven-publish.yml b/.github/workflows/maven-publish.yml index 1e3bc651..8b73d17d 100644 --- a/.github/workflows/maven-publish.yml +++ b/.github/workflows/maven-publish.yml @@ -32,8 +32,14 @@ jobs: export-env: true env: OP_SERVICE_ACCOUNT_TOKEN: ${{ secrets.OP_SERVICE_ACCOUNT_TOKEN }} - SONATYPE_USERNAME: op://opensearch-infra-secrets/maven-central-portal-credentials/username - SONATYPE_PASSWORD: op://opensearch-infra-secrets/maven-central-portal-credentials/password + MAVEN_SNAPSHOTS_S3_REPO: op://opensearch-infra-secrets/maven-snapshots-s3/repo + MAVEN_SNAPSHOTS_S3_ROLE: op://opensearch-infra-secrets/maven-snapshots-s3/role + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v5 + with: + role-to-assume: ${{ env.MAVEN_SNAPSHOTS_S3_ROLE }} + aws-region: us-east-1 - name: publish snapshots to maven run: | diff --git a/build.gradle b/build.gradle index 1eda7969..278551c3 100644 --- a/build.gradle +++ b/build.gradle @@ -31,8 +31,7 @@ buildscript { repositories { mavenLocal() - maven { url "https://central.sonatype.com/repository/maven-snapshots/" } - maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } + maven { url "https://ci.opensearch.org/ci/dbc/snapshots/maven/" } maven { url "https://plugins.gradle.org/m2/" } mavenCentral() } @@ -57,8 +56,7 @@ repositories { mavenLocal() mavenCentral() maven { url "https://plugins.gradle.org/m2/" } - maven { url "https://central.sonatype.com/repository/maven-snapshots/" } - maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } + maven { url "https://ci.opensearch.org/ci/dbc/snapshots/maven/" } } allprojects { @@ -448,10 +446,11 @@ publishing { repositories { maven { name = "Snapshots" - url = "https://central.sonatype.com/repository/maven-snapshots/" - credentials { - username System.getenv("SONATYPE_USERNAME") - password System.getenv("SONATYPE_PASSWORD") + url = System.getenv("MAVEN_SNAPSHOTS_S3_REPO") + credentials(AwsCredentials) { + accessKey = System.getenv("AWS_ACCESS_KEY_ID") + secretKey = System.getenv("AWS_SECRET_ACCESS_KEY") + sessionToken = System.getenv("AWS_SESSION_TOKEN") } } } From 0952482b1cfa19163a6b0dd42b2a388b7b181094 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 28 Oct 2025 07:41:04 +0800 Subject: [PATCH 15/30] fix regex bypass issue (#656) * fix regex bypass issue Signed-off-by: zane-neo * change websearch tool as dependency changed Signed-off-by: zane-neo * increment to patch version Signed-off-by: zane-neo --------- Signed-off-by: zane-neo --- build.gradle | 2 +- .../java/org/opensearch/agent/tools/WebSearchTool.java | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/build.gradle b/build.gradle index 278551c3..28da7c75 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ import com.github.jengelman.gradle.plugins.shadow.ShadowPlugin buildscript { ext { opensearch_group = "org.opensearch" - opensearch_version = System.getProperty("opensearch.version", "3.3.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.3.2-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") isSnapshot = "true" == System.getProperty("build.snapshot", "true") version_tokens = opensearch_version.tokenize('-') diff --git a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java index 1d5500f6..a1d2fa7d 100644 --- a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java +++ b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java @@ -127,7 +127,8 @@ public class WebSearchTool implements Tool { public WebSearchTool(ThreadPool threadPool) { // Use 1s for connection timeout, 3s for read timeout, 30 for max connections of httpclient. - this.httpClient = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(1), Duration.ofSeconds(3), 30); + // For WebSearchTool, we don't allow user to connect to private ip. + this.httpClient = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(1), Duration.ofSeconds(3), 30, false); this.threadPool = threadPool; this.attributes = new HashMap<>(); attributes.put(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); @@ -207,7 +208,12 @@ public void run(Map originalParameters, ActionListener li new WebSearchResponseHandler(endpoint, authorization, parsedNextPage, engine, customResUrlJsonpath, listener) ) .build(); - httpClient.execute(executeRequest); + try { + httpClient.execute(executeRequest); + } catch (Exception e) { + log.error("Web search failed!", e); + listener.onFailure(new IllegalStateException(String.format(Locale.ROOT, "Web search failed: %s", e.getMessage()))); + } } }); } catch (Exception e) { From 6584fa51e4a26ef75ca4e427b5de873881a99120 Mon Sep 17 00:00:00 2001 From: opensearch-ci <83309141+opensearch-ci-bot@users.noreply.github.com> Date: Tue, 28 Oct 2025 13:09:49 -0400 Subject: [PATCH 16/30] [AUTO] Add release notes for 3.3.2 (#664) * Add release notes for 3.3.2 Signed-off-by: opensearch-ci * Add release notes for 3.3.2 Signed-off-by: opensearch-ci * Remove redundant sections from release notes Signed-off-by: Peter Zhu * Update compatibility information for version 3.3.2 Signed-off-by: Peter Zhu --------- Signed-off-by: opensearch-ci Signed-off-by: Peter Zhu Co-authored-by: Peter Zhu --- release-notes/opensearch-skills.release-notes-3.3.2.0.md | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 release-notes/opensearch-skills.release-notes-3.3.2.0.md diff --git a/release-notes/opensearch-skills.release-notes-3.3.2.0.md b/release-notes/opensearch-skills.release-notes-3.3.2.0.md new file mode 100644 index 00000000..dfddd694 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.3.2.0.md @@ -0,0 +1,6 @@ +## Version 3.3.2 Release Notes + +Compatible with OpenSearch 3.3.2 and OpenSearch Dashboards 3.3.0 + +### Bug Fixes +* Fix regex bypass issue ([#656](https://github.com/opensearch-project/skills/pull/656)) From ba0a55e26e967f19976e0b55b576c76ab1bd276d Mon Sep 17 00:00:00 2001 From: "mend-for-github-com[bot]" <50673670+mend-for-github-com[bot]@users.noreply.github.com> Date: Thu, 30 Oct 2025 11:48:34 +0800 Subject: [PATCH 17/30] chore(deps): update gradle to v8.14.3 (#649) Signed-off-by: mend-for-github-com[bot] Co-authored-by: mend-for-github-com[bot] <50673670+mend-for-github-com[bot]@users.noreply.github.com> --- gradle/wrapper/gradle-wrapper.properties | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index f373f37a..dbc089ed 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=efe9a3d147d948d7528a9887fa35abcf24ca1a43ad06439996490f77569b02d1 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-all.zip +distributionSha256Sum=ed1a8d686605fd7c23bdf62c7fc7add1c5b23b2bbc3721e661934ef4a4911d7c +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-all.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME From dc0740144bb42f5479bd9732db03a98db8e5bb6d Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 12:39:18 -0800 Subject: [PATCH 18/30] Increment version to 3.4.0-SNAPSHOT (#646) Signed-off-by: opensearch-ci-bot Co-authored-by: opensearch-ci-bot --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index 28da7c75..ef7f3106 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ import com.github.jengelman.gradle.plugins.shadow.ShadowPlugin buildscript { ext { opensearch_group = "org.opensearch" - opensearch_version = System.getProperty("opensearch.version", "3.3.2-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.4.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") isSnapshot = "true" == System.getProperty("build.snapshot", "true") version_tokens = opensearch_version.tokenize('-') From 805de2dc38e1829be32754fa966517e3ba375e73 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 2 Dec 2025 13:23:27 +0800 Subject: [PATCH 19/30] Gradle 9.2.0 and GitHub Actions JDK 25 Upgrade (#675) * Gradle 9.2.0 and GitHub Actions JDK 25 Upgrade Signed-off-by: zane-neo * upgrade opensearch version to 3.4.0 to fix forbiddenApiMain ifailure Signed-off-by: zane-neo * fix build error Signed-off-by: zane-neo --------- Signed-off-by: zane-neo --- .github/workflows/ci.yml | 6 +++--- .github/workflows/test_security.yml | 2 +- gradle/wrapper/gradle-wrapper.properties | 4 ++-- .../agent/tools/SearchAnomalyDetectorsToolTests.java | 1 + .../java/org/opensearch/integTest/BaseAgentToolsIT.java | 1 + 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index deeed07e..c3edc539 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: needs: Get-CI-Image-Tag strategy: matrix: - java: [21, 24] + java: [21, 25] name: Build and Test skills plugin on Linux runs-on: ubuntu-latest container: @@ -52,7 +52,7 @@ jobs: build-MacOS: strategy: matrix: - java: [21, 24] + java: [21, 25] name: Build and Test skills Plugin on MacOS needs: Get-CI-Image-Tag @@ -77,7 +77,7 @@ jobs: build-windows: strategy: matrix: - java: [21, 24] + java: [21, 25] name: Build and Test skills plugin on Windows needs: Get-CI-Image-Tag runs-on: windows-latest diff --git a/.github/workflows/test_security.yml b/.github/workflows/test_security.yml index 0f22b923..63c99b29 100644 --- a/.github/workflows/test_security.yml +++ b/.github/workflows/test_security.yml @@ -16,7 +16,7 @@ jobs: integ-test-with-security-linux: strategy: matrix: - java: [21, 24] + java: [21, 25] env: ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true name: Run Security Integration Tests on Linux diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index dbc089ed..b11741a1 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=ed1a8d686605fd7c23bdf62c7fc7add1c5b23b2bbc3721e661934ef4a4911d7c -distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-all.zip +distributionSha256Sum=16f2b95838c1ddcf7242b1c39e7bbbb43c842f1f1a1a0dc4959b6d4d68abcac3 +distributionUrl=https\://services.gradle.org/distributions/gradle-9.2.0-all.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java index 149a5148..9b5a48e2 100644 --- a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java @@ -96,6 +96,7 @@ public void setup() { null, null, null, + null, null ); } diff --git a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java index 76d5c72e..66cedc2f 100644 --- a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java +++ b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java @@ -65,6 +65,7 @@ public void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100); updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); updateClusterSettings("plugins.ml_commons.agent_framework_enabled", true); + updateClusterSettings("plugins.ml_commons.connector.private_ip_enabled", true); } @SneakyThrows From 7d49ceea7af39fb727c33abfa4131287b69cbea0 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Thu, 4 Dec 2025 09:02:50 +0800 Subject: [PATCH 20/30] increase max_sample_count to 5 for log insight (#677) Signed-off-by: Hailong Cui --- .../java/org/opensearch/agent/tools/LogPatternAnalysisTool.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java index 71508d95..a0c3abff 100644 --- a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java +++ b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java @@ -753,7 +753,7 @@ private void logInsight(AnalysisParameters params, ActionListener listene .format( Locale.ROOT, "source=%s | where %s>'%s' and %s<'%s' | where match(%s, '%s') | patterns %s method=brain " - + "mode=aggregation max_sample_count=2 " + + "mode=aggregation max_sample_count=5 " + "variable_count_threshold=3 | fields patterns_field, pattern_count, sample_logs " + "| sort -pattern_count | head 5", params.index, From 06525fc4f39fbd0d07a7fef73c792b174293a3a5 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Thu, 22 Jan 2026 13:52:27 +0800 Subject: [PATCH 21/30] [AUTO] Increment version to 3.5.0-SNAPSHOT (#683) * Increment version to 3.5.0-SNAPSHOT Signed-off-by: opensearch-ci-bot * fix jackson version Signed-off-by: Hailong Cui --------- Signed-off-by: opensearch-ci-bot Signed-off-by: Hailong Cui Co-authored-by: opensearch-ci-bot Co-authored-by: Hailong Cui --- build.gradle | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/build.gradle b/build.gradle index ef7f3106..be84ec76 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ import com.github.jengelman.gradle.plugins.shadow.ShadowPlugin buildscript { ext { opensearch_group = "org.opensearch" - opensearch_version = System.getProperty("opensearch.version", "3.4.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.5.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") isSnapshot = "true" == System.getProperty("build.snapshot", "true") version_tokens = opensearch_version.tokenize('-') @@ -139,7 +139,7 @@ dependencies { compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: "${versions.commonslang}" compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.12.0' compileOnly group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' - compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") + compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson_annotations}") compileOnly("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") compileOnly(group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: "${versions.httpcore5}") compileOnly(group: 'org.apache.httpcomponents.client5', name: 'httpclient5', version: "${versions.httpclient5}") @@ -165,7 +165,7 @@ dependencies { implementation("org.json4s:json4s-ast_2.13:3.7.0-M11") implementation("org.json4s:json4s-core_2.13:3.7.0-M11") implementation("org.json4s:json4s-jackson_2.13:3.7.0-M11") - implementation 'com.fasterxml.jackson.module:jackson-module-scala_3:2.18.2' + implementation "com.fasterxml.jackson.module:jackson-module-scala_3:${versions.jackson}" implementation group: 'org.scala-lang', name: 'scala3-library_3', version: '3.7.0-RC1-bin-20250119-bd699fc-NIGHTLY' implementation("com.thoughtworks.paranamer:paranamer:2.8") implementation("org.jsoup:jsoup:1.19.1") @@ -175,7 +175,7 @@ dependencies { compileOnly group: 'org.opensearch', name:'opensearch-ml-spi', version: "${opensearch_build}" compileOnly fileTree(dir: jsJarDirectory, include: ["opensearch-job-scheduler-${opensearch_build}.jar"]) implementation fileTree(dir: adJarDirectory, include: ["opensearch-anomaly-detection-${opensearch_build}.jar"]) - implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-thin-${opensearch_build}.jar", "ppl-${opensearch_build}.jar", "protocol-${opensearch_build}.jar"]) + implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-thin-${opensearch_build}.jar", "ppl-${opensearch_build}.jar", "protocol-${opensearch_build}.jar", "core-${opensearch_build}.jar"]) implementation fileTree(dir: sparkDir, include: ["spark*.jar"]) compileOnly "org.opensearch:common-utils:${opensearch_build}" compileOnly "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}" From 054e9ed7231c70e3828b827162178806c48470be Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Thu, 5 Feb 2026 13:36:38 +0800 Subject: [PATCH 22/30] fix LogPatternAnalysisTool missing attributes (#690) Signed-off-by: Hailong Cui --- .../java/org/opensearch/agent/tools/LogPatternAnalysisTool.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java index a0c3abff..cd19be9f 100644 --- a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java +++ b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java @@ -254,7 +254,7 @@ public String getType() { @Override public Map getAttributes() { - return Map.of(); + return DEFAULT_ATTRIBUTES; } @Override From ec18a106c252c5f8f67bd98144b181fea7a6ecd1 Mon Sep 17 00:00:00 2001 From: Xuesong Luo Date: Fri, 27 Feb 2026 11:21:46 +0800 Subject: [PATCH 23/30] feat: Add MetricChangeAnalysisTool for metric change detection (#698) * feat: Add MetricChangeAnalysisTool for metric change detection - Add MetricChangeAnalysisTool to analyze metric changes via percentile comparison * Analyzes P25, P50, P75, P90 percentiles between baseline and selection periods * Ranks fields by relative change score to identify significant changes * Uses variance calculation based on relative changes (scale-independent) * Requires both baseline and selection periods for comparison * Configurable topN parameter (default: 5) to return top N fields - Extract DataFetchingHelper utility for code reuse * Shared data fetching logic between DataDistributionTool and MetricChangeAnalysisTool * Handles field type detection, query building, and data retrieval * Reduces code duplication by ~250 lines - Refactor DataDistributionTool to use shared helper * Removed duplicate AnalysisParameters class * Uses DataFetchingHelper.AnalysisParameters instead * Delegated data fetching to DataFetchingHelper * All existing functionality preserved and tests passing - Add comprehensive test coverage * 16 unit tests for MetricChangeAnalysisTool (including topN validation) * All DataDistributionTool tests still passing (97/97) * Full test suite: BUILD SUCCESSFUL Breaking Change: This is a new tool, no breaking changes to existing functionality. Test Results: All tests passing (16 unit + 97 DataDistribution tests) Signed-off-by: Xuesong Luo * fix integration test and remove unused methods Signed-off-by: Hailong Cui * enable MetricChangeAnalysisTool Signed-off-by: Hailong Cui * update tool description Signed-off-by: Hailong Cui * update tool description and remove P25/P75 Signed-off-by: Hailong Cui * add back size validation Signed-off-by: Hailong Cui --------- Signed-off-by: Xuesong Luo Signed-off-by: Hailong Cui Co-authored-by: Hailong Cui --- .../java/org/opensearch/agent/ToolPlugin.java | 5 +- .../agent/tools/DataDistributionTool.java | 629 +----------------- .../agent/tools/DataFetchingHelper.java | 482 ++++++++++++++ .../agent/tools/MetricChangeAnalysisTool.java | 556 ++++++++++++++++ .../org/opensearch/agent/ToolPluginTests.java | 2 +- .../tools/DataDistributionToolTests.java | 397 +---------- .../tools/MetricChangeAnalysisToolTests.java | 619 +++++++++++++++++ .../integTest/DataDistributionToolIT.java | 4 +- 8 files changed, 1724 insertions(+), 970 deletions(-) create mode 100644 src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java create mode 100644 src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java create mode 100644 src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 5de1227d..97228054 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -16,6 +16,7 @@ import org.opensearch.agent.tools.DataDistributionTool; import org.opensearch.agent.tools.LogPatternAnalysisTool; import org.opensearch.agent.tools.LogPatternTool; +import org.opensearch.agent.tools.MetricChangeAnalysisTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; import org.opensearch.agent.tools.RAGTool; @@ -102,6 +103,7 @@ public Collection createComponents( WebSearchTool.Factory.getInstance().init(threadPool); LogPatternAnalysisTool.Factory.getInstance().init(client); DataDistributionTool.Factory.getInstance().init(client); + MetricChangeAnalysisTool.Factory.getInstance().init(client); return Collections.emptyList(); } @@ -122,7 +124,8 @@ public List> getToolFactories() { LogPatternTool.Factory.getInstance(), WebSearchTool.Factory.getInstance(), LogPatternAnalysisTool.Factory.getInstance(), - DataDistributionTool.Factory.getInstance() + DataDistributionTool.Factory.getInstance(), + MetricChangeAnalysisTool.Factory.getInstance() ); } diff --git a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java index 1e6dde61..9a163126 100644 --- a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java +++ b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java @@ -5,16 +5,17 @@ package org.opensearch.agent.tools; +import static org.opensearch.agent.tools.DataFetchingHelper.DATE_FORMAT_PATTERN; +import static org.opensearch.agent.tools.DataFetchingHelper.NUMBER_FIELD_TYPES; +import static org.opensearch.agent.tools.DataFetchingHelper.QUERY_TYPE_PPL; import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.time.LocalDateTime; -import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; @@ -26,25 +27,14 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.math.NumberUtils; -import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; -import org.opensearch.action.search.SearchRequest; import org.opensearch.agent.tools.utils.PPLExecuteHelper; -import org.opensearch.agent.tools.utils.ToolHelper; -import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.ToolUtils; -import org.opensearch.search.SearchHit; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.transport.client.Client; -import com.google.gson.reflect.TypeToken; - import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -114,27 +104,9 @@ public class DataDistributionTool implements Tool { private static final String DEFAULT_DESCRIPTION = "This tool analyzes data distribution differences between time ranges or provides single dataset insights."; - private static final String DEFAULT_TIME_FIELD = "@timestamp"; - - private static final String PARAM_INDEX = "index"; - private static final String PARAM_TIME_FIELD = "timeField"; - private static final String PARAM_SELECTION_TIME_RANGE_START = "selectionTimeRangeStart"; - private static final String PARAM_SELECTION_TIME_RANGE_END = "selectionTimeRangeEnd"; - private static final String PARAM_BASELINE_TIME_RANGE_START = "baselineTimeRangeStart"; - private static final String PARAM_BASELINE_TIME_RANGE_END = "baselineTimeRangeEnd"; - private static final String PARAM_SIZE = "size"; - private static final String PARAM_QUERY_TYPE = "queryType"; - private static final String PARAM_FILTER = "filter"; - private static final String PARAM_DSL = "dsl"; - private static final String QUERY_TYPE_PPL = "ppl"; - private static final String QUERY_TYPE_DSL = "dsl"; - private static final String DEFAULT_SIZE = "1000"; - private static final String DATE_FORMAT_PATTERN = "yyyy-MM-dd HH:mm:ss"; private static final Set USEFUL_FIELD_TYPES = Set .of("keyword", "boolean", "text", "byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); - private static final Set NUMBER_FIELD_TYPES = Set - .of("byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); private static final int DEFAULT_COMPARISON_RESULT_LIMIT = 10; private static final int DEFAULT_SINGLE_ANALYSIS_RESULT_LIMIT = 30; @@ -207,97 +179,6 @@ public class DataDistributionTool implements Tool { public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); - /** - * Parameter class to hold analysis parameters with validation - */ - private static class AnalysisParameters { - final String index; - final String timeField; - final String selectionTimeRangeStart; - final String selectionTimeRangeEnd; - final String baselineTimeRangeStart; - final String baselineTimeRangeEnd; - final int size; - final String queryType; - final List filter; - final String dsl; - final String ppl; - - /** - * Constructs analysis parameters from input map with default values - * - * @param parameters Input parameter map from user request - */ - AnalysisParameters(Map parameters) { - this.index = parameters.getOrDefault(PARAM_INDEX, ""); - this.timeField = parameters.getOrDefault(PARAM_TIME_FIELD, DEFAULT_TIME_FIELD); - this.selectionTimeRangeStart = parameters.getOrDefault(PARAM_SELECTION_TIME_RANGE_START, ""); - this.selectionTimeRangeEnd = parameters.getOrDefault(PARAM_SELECTION_TIME_RANGE_END, ""); - this.baselineTimeRangeStart = parameters.getOrDefault(PARAM_BASELINE_TIME_RANGE_START, ""); - this.baselineTimeRangeEnd = parameters.getOrDefault(PARAM_BASELINE_TIME_RANGE_END, ""); - - try { - this.size = Integer.parseInt(parameters.getOrDefault(PARAM_SIZE, DEFAULT_SIZE)); - if (this.size > MAX_SIZE_LIMIT) { - throw new IllegalArgumentException("Size parameter exceeds maximum limit of " + MAX_SIZE_LIMIT + ", got: " + this.size); - } - } catch (NumberFormatException e) { - throw new IllegalArgumentException( - "Invalid 'size' parameter: must be a valid integer, got '" + parameters.get(PARAM_SIZE) + "'" - ); - } - - this.queryType = parameters.getOrDefault(PARAM_QUERY_TYPE, QUERY_TYPE_DSL); - - String filterParam = parameters.getOrDefault(PARAM_FILTER, ""); - if (Strings.isEmpty(filterParam)) { - this.filter = List.of(); - } else { - try { - this.filter = Arrays.asList(gson.fromJson(filterParam, String[].class)); - } catch (Exception e) { - throw new IllegalArgumentException( - "Invalid 'filter' parameter: must be a valid JSON array of strings, got '" - + filterParam - + "'. Example: [\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]" - ); - } - } - - this.dsl = parameters.getOrDefault(PARAM_DSL, ""); - this.ppl = parameters.getOrDefault(QUERY_TYPE_PPL, ""); - } - - /** - * Validates required parameters are present - * - * @throws IllegalArgumentException if required parameters are missing - */ - void validate() { - List missingParams = new ArrayList<>(); - if (Strings.isEmpty(index)) - missingParams.add(PARAM_INDEX); - if (Strings.isEmpty(selectionTimeRangeStart)) - missingParams.add(PARAM_SELECTION_TIME_RANGE_START); - if (Strings.isEmpty(selectionTimeRangeEnd)) - missingParams.add(PARAM_SELECTION_TIME_RANGE_END); - if (Strings.isEmpty(timeField)) - missingParams.add(PARAM_TIME_FIELD); - if (!missingParams.isEmpty()) { - throw new IllegalArgumentException("Missing required parameters: " + String.join(", ", missingParams)); - } - } - - /** - * Checks if baseline time range is provided for comparison analysis - * - * @return true if both baseline start and end times are provided - */ - boolean hasBaselineTime() { - return !Strings.isEmpty(baselineTimeRangeStart) && !Strings.isEmpty(baselineTimeRangeEnd); - } - } - /** * Result class for data distribution analysis */ @@ -319,6 +200,7 @@ private record ChangeItem(String value, double selectionPercentage, Double basel @Getter private String version; private Client client; + private DataFetchingHelper dataFetchingHelper; /** * Constructs a DataDistributionTool with the given OpenSearch client @@ -327,6 +209,7 @@ private record ChangeItem(String value, double selectionPercentage, Double basel */ public DataDistributionTool(Client client) { this.client = client; + this.dataFetchingHelper = new DataFetchingHelper(client); } @Override @@ -345,7 +228,7 @@ public void setAttributes(Map map) {} @Override public boolean validate(Map map) { try { - new AnalysisParameters(map).validate(); + new DataFetchingHelper.AnalysisParameters(map).validate(); } catch (Exception e) { log.error("Failed to validate the data distribution analysis parameter: {}", e.getMessage()); return false; @@ -366,7 +249,7 @@ public void run(Map originalParameters, ActionListener li try { Map parameters = ToolUtils.extractInputParameters(originalParameters, DEFAULT_ATTRIBUTES); log.debug("Starting data distribution analysis with parameters: {}", parameters.keySet()); - AnalysisParameters params = new AnalysisParameters(parameters); + DataFetchingHelper.AnalysisParameters params = new DataFetchingHelper.AnalysisParameters(parameters); if (QUERY_TYPE_PPL.equals(params.queryType)) { executePPLAnalysis(params, listener); @@ -389,8 +272,8 @@ public void run(Map originalParameters, ActionListener li * @param params Analysis parameters containing query details * @param listener Action listener for handling results */ - private void executePPLAnalysis(AnalysisParameters params, ActionListener listener) { - if (params.hasBaselineTime()) { + private void executePPLAnalysis(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { + if (params.hasBaselineTimeRange()) { fetchPPLComparisonData(params, listener); } else { String pplQuery = buildPPLQuery( @@ -423,8 +306,8 @@ private void executePPLAnalysis(AnalysisParameters params, ActionListener * @param params Analysis parameters containing query details * @param listener Action listener for handling results */ - private void executeDSLAnalysis(AnalysisParameters params, ActionListener listener) { - if (params.hasBaselineTime()) { + private void executeDSLAnalysis(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { + if (params.hasBaselineTimeRange()) { fetchComparisonData(params, listener); } else { getSingleDataDistribution(params, listener); @@ -438,7 +321,7 @@ private void executeDSLAnalysis(AnalysisParameters params, ActionListener * @param params Analysis parameters containing time ranges * @param listener Action listener for handling comparison results */ - private void fetchComparisonData(AnalysisParameters params, ActionListener listener) { + private void fetchComparisonData(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { fetchIndexData(params.selectionTimeRangeStart, params.selectionTimeRangeEnd, params, ActionListener.wrap(selectionData -> { fetchIndexData(params.baselineTimeRangeStart, params.baselineTimeRangeEnd, params, ActionListener.wrap(baselineData -> { try { @@ -465,7 +348,7 @@ private void fetchComparisonData(AnalysisParameters params, ActionListener void getSingleDataDistribution(AnalysisParameters params, ActionListener listener) { + private void getSingleDataDistribution(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { fetchIndexData(params.selectionTimeRangeStart, params.selectionTimeRangeEnd, params, ActionListener.wrap(data -> { try { if (data.isEmpty()) { @@ -480,48 +363,6 @@ private void getSingleDataDistribution(AnalysisParameters params, ActionList }, listener::onFailure)); } - /** - * Formats time string to ISO 8601 format for OpenSearch compatibility - * - * @param timeString Input time string - * @return Formatted time string in ISO 8601 format - * @throws DateTimeParseException if time string cannot be parsed - */ - private String formatTimeString(String timeString) throws DateTimeParseException { - log.debug("Attempting to parse time string: {}", timeString); - - // Try parsing with zone first - try { - if (timeString.endsWith("Z")) { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss'Z'", Locale.ROOT); - ZonedDateTime dateTime = ZonedDateTime.parse(timeString, formatter.withZone(ZoneOffset.UTC)); - return dateTime.format(DateTimeFormatter.ISO_INSTANT); - } - } catch (DateTimeParseException e) { - log.debug("Failed to parse as UTC time: {}", e.getMessage()); - } - - // Try parsing as local time without zone - try { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DATE_FORMAT_PATTERN, Locale.ROOT); - LocalDateTime localDateTime = LocalDateTime.parse(timeString, formatter); - ZonedDateTime zonedDateTime = localDateTime.atOffset(ZoneOffset.UTC).toZonedDateTime(); - return zonedDateTime.format(DateTimeFormatter.ISO_INSTANT); - } catch (DateTimeParseException e) { - log.debug("Failed to parse as local time: {}", e.getMessage()); - } - - // Try ISO format - try { - ZonedDateTime dateTime = ZonedDateTime.parse(timeString); - return dateTime.format(DateTimeFormatter.ISO_INSTANT); - } catch (DateTimeParseException e) { - log.debug("Failed to parse as ISO format: {}", e.getMessage()); - } - - throw new DateTimeParseException("Unable to parse time string: " + timeString, timeString, 0); - } - /** * Fetches data from the specified index within the given time range * @@ -533,82 +374,28 @@ private String formatTimeString(String timeString) throws DateTimeParseException private void fetchIndexData( String startTime, String endTime, - AnalysisParameters params, + DataFetchingHelper.AnalysisParameters params, ActionListener>> listener ) { - try { - String formattedStartTime = formatTimeString(startTime); - String formattedEndTime = formatTimeString(endTime); - BoolQueryBuilder query; - - // Use raw DSL query if provided - if (!Strings.isEmpty(params.dsl)) { - try { - Map dslMap = gson.fromJson(params.dsl, new TypeToken>() { - }.getType()); - query = QueryBuilders.boolQuery(); - - // Handle DSL query structure - check if it has "query" wrapper - if (dslMap.containsKey("query")) { - @SuppressWarnings("unchecked") - Map queryMap = (Map) dslMap.get("query"); - log.debug("Processing DSL query with wrapper: {}", queryMap); - - // Build the DSL query directly into the main query - buildQueryFromMap(queryMap, query); - - // Add time range filter - query.filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); - } else { - log.debug("Processing DSL query without wrapper: {}", dslMap); - buildQueryFromMap(dslMap, query); - // Add time range filter to the raw DSL query - query.filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); - } - - log.debug("Final DSL query: {}", query.toString()); - } catch (Exception e) { - log.warn("Failed to parse raw DSL query: {}, falling back to time range only", params.dsl, e); - query = QueryBuilders - .boolQuery() - .filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); - } - } else { - query = QueryBuilders - .boolQuery() - .filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); - - // Add additional filters if provided - if (!params.filter.isEmpty()) { - for (String filterStr : params.filter) { - try { - Map filterMap = gson.fromJson(filterStr, new TypeToken>() { - }.getType()); - BoolQueryBuilder filterQuery = QueryBuilders.boolQuery(); - buildQueryFromMap(filterMap, filterQuery); - query.must(filterQuery); - } catch (Exception e) { - log.warn("Failed to parse filter parameter: {}", filterStr, e); - } - } - } - } - - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query).size(params.size); + // Convert AnalysisParameters to helper parameters + Map helperParams = new HashMap<>(); + helperParams.put("index", params.index); + helperParams.put("timeField", params.timeField); + helperParams.put("size", String.valueOf(params.size)); + helperParams.put("queryType", params.queryType); + if (!Strings.isEmpty(params.dsl)) { + helperParams.put("dsl", params.dsl); + } + if (!params.filter.isEmpty()) { + helperParams.put("filter", gson.toJson(params.filter)); + } + if (!Strings.isEmpty(params.ppl)) { + helperParams.put("ppl", params.ppl); + } - SearchRequest request = new SearchRequest(params.index).source(sourceBuilder); + DataFetchingHelper.AnalysisParameters helperAnalysisParams = new DataFetchingHelper.AnalysisParameters(helperParams); - client.search(request, ActionListener.wrap(response -> { - List> data = Arrays - .stream(response.getHits().getHits()) - .map(SearchHit::getSourceAsMap) - .collect(Collectors.toList()); - listener.onResponse(data); - }, listener::onFailure)); - } catch (Exception e) { - log.error("Failed to format time strings: {}", e.getMessage()); - listener.onFailure(new IllegalArgumentException("Invalid time format: " + e.getMessage(), e)); - } + dataFetchingHelper.fetchIndexData(startTime, endTime, helperAnalysisParams, listener); } /** @@ -618,7 +405,7 @@ private void fetchIndexData( * @param params Analysis parameters containing time ranges * @param listener Action listener for handling comparison results */ - private void fetchPPLComparisonData(AnalysisParameters params, ActionListener listener) { + private void fetchPPLComparisonData(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { String selectionQuery = buildPPLQuery( params.index, params.timeField, @@ -840,38 +627,7 @@ private record GroupedDistributions(Map groupedSelectionDist, Ma * @param listener Action listener for handling field types result */ private void getFieldTypes(String index, ActionListener> listener) { - try { - GetMappingsRequest getMappingsRequest = new GetMappingsRequest().indices(index); - client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(response -> { - try { - Map mappings = response.getMappings(); - if (mappings.isEmpty()) { - listener.onResponse(Map.of()); - return; - } - - MappingMetadata mappingMetadata = mappings.values().iterator().next(); - Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); - if (mappingSource == null) { - listener.onResponse(Map.of()); - return; - } - - Map fieldsToType = new HashMap<>(); - ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", true); - listener.onResponse(fieldsToType); - } catch (Exception e) { - log.error("Failed to process field types for index: {}", index, e); - listener.onResponse(Map.of()); - } - }, e -> { - log.error("Failed to get field types for index: {}", index, e); - listener.onResponse(Map.of()); - })); - } catch (Exception e) { - log.error("Failed to create getMappings request for index: {}", index, e); - listener.onResponse(Map.of()); - } + dataFetchingHelper.getFieldTypes(index, listener); } /** @@ -940,20 +696,7 @@ private List getUsefulFields(List> data, Map doc, String field) { - String[] parts = field.split("\\."); - Object current = doc; - - for (String part : parts) { - if (current instanceof Map) { - current = ((Map) current).get(part); - } else if (current instanceof List) { - return gson.toJson(current); - } else { - return null; - } - } - - return current; + return dataFetchingHelper.getFlattenedValue(doc, field); } /** @@ -1043,12 +786,7 @@ private List getFieldsFromData(List> data) { * @return Set of number field names */ private Set getNumberFields(Map fieldTypes) { - return fieldTypes - .entrySet() - .stream() - .filter(entry -> NUMBER_FIELD_TYPES.contains(entry.getValue())) - .map(Map.Entry::getKey) - .collect(Collectors.toSet()); + return dataFetchingHelper.getNumberFields(fieldTypes); } /** @@ -1150,303 +888,6 @@ private List formatComparisonSummary(List differ }).collect(Collectors.toList()); } - /** - * Builds query conditions from filter map for DSL queries - * - * @param filterMap Filter conditions as map - * @param queryBuilder Query builder to add conditions to - */ - private void buildQueryFromMap(Map filterMap, BoolQueryBuilder queryBuilder) { - log.debug("Building query from map: {}", filterMap); - - for (Map.Entry entry : filterMap.entrySet()) { - String key = entry.getKey(); - Object value = entry.getValue(); - - log.debug("Processing query key: {}, value: {}", key, value); - - // Handle special query types - switch (key) { - case "match_all" -> { - // {"match_all": {}} - log.debug("Adding match_all query"); - queryBuilder.must(QueryBuilders.matchAllQuery()); - } - case "match_none" -> { - // {"match_none": {}} - log.debug("Adding match_none query"); - queryBuilder.mustNot(QueryBuilders.matchAllQuery()); - } - case "bool" -> { - if (value instanceof Map) { - log.debug("Processing bool query: {}", value); - processBoolQuery((Map) value, queryBuilder); - } - } - case "term" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - log.debug("Adding term query: {}", valueMap); - // {"term": {"field": "value"}} - for (Map.Entry termEntry : valueMap.entrySet()) { - log.debug("Term query - field: {}, value: {}", termEntry.getKey(), termEntry.getValue()); - queryBuilder.must(QueryBuilders.termQuery(termEntry.getKey(), termEntry.getValue())); - } - } - } - case "wildcard" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - log.debug("Adding wildcard query: {}", valueMap); - // {"wildcard": {"field": "pattern"}} - for (Map.Entry wildcardEntry : valueMap.entrySet()) { - log.debug("Wildcard query - field: {}, pattern: {}", wildcardEntry.getKey(), wildcardEntry.getValue()); - queryBuilder.must(QueryBuilders.wildcardQuery(wildcardEntry.getKey(), wildcardEntry.getValue().toString())); - } - } - } - case "range" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"range": {"field": {"gte": 1, "lte": 10}}} - for (Map.Entry rangeEntry : valueMap.entrySet()) { - String field = rangeEntry.getKey(); - Object rangeValue = rangeEntry.getValue(); - if (rangeValue instanceof Map) { - processRangeQuery(field, rangeValue, queryBuilder); - } - } - } - } - case "match" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"match": {"field": "value"}} - for (Map.Entry matchEntry : valueMap.entrySet()) { - queryBuilder.must(QueryBuilders.matchQuery(matchEntry.getKey(), matchEntry.getValue())); - } - } - } - case "match_phrase" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"match_phrase": {"field": "value"}} - for (Map.Entry matchPhraseEntry : valueMap.entrySet()) { - queryBuilder.must(QueryBuilders.matchPhraseQuery(matchPhraseEntry.getKey(), matchPhraseEntry.getValue())); - } - } - } - case "prefix" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"prefix": {"field": "value"}} - for (Map.Entry prefixEntry : valueMap.entrySet()) { - queryBuilder.must(QueryBuilders.prefixQuery(prefixEntry.getKey(), prefixEntry.getValue().toString())); - } - } - } - case "exists" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"exists": {"field": "fieldname"}} - Object fieldValue = valueMap.get("field"); - if (fieldValue != null) { - queryBuilder.must(QueryBuilders.existsQuery(fieldValue.toString())); - } - } - } - case "regexp" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"regexp": {"field": "pattern"}} - for (Map.Entry regexpEntry : valueMap.entrySet()) { - queryBuilder.must(QueryBuilders.regexpQuery(regexpEntry.getKey(), regexpEntry.getValue().toString())); - } - } - } - case "terms" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"terms": {"field": ["value1", "value2"]}} - for (Map.Entry termsEntry : valueMap.entrySet()) { - if (termsEntry.getValue() instanceof List) { - queryBuilder.must(QueryBuilders.termsQuery(termsEntry.getKey(), (List) termsEntry.getValue())); - } - } - } - } - case "multi_match" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - Object queryValue = valueMap.get("query"); - Object fieldsValue = valueMap.get("fields"); - if (queryValue != null && fieldsValue instanceof List) { - @SuppressWarnings("unchecked") - List fields = (List) fieldsValue; - queryBuilder.must(QueryBuilders.multiMatchQuery(queryValue, fields.toArray(new String[0]))); - } - } - } - default -> { - // Handle direct field-value pairs or unknown query types - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // This might be a field with nested operators like {"field": {"term": "value"}} - processNestedQuery(key, valueMap, queryBuilder); - } else { - // Direct field-value mapping - queryBuilder.must(QueryBuilders.termQuery(key, value)); - } - } - } - } - } - - /** - * Processes bool query conditions - * - * @param boolMap Bool query conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processBoolQuery(Map boolMap, BoolQueryBuilder queryBuilder) { - for (Map.Entry boolEntry : boolMap.entrySet()) { - String boolType = boolEntry.getKey(); - Object boolValue = boolEntry.getValue(); - - if (boolValue instanceof List) { - @SuppressWarnings("unchecked") - List> clauses = (List>) boolValue; - for (Map clause : clauses) { - BoolQueryBuilder subQuery = QueryBuilders.boolQuery(); - buildQueryFromMap(clause, subQuery); - switch (boolType) { - case "must" -> queryBuilder.must(subQuery); - case "should" -> queryBuilder.should(subQuery); - case "must_not" -> queryBuilder.mustNot(subQuery); - case "filter" -> queryBuilder.filter(subQuery); - default -> log.warn("Unsupported bool query type: {}", boolType); - } - } - } - } - } - - /** - * Processes nested query conditions for a field - * - * @param field Field name - * @param nestedMap Nested query conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processNestedQuery(String field, Map nestedMap, BoolQueryBuilder queryBuilder) { - for (Map.Entry nestedEntry : nestedMap.entrySet()) { - String operator = nestedEntry.getKey(); - Object operatorValue = nestedEntry.getValue(); - - switch (operator) { - case "term" -> queryBuilder.must(QueryBuilders.termQuery(field, operatorValue)); - case "range" -> processRangeQuery(field, operatorValue, queryBuilder); - case "match" -> queryBuilder.must(QueryBuilders.matchQuery(field, operatorValue)); - case "match_phrase" -> queryBuilder.must(QueryBuilders.matchPhraseQuery(field, operatorValue)); - case "prefix" -> queryBuilder.must(QueryBuilders.prefixQuery(field, operatorValue.toString())); - case "wildcard" -> processWildcardQuery(field, operatorValue, queryBuilder); - case "exists" -> queryBuilder.must(QueryBuilders.existsQuery(field)); - case "regexp" -> processRegexpQuery(field, operatorValue, queryBuilder); - default -> { - // Handle direct field-value mapping for nested structures - if (operatorValue instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) operatorValue; - BoolQueryBuilder nestedQuery = QueryBuilders.boolQuery(); - buildQueryFromMap(Map.of(operator, valueMap), nestedQuery); - queryBuilder.must(nestedQuery); - } else { - log.warn("Unsupported query operator: {}", operator); - } - } - } - } - } - - /** - * Processes range query conditions - * - * @param field Field name - * @param operatorValue Range conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processRangeQuery(String field, Object operatorValue, BoolQueryBuilder queryBuilder) { - if (!(operatorValue instanceof Map)) { - return; - } - - @SuppressWarnings("unchecked") - Map rangeMap = (Map) operatorValue; - RangeQueryBuilder rangeQuery = QueryBuilders.rangeQuery(field); - - rangeMap.forEach((rangeOp, rangeVal) -> { - switch (rangeOp) { - case "gte" -> rangeQuery.gte(rangeVal); - case "lte" -> rangeQuery.lte(rangeVal); - case "gt" -> rangeQuery.gt(rangeVal); - case "lt" -> rangeQuery.lt(rangeVal); - } - }); - - queryBuilder.must(rangeQuery); - } - - /** - * Processes wildcard query conditions - * - * @param field Field name - * @param operatorValue Wildcard conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processWildcardQuery(String field, Object operatorValue, BoolQueryBuilder queryBuilder) { - if (operatorValue instanceof Map) { - @SuppressWarnings("unchecked") - Map wildcardMap = (Map) operatorValue; - Object wildcardValue = wildcardMap.get("value"); - if (wildcardValue != null) { - queryBuilder.must(QueryBuilders.wildcardQuery(field, wildcardValue.toString())); - } - } else { - queryBuilder.must(QueryBuilders.wildcardQuery(field, operatorValue.toString())); - } - } - - /** - * Processes regexp query conditions - * - * @param field Field name - * @param operatorValue Regexp conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processRegexpQuery(String field, Object operatorValue, BoolQueryBuilder queryBuilder) { - if (operatorValue instanceof Map) { - @SuppressWarnings("unchecked") - Map regexpMap = (Map) operatorValue; - Object regexpValue = regexpMap.get("value"); - if (regexpValue != null) { - queryBuilder.must(QueryBuilders.regexpQuery(field, regexpValue.toString())); - } - } else { - queryBuilder.must(QueryBuilders.regexpQuery(field, operatorValue.toString())); - } - } - /** * Parses PPL query result into list of documents * diff --git a/src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java b/src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java new file mode 100644 index 00000000..39b8e45d --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java @@ -0,0 +1,482 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.agent.tools.utils.PPLExecuteHelper; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.transport.client.Client; + +import com.google.gson.reflect.TypeToken; + +import lombok.extern.log4j.Log4j2; + +/** + * Helper class for fetching and processing data from OpenSearch indices. + * Provides common functionality for data analysis tools including: + * - Field type detection and mapping + * - Time-based data fetching with DSL/PPL query support + * - Query building and parameter validation + * - Nested field value extraction + */ +@Log4j2 +public class DataFetchingHelper { + + private static final String DEFAULT_TIME_FIELD = "@timestamp"; + public static final String DATE_FORMAT_PATTERN = "yyyy-MM-dd HH:mm:ss"; + public static final String QUERY_TYPE_PPL = "ppl"; + public static final String QUERY_TYPE_DSL = "dsl"; + private static final String DEFAULT_SIZE = "1000"; + private static final int MAX_SIZE_LIMIT = 10000; + + public static final Set NUMBER_FIELD_TYPES = Set + .of("byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); + + private final Client client; + + /** + * Constructs a DataFetchingHelper with the given OpenSearch client + * + * @param client The OpenSearch client for executing queries + */ + public DataFetchingHelper(Client client) { + this.client = client; + } + + /** + * Parameters for data analysis operations + */ + public static class AnalysisParameters { + public final String index; + public final String timeField; + public final String selectionTimeRangeStart; + public final String selectionTimeRangeEnd; + public final String baselineTimeRangeStart; + public final String baselineTimeRangeEnd; + public final int size; + public final String queryType; + public final List filter; + public final String dsl; + public final String ppl; + + public AnalysisParameters(Map parameters) { + this.index = parameters.getOrDefault("index", ""); + this.timeField = parameters.getOrDefault("timeField", DEFAULT_TIME_FIELD); + this.selectionTimeRangeStart = parameters.getOrDefault("selectionTimeRangeStart", ""); + this.selectionTimeRangeEnd = parameters.getOrDefault("selectionTimeRangeEnd", ""); + this.baselineTimeRangeStart = parameters.getOrDefault("baselineTimeRangeStart", ""); + this.baselineTimeRangeEnd = parameters.getOrDefault("baselineTimeRangeEnd", ""); + + String sizeStr = parameters.getOrDefault("size", DEFAULT_SIZE); + int parsedSize; + try { + parsedSize = Double.valueOf(sizeStr).intValue(); + if (parsedSize <= 0 || parsedSize > MAX_SIZE_LIMIT) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid 'size' parameter: must be between 1 and %d, got '%s'", MAX_SIZE_LIMIT, sizeStr) + ); + } + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid 'size' parameter: '%s', must be a valid integer", sizeStr) + ); + } + this.size = parsedSize; + + this.queryType = parameters.getOrDefault("queryType", QUERY_TYPE_DSL); + + // Parse filter from JSON string to List + String filterParam = parameters.getOrDefault("filter", ""); + if (Strings.isNullOrEmpty(filterParam)) { + this.filter = List.of(); + } else { + try { + this.filter = List.of(gson.fromJson(filterParam, String[].class)); + } catch (Exception e) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Invalid 'filter' parameter: must be a valid JSON array of strings, got '%s'. " + + "Example: [\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]", + filterParam + ) + ); + } + } + + this.dsl = parameters.getOrDefault("dsl", ""); + this.ppl = parameters.getOrDefault("ppl", ""); + } + + /** + * Validates required parameters are present + * + * @throws IllegalArgumentException if required parameters are missing + */ + public void validate() { + if (Strings.isNullOrEmpty(index)) { + throw new IllegalArgumentException("Missing required parameter: 'index'"); + } + if (Strings.isNullOrEmpty(selectionTimeRangeStart) || Strings.isNullOrEmpty(selectionTimeRangeEnd)) { + throw new IllegalArgumentException("Missing required parameters: 'selectionTimeRangeStart' and 'selectionTimeRangeEnd'"); + } + if (!QUERY_TYPE_DSL.equals(queryType) && !QUERY_TYPE_PPL.equals(queryType)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid 'queryType': must be 'dsl' or 'ppl', got '%s'", queryType) + ); + } + } + + public boolean hasBaselineTimeRange() { + return !Strings.isNullOrEmpty(baselineTimeRangeStart) && !Strings.isNullOrEmpty(baselineTimeRangeEnd); + } + } + + /** + * Retrieves field type mappings from the specified index + * + * @param index The index name + * @param listener Action listener for handling the field type map or failures + */ + public void getFieldTypes(String index, ActionListener> listener) { + GetMappingsRequest request = new GetMappingsRequest().indices(index); + + client.admin().indices().getMappings(request, ActionListener.wrap(response -> { + Map fieldTypes = new HashMap<>(); + + for (Map.Entry entry : response.getMappings().entrySet()) { + MappingMetadata metadata = entry.getValue(); + Map mappingSource = metadata.getSourceAsMap(); + + if (mappingSource.containsKey("properties")) { + @SuppressWarnings("unchecked") + Map properties = (Map) mappingSource.get("properties"); + extractFieldTypes(properties, "", fieldTypes); + } + } + + listener.onResponse(fieldTypes); + }, listener::onFailure)); + } + + /** + * Recursively extracts field types from mapping properties + */ + private void extractFieldTypes(Map properties, String prefix, Map fieldTypes) { + for (Map.Entry entry : properties.entrySet()) { + String fieldName = prefix.isEmpty() ? entry.getKey() : prefix + "." + entry.getKey(); + + @SuppressWarnings("unchecked") + Map fieldProps = (Map) entry.getValue(); + + if (fieldProps.containsKey("type")) { + fieldTypes.put(fieldName, (String) fieldProps.get("type")); + } + + if (fieldProps.containsKey("properties")) { + @SuppressWarnings("unchecked") + Map nestedProps = (Map) fieldProps.get("properties"); + extractFieldTypes(nestedProps, fieldName, fieldTypes); + } + } + } + + /** + * Fetches data from index within the specified time range + * + * @param timeRangeStart Start time string + * @param timeRangeEnd End time string + * @param params Analysis parameters + * @param listener Action listener for handling the fetched data or failures + */ + public void fetchIndexData( + String timeRangeStart, + String timeRangeEnd, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + if (QUERY_TYPE_PPL.equals(params.queryType)) { + fetchDataWithPPL(timeRangeStart, timeRangeEnd, params, listener); + } else { + fetchDataWithDSL(timeRangeStart, timeRangeEnd, params, listener); + } + } catch (Exception e) { + log.error("Failed to fetch index data", e); + listener.onFailure(e); + } + } + + /** + * Builds a BoolQueryBuilder with time range filter and optional custom filters. + * Can be used with any SearchSourceBuilder (for documents or aggregations). + * + * @param timeRangeStart Start time string + * @param timeRangeEnd End time string + * @param params Analysis parameters containing timeField, dsl, and filter settings + * @return BoolQueryBuilder with time range and custom filters applied + */ + public BoolQueryBuilder buildFilterQuery(String timeRangeStart, String timeRangeEnd, AnalysisParameters params) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + + // Add time range filter + RangeQueryBuilder timeRangeQuery = QueryBuilders + .rangeQuery(params.timeField) + .gte(formatTimeString(timeRangeStart)) + .lte(formatTimeString(timeRangeEnd)) + .format("strict_date_optional_time||epoch_millis"); + boolQuery.must(timeRangeQuery); + + // Add custom query if provided + if (!Strings.isNullOrEmpty(params.dsl)) { + Map dslMap = buildQueryFromMap(params.dsl); + boolQuery.must(QueryBuilders.wrapperQuery(gson.toJson(dslMap))); + } else if (!params.filter.isEmpty()) { + for (String filterStr : params.filter) { + Map filterMap = buildQueryFromMap(filterStr); + boolQuery.must(QueryBuilders.wrapperQuery(gson.toJson(filterMap))); + } + } + + return boolQuery; + } + + /** + * Fetches data using DSL query + */ + private void fetchDataWithDSL( + String timeRangeStart, + String timeRangeEnd, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + BoolQueryBuilder boolQuery = buildFilterQuery(timeRangeStart, timeRangeEnd, params); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQuery).size(params.size).fetchSource(true); + + SearchRequest searchRequest = new SearchRequest(params.index).source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(response -> { + List> data = new ArrayList<>(); + for (SearchHit hit : response.getHits().getHits()) { + data.add(hit.getSourceAsMap()); + } + listener.onResponse(data); + }, listener::onFailure)); + + } catch (Exception e) { + log.error("Failed to fetch data with DSL", e); + listener.onFailure(e); + } + } + + /** + * Fetches data using PPL query + */ + private void fetchDataWithPPL( + String timeRangeStart, + String timeRangeEnd, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + String pplQuery = buildPPLQuery(timeRangeStart, timeRangeEnd, params); + + // Use PPLExecuteHelper with parser function + Function, List>> pplResultParser = this::parsePPLResult; + + PPLExecuteHelper.executePPLAndParseResult(client, pplQuery, pplResultParser, listener); + + } catch (Exception e) { + log.error("Failed to fetch data with PPL", e); + listener.onFailure(e); + } + } + + /** + * Parses PPL query result into list of documents + */ + private List> parsePPLResult(Map pplResult) { + Object datarowsObj = pplResult.get("datarows"); + Object schemaObj = pplResult.get("schema"); + + if (!(datarowsObj instanceof List) || !(schemaObj instanceof List)) { + log.warn("Invalid PPL result format"); + return new ArrayList<>(); + } + + @SuppressWarnings("unchecked") + List> datarows = (List>) datarowsObj; + @SuppressWarnings("unchecked") + List> schema = (List>) schemaObj; + + List fieldNames = new ArrayList<>(); + for (Map field : schema) { + fieldNames.add((String) field.get("name")); + } + + List> documents = new ArrayList<>(); + for (List row : datarows) { + Map doc = new HashMap<>(); + for (int i = 0; i < Math.min(row.size(), fieldNames.size()); i++) { + doc.put(fieldNames.get(i), row.get(i)); + } + documents.add(doc); + } + + return documents; + } + + /** + * Builds PPL query with time range filter + */ + private String buildPPLQuery(String timeRangeStart, String timeRangeEnd, AnalysisParameters params) { + String pplBase = !Strings.isNullOrEmpty(params.ppl) ? params.ppl : "source=" + params.index; + + String timeFilter = String + .format( + Locale.ROOT, + "%s >= '%s' AND %s <= '%s'", + params.timeField, + formatTimeString(timeRangeStart), + params.timeField, + formatTimeString(timeRangeEnd) + ); + + return String.format(Locale.ROOT, "%s | where %s | head %d", pplBase, timeFilter, params.size); + } + + /** + * Formats time string to ISO 8601 format for OpenSearch compatibility + * + * @param timeString Input time string + * @return Formatted time string in ISO 8601 format + * @throws DateTimeParseException if time string cannot be parsed + */ + private String formatTimeString(String timeString) throws DateTimeParseException { + log.debug("Attempting to parse time string: {}", timeString); + + // Try parsing with zone first + try { + if (timeString.endsWith("Z")) { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss'Z'", Locale.ROOT); + ZonedDateTime dateTime = ZonedDateTime.parse(timeString, formatter.withZone(ZoneOffset.UTC)); + return dateTime.format(DateTimeFormatter.ISO_INSTANT); + } + } catch (DateTimeParseException e) { + log.debug("Failed to parse as UTC time: {}", e.getMessage()); + } + + // Try parsing as local time without zone + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DATE_FORMAT_PATTERN, Locale.ROOT); + LocalDateTime localDateTime = LocalDateTime.parse(timeString, formatter); + ZonedDateTime zonedDateTime = localDateTime.atOffset(ZoneOffset.UTC).toZonedDateTime(); + return zonedDateTime.format(DateTimeFormatter.ISO_INSTANT); + } catch (DateTimeParseException e) { + log.debug("Failed to parse as local time: {}", e.getMessage()); + } + + // Try ISO format + try { + ZonedDateTime dateTime = ZonedDateTime.parse(timeString); + return dateTime.format(DateTimeFormatter.ISO_INSTANT); + } catch (DateTimeParseException e) { + log.debug("Failed to parse as ISO format: {}", e.getMessage()); + } + + throw new RuntimeException("Invalid time format: " + timeString); + } + + /** + * Builds query map from JSON string + */ + public Map buildQueryFromMap(String queryStr) { + if (Strings.isNullOrEmpty(queryStr)) { + return new HashMap<>(); + } + + try { + String normalizedQuery = queryStr.trim().replace("'", "\""); + return gson.fromJson(normalizedQuery, new TypeToken>() { + }.getType()); + } catch (Exception e) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid query format: %s. Error: %s", queryStr, e.getMessage())); + } + } + + /** + * Filters numeric fields from field type mappings + * + * @param fieldTypes Map of field names to their types + * @return Set of numeric field names + */ + public Set getNumberFields(Map fieldTypes) { + Set numberFields = new HashSet<>(); + for (Map.Entry entry : fieldTypes.entrySet()) { + if (NUMBER_FIELD_TYPES.contains(entry.getValue())) { + numberFields.add(entry.getKey()); + } + } + return numberFields; + } + + /** + * Extracts nested field value from document using dot notation + * + * @param doc The document map + * @param field The field path (e.g., "metrics.response_time") + * @return The field value or null if not found + */ + public Object getFlattenedValue(Map doc, String field) { + if (doc == null || field == null) { + return null; + } + + String[] parts = field.split("\\."); + Object current = doc; + + for (String part : parts) { + if (!(current instanceof Map)) { + return null; + } + @SuppressWarnings("unchecked") + Map currentMap = (Map) current; + current = currentMap.get(part); + if (current == null) { + return null; + } + } + + return current; + } +} diff --git a/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java new file mode 100644 index 00000000..e755b174 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java @@ -0,0 +1,556 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.transport.client.Client; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Tool for analyzing metric changes by comparing percentile distributions between time periods. + * Uses relative change (percentage change) to rank fields by significance. + * + * Usage: + * POST /_plugins/_ml/agents/{agent_id}/_execute + * { + * "parameters": { + * "index": "logs-2025.01.15", + * "timeField": "@timestamp", + * "selectionTimeRangeStart": "2025-01-15 10:00:00", + * "selectionTimeRangeEnd": "2025-01-15 11:00:00", + * "baselineTimeRangeStart": "2025-01-15 08:00:00", + * "baselineTimeRangeEnd": "2025-01-15 09:00:00", + * "size": 1000 + * } + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(MetricChangeAnalysisTool.TYPE) +public class MetricChangeAnalysisTool implements Tool { + public static final String TYPE = "MetricChangeAnalysisTool"; + + private static final String DEFAULT_DESCRIPTION = + "This tool analyzes a metric index to identify which numeric metrics changed most significantly between a selection time range and a baseline time range. " + + "It compares percentile distributions (P50, P90) of all numeric fields and returns the top N fields ranked by log-ratio change score. " + + "Use this tool for root cause analysis when investigating performance degradation, anomalies, or incidents in metric data. " + + "Keep both time ranges short and focused (e.g. 15-30 minutes) and similar in duration for accurate comparison."; + + private static final int DEFAULT_TOP_N = 10; + + public static final String DEFAULT_INPUT_SCHEMA = + """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "Target OpenSearch index name" + }, + "timeField": { + "type": "string", + "description": "Date/time field for filtering (default: @timestamp)" + }, + "selectionTimeRangeStart": { + "type": "string", + "description": "Start time for the analysis period (format: yyyy-MM-dd HH:mm:ss). The selection period is the time range where the change or anomaly occurred." + }, + "selectionTimeRangeEnd": { + "type": "string", + "description": "End time for the analysis period (format: yyyy-MM-dd HH:mm:ss)." + }, + "baselineTimeRangeStart": { + "type": "string", + "description": "Start time for the baseline period (format: yyyy-MM-dd HH:mm:ss). The baseline represents normal behavior. Its duration should be same to the selection period for a fair comparison." + }, + "baselineTimeRangeEnd": { + "type": "string", + "description": "End time for the baseline period (format: yyyy-MM-dd HH:mm:ss). Should be at or before selectionTimeRangeStart." + }, + "size": { + "type": "integer", + "description": "Maximum number of documents to analyze (default: 1000, max: 10000)" + }, + "topN": { + "type": "integer", + "description": "Number of top fields to return, ranked by change score (default: 10)" + }, + "queryType": { + "type": "string", + "description": "Query type: 'ppl' or 'dsl' (default: 'dsl')" + }, + "filter": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Additional DSL query conditions (optional)" + }, + "dsl": { + "type": "string", + "description": "Complete raw DSL query as JSON string (optional)" + }, + "ppl": { + "type": "string", + "description": "Complete PPL statement without time information (optional)" + } + }, + "required": ["index", "timeField", "selectionTimeRangeStart", "selectionTimeRangeEnd", "baselineTimeRangeStart", "baselineTimeRangeEnd"] + } + """; + + public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); + + /** + * Record for percentile statistics + */ + private record PercentileStats(double p50, double p90) { + } + + /** + * Record for field percentile analysis with variance + */ + private record FieldPercentileAnalysis(String field, double variance, PercentileStats selectionStats, PercentileStats baselineStats) { + } + + /** + * Result item for JSON output + */ + private record PercentileAnalysisResult(String field, Double changeScore, Map selectionPercentiles, + Map baselinePercentiles, Map logRatios) { + } + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String version; + private Client client; + private DataFetchingHelper dataFetchingHelper; + + /** + * Constructs a MetricChangeAnalysisTool with the given OpenSearch client + * + * @param client The OpenSearch client for executing queries + */ + public MetricChangeAnalysisTool(Client client) { + this.client = client; + this.dataFetchingHelper = new DataFetchingHelper(client); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public Map getAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public void setAttributes(Map map) {} + + @Override + public boolean validate(Map parameters) { + try { + // Use helper's validation logic + DataFetchingHelper.AnalysisParameters params = new DataFetchingHelper.AnalysisParameters(parameters); + params.validate(); + return true; + } catch (IllegalArgumentException e) { + log.error("Invalid parameters: {}", e.getMessage()); + return false; + } + } + + /** + * Executes percentile analysis on numeric fields between selection and baseline periods + * + * @param The response type + * @param originalParameters Input parameters for analysis + * @param listener Action listener for handling results or failures + */ + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, DEFAULT_ATTRIBUTES); + log.debug("Starting metric change analysis with parameters: {}", parameters.keySet()); + + // Extract topN parameter (use Double.valueOf to handle "30.0" from Gson serialization) + int topN = DEFAULT_TOP_N; + String topNStr = parameters.get("topN"); + if (topNStr != null && !topNStr.isEmpty()) { + try { + topN = Double.valueOf(topNStr).intValue(); + if (topN <= 0) { + topN = DEFAULT_TOP_N; + } + } catch (NumberFormatException e) { + log.warn("Invalid topN parameter '{}', using default: {}", topNStr, DEFAULT_TOP_N); + } + } + + // Use DataDistributionTool's data fetching mechanism + fetchDataAndAnalyze(parameters, topN, listener); + + } catch (IllegalArgumentException e) { + log.error("Invalid parameters for MetricChangeAnalysisTool: {}", e.getMessage()); + listener.onFailure(e); + } catch (Exception e) { + log.error("Unexpected error in MetricChangeAnalysisTool", e); + listener.onFailure(e); + } + } + + /** + * Fetches data using DataDistributionTool's mechanism and performs percentile analysis + */ + private void fetchDataAndAnalyze(Map parameters, int topN, ActionListener listener) { + try { + // Create analysis parameters + DataFetchingHelper.AnalysisParameters params = new DataFetchingHelper.AnalysisParameters(parameters); + + // Get field types first + String index = parameters.get("index"); + dataFetchingHelper.getFieldTypes(index, ActionListener.wrap((Map fieldTypes) -> { + // Get number fields + Set numberFields = dataFetchingHelper.getNumberFields(fieldTypes); + + if (numberFields.isEmpty()) { + listener + .onFailure( + new IllegalStateException("No numeric fields found in index. Percentile analysis requires numeric fields.") + ); + return; + } + + // Fetch selection and baseline data + fetchBothDatasets(params, numberFields, topN, listener); + }, listener::onFailure)); + + } catch (Exception e) { + log.error("Failed to fetch data for percentile analysis", e); + listener.onFailure(e); + } + } + + /** + * Fetches both selection and baseline datasets + */ + private void fetchBothDatasets( + DataFetchingHelper.AnalysisParameters params, + Set numberFields, + int topN, + ActionListener listener + ) { + try { + // Fetch selection data + dataFetchingHelper + .fetchIndexData( + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd, + params, + ActionListener.wrap((List> selectionData) -> { + // Fetch baseline data + dataFetchingHelper + .fetchIndexData( + params.baselineTimeRangeStart, + params.baselineTimeRangeEnd, + params, + ActionListener.wrap((List> baselineData) -> { + // Perform metric change analysis + performAnalysis(selectionData, baselineData, numberFields, topN, listener); + }, listener::onFailure) + ); + }, listener::onFailure) + ); + + } catch (Exception e) { + log.error("Failed to fetch datasets", e); + listener.onFailure(e); + } + } + + /** + * Performs metric change analysis on the fetched data + */ + private void performAnalysis( + List> selectionData, + List> baselineData, + Set numberFields, + int topN, + ActionListener listener + ) { + try { + if (selectionData.isEmpty()) { + listener.onFailure(new IllegalStateException("No data found for selection time range")); + return; + } + + if (baselineData.isEmpty()) { + listener.onFailure(new IllegalStateException("No data found for baseline time range")); + return; + } + + // Calculate percentiles and relative changes + List analyses = calculateMetricChangeAnalysis(selectionData, baselineData, numberFields); + List results = formatResults(analyses, topN); + listener.onResponse((T) gson.toJson(Map.of("percentileAnalysis", results))); + + } catch (Exception e) { + log.error("Failed to perform metric change analysis", e); + listener.onFailure(e); + } + } + + /** + * Formats analysis results for JSON output, limiting to top N results + */ + private List formatResults(List analyses, int topN) { + return analyses.stream().limit(topN).map(analysis -> { + Map selectionStats = Map.of("p50", analysis.selectionStats.p50, "p90", analysis.selectionStats.p90); + + Map baselineStats = Map.of("p50", analysis.baselineStats.p50, "p90", analysis.baselineStats.p90); + + Map logRatios = Map + .of( + "p50", + safeLogRatio(analysis.selectionStats.p50, analysis.baselineStats.p50), + "p90", + safeLogRatio(analysis.selectionStats.p90, analysis.baselineStats.p90) + ); + + return new PercentileAnalysisResult(analysis.field, analysis.variance, selectionStats, baselineStats, logRatios); + }).toList(); + } + + // ========== Metric Change Analysis Functions ========== + + /** + * Calculates metric change analysis for all numeric fields + */ + private List calculateMetricChangeAnalysis( + List> selectionData, + List> baselineData, + Set numberFields + ) { + List analyses = new ArrayList<>(); + + for (String field : numberFields) { + List selectionValues = extractNumericValues(selectionData, field); + List baselineValues = extractNumericValues(baselineData, field); + + if (selectionValues.isEmpty() || baselineValues.isEmpty()) { + continue; + } + + PercentileStats selectionStats = calculatePercentiles(selectionValues); + PercentileStats baselineStats = calculatePercentiles(baselineValues); + double variance = calculatePercentileVariance(selectionStats, baselineStats); + + analyses.add(new FieldPercentileAnalysis(field, variance, selectionStats, baselineStats)); + } + + analyses.sort(Comparator.comparingDouble((FieldPercentileAnalysis a) -> a.variance).reversed()); + return analyses; + } + + /** + * Extracts numeric values from dataset for a specific field + */ + private List extractNumericValues(List> data, String field) { + List values = new ArrayList<>(); + + for (Map doc : data) { + Object value = getFlattenedValue(doc, field); + if (value != null) { + try { + if (value instanceof Number) { + values.add(((Number) value).doubleValue()); + } else { + values.add(Double.parseDouble(value.toString())); + } + } catch (NumberFormatException e) { + // Skip non-numeric values + } + } + } + + return values; + } + + /** + * Extracts nested field values from document using dot notation + */ + private Object getFlattenedValue(Map doc, String field) { + return dataFetchingHelper.getFlattenedValue(doc, field); + } + + /** + * Calculates statistics (avg, P50, P90) for a list of values + */ + private PercentileStats calculatePercentiles(List values) { + if (values.isEmpty()) { + return new PercentileStats(0.0, 0.0); + } + + List sorted = new ArrayList<>(values); + sorted.sort(Double::compareTo); + + double p50 = calculatePercentile(sorted, 50); + double p90 = calculatePercentile(sorted, 90); + + return new PercentileStats(p50, p90); + } + + /** + * Calculates a specific percentile from sorted values + */ + private double calculatePercentile(List sortedValues, int percentile) { + if (sortedValues.isEmpty()) { + return 0.0; + } + + if (sortedValues.size() == 1) { + return sortedValues.get(0); + } + + double index = (percentile / 100.0) * (sortedValues.size() - 1); + int lowerIndex = (int) Math.floor(index); + int upperIndex = (int) Math.ceil(index); + + if (lowerIndex == upperIndex) { + return sortedValues.get(lowerIndex); + } + + double lowerValue = sortedValues.get(lowerIndex); + double upperValue = sortedValues.get(upperIndex); + double fraction = index - lowerIndex; + + return lowerValue + (upperValue - lowerValue) * fraction; + } + + private static final double LOG_RATIO_CAP = 10.0; + private static final double EPSILON = 1e-10; + + /** + * Calculates change score between selection and baseline statistics. + * Uses weighted log-ratio scoring on avg, P50, and P90. + * Skips a metric if its baseline is near zero to avoid inflated scores. + * Redistributes weight equally among valid metrics. + */ + private double calculatePercentileVariance(PercentileStats selectionStats, PercentileStats baselineStats) { + boolean p50Valid = Math.abs(baselineStats.p50) >= EPSILON; + boolean p90Valid = Math.abs(baselineStats.p90) >= EPSILON; + + if (!p50Valid && !p90Valid) { + return 0.0; + } + if (p50Valid && p90Valid) { + return 0.5 * safeLogRatio(selectionStats.p50, baselineStats.p50) + 0.5 * safeLogRatio(selectionStats.p90, baselineStats.p90); + } + if (p50Valid) { + return safeLogRatio(selectionStats.p50, baselineStats.p50); + } + return safeLogRatio(selectionStats.p90, baselineStats.p90); + } + + /** + * Computes |log(selection / baseline)| with safe handling of near-zero values. + * Returns 0 when both values are near zero, caps at LOG_RATIO_CAP when only baseline is near zero. + */ + private double safeLogRatio(double selection, double baseline) { + if (Math.abs(baseline) < EPSILON && Math.abs(selection) < EPSILON) { + return 0.0; + } + if (Math.abs(baseline) < EPSILON) { + return LOG_RATIO_CAP; + } + double ratio = selection / baseline; + if (ratio <= 0) { + return 0.0; + } + return Math.abs(Math.log(ratio)); + } + + /** + * Factory class for creating MetricChangeAnalysisTool instances + */ + public static class Factory implements Tool.Factory { + private Client client; + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (MetricChangeAnalysisTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public MetricChangeAnalysisTool create(Map map) { + return new MetricChangeAnalysisTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/test/java/org/opensearch/agent/ToolPluginTests.java b/src/test/java/org/opensearch/agent/ToolPluginTests.java index 0bbee973..311fa88c 100644 --- a/src/test/java/org/opensearch/agent/ToolPluginTests.java +++ b/src/test/java/org/opensearch/agent/ToolPluginTests.java @@ -96,7 +96,7 @@ public void test_getRestHandlers_successful() { @Test public void test_getToolFactories_successful() { - assertEquals(14, toolPlugin.getToolFactories().size()); + assertEquals(15, toolPlugin.getToolFactories().size()); } @Test diff --git a/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java b/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java index d53bf0d8..d79336b0 100644 --- a/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java @@ -5,6 +5,7 @@ package org.opensearch.agent.tools; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.jsoup.helper.Validate.fail; import static org.junit.Assert.assertEquals; @@ -883,8 +884,8 @@ public void testExecutionWithSizeExceedsMaxLimit() { "15000" ), ActionListener.wrap(response -> fail("Should have failed with size exceeding limit"), e -> { - MatcherAssert.assertThat(e.getMessage(), containsString("Size parameter exceeds maximum limit of 10000")); - MatcherAssert.assertThat(e.getMessage(), containsString("got: 15000")); + MatcherAssert.assertThat(e.getMessage(), containsString("must be between 1 and 10000")); + MatcherAssert.assertThat(e.getMessage(), containsString("15000")); }) ); } @@ -1168,284 +1169,6 @@ private List> createHighCardinalityTestData() { return data; } - @Test - @SneakyThrows - public void testBuildQueryFromMapWithTermQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("status", Map.of("term", "error")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("term")); - assertTrue(queryBuilder.toString().contains("status")); - assertTrue(queryBuilder.toString().contains("error")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithRangeQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("level", Map.of("range", Map.of("gte", 3, "lte", 5))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("range")); - assertTrue(queryBuilder.toString().contains("level")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMatchQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("message", Map.of("match", "test message")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("match")); - assertTrue(queryBuilder.toString().contains("message")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithExistsQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("status", Map.of("exists", true)); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("exists")); - assertTrue(queryBuilder.toString().contains("status")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithDirectTermQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("status", "error"); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("term")); - assertTrue(queryBuilder.toString().contains("status")); - assertTrue(queryBuilder.toString().contains("error")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMatchPhraseQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("message", Map.of("match_phrase", "exact phrase")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("match_phrase")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithPrefixQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("prefix", "server")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("prefix")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithWildcardQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("wildcard", "server*")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("wildcard")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithWildcardMapQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("wildcard", Map.of("value", "server*"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("wildcard")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithRegexpQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("regexp", "server-[0-9]+")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("regexp")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithRegexpMapQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("regexp", Map.of("value", "server-[0-9]+"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("regexp")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithTermsQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("terms", Map.of("status", List.of("error", "warning"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("terms")); - assertTrue(queryBuilder.toString().contains("status")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMultiMatchQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map - .of("multi_match", Map.of("query", "error message", "fields", List.of("message", "description"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("multi_match")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithComplexRangeQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("level", Map.of("range", Map.of("gt", 1, "lt", 10))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("range")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithUnsupportedOperator() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("status", Map.of("unsupported_op", "value")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - } - @Test @SneakyThrows public void testGroupNumericKeysWithManyNumericValues() { @@ -1852,6 +1575,7 @@ public void testDSLWithInvalidRawDSLQuery() { mockSearchResponse(); DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + // Test that invalid DSL query causes execution to fail tool .run( ImmutableMap @@ -1868,13 +1592,14 @@ public void testDSLWithInvalidRawDSLQuery() { "invalid-json-query" ), ActionListener.wrap(response -> { - // Should fallback to time range only query when DSL parsing fails - JsonElement result = gson.fromJson(response, JsonElement.class); + fail("Should have failed with invalid DSL query"); + }, e -> { + // Expect failure due to invalid DSL format assertTrue( - "Response should contain singleAnalysis even with invalid DSL", - result.getAsJsonObject().has("singleAnalysis") + "Should fail with exception for invalid DSL", + e instanceof IllegalArgumentException || e.getMessage().contains("Invalid query format") ); - }, e -> fail("Tool execution failed: " + e.getMessage())) + }) ); } @@ -2101,6 +1826,7 @@ public void testDSLWithFilterArrayInvalidFilter() { mockSearchResponse(); DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + // Test that invalid filter JSON causes parameter validation to fail tool .run( ImmutableMap @@ -2117,13 +1843,15 @@ public void testDSLWithFilterArrayInvalidFilter() { "[\"{'term': {'status': 'error'}}\", \"invalid-json-filter\"]" ), ActionListener.wrap(response -> { - // Should continue processing valid filters and ignore invalid ones - JsonElement result = gson.fromJson(response, JsonElement.class); + fail("Should have failed with invalid filter JSON"); + }, e -> { + // Expect IllegalArgumentException due to invalid filter format + assertTrue("Should fail with IllegalArgumentException for invalid filter", e instanceof IllegalArgumentException); assertTrue( - "Response should contain singleAnalysis even with some invalid filters", - result.getAsJsonObject().has("singleAnalysis") + "Error message should mention invalid filter", + e.getMessage().contains("Invalid 'filter' parameter") || e.getMessage().contains("Invalid query format") ); - }, e -> fail("Tool execution failed: " + e.getMessage())) + }) ); } @@ -2484,8 +2212,13 @@ public void testDSLWithMalformedFilterJSON() { "[malformed-json]" ), ActionListener.wrap(response -> fail("Should have failed with malformed filter JSON"), e -> { - MatcherAssert.assertThat(e.getMessage(), containsString("Invalid 'filter' parameter")); - MatcherAssert.assertThat(e.getMessage(), containsString("must be a valid JSON array of strings")); + // The filter "[malformed-json]" is valid JSON array syntax, but "malformed-json" is not valid query JSON + // Error occurs during query execution, not parameter validation + MatcherAssert + .assertThat( + e.getMessage(), + anyOf(containsString("Invalid query format"), containsString("Invalid 'filter' parameter")) + ); }) ); } @@ -2637,84 +2370,4 @@ public void testDSLQueryPrecedenceOverFilter() { }, e -> fail("Tool execution failed: " + e.getMessage())) ); } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithTermsQueryNonListValue() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - // Test terms query with non-list value (should be ignored) - Map filterMap = Map.of("terms", Map.of("status", "error")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - // Should not contain terms query since value is not a list - assertFalse(queryBuilder.toString().contains("terms")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMultiMatchQueryMissingFields() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - // Test multi_match query with missing fields (should be ignored) - Map filterMap = Map.of("multi_match", Map.of("query", "error message")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - // Should not contain multi_match query since fields is missing - assertFalse(queryBuilder.toString().contains("multi_match")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMultiMatchQueryMissingQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - // Test multi_match query with missing query (should be ignored) - Map filterMap = Map.of("multi_match", Map.of("fields", List.of("message", "description"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - // Should not contain multi_match query since query is missing - assertFalse(queryBuilder.toString().contains("multi_match")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMultiMatchQueryNonListFields() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - // Test multi_match query with non-list fields (should be ignored) - Map filterMap = Map.of("multi_match", Map.of("query", "error message", "fields", "message")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - // Should not contain multi_match query since fields is not a list - assertFalse(queryBuilder.toString().contains("multi_match")); - } } diff --git a/src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java b/src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java new file mode 100644 index 00000000..0ffcf06e --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java @@ -0,0 +1,619 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.IndicesAdminClient; + +import com.google.common.collect.ImmutableMap; + +import lombok.SneakyThrows; + +public class MetricChangeAnalysisToolTests { + + private Map params = new HashMap<>(); + private final Client client = mock(Client.class); + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + @Mock + private SearchResponse searchResponse; + + @SneakyThrows + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + setupMockMappings(); + MetricChangeAnalysisTool.Factory.getInstance().init(client); + } + + private void setupMockMappings() { + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + + Map properties = ImmutableMap + .builder() + .put("response_time", ImmutableMap.of("type", "long")) + .put("cpu_usage", ImmutableMap.of("type", "double")) + .put("memory_usage", ImmutableMap.of("type", "float")) + .put("status", ImmutableMap.of("type", "keyword")) + .put("@timestamp", ImmutableMap.of("type", "date")) + .build(); + + Map mappingSource = ImmutableMap.of("properties", properties); + when(mappingMetadata.getSourceAsMap()).thenReturn(mappingSource); + when(getMappingsResponse.getMappings()).thenReturn(ImmutableMap.of("test-index", mappingMetadata)); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getMappingsResponse); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + } + + private SearchHit[] createSampleHits(Map... sources) { + SearchHit[] hits = new SearchHit[sources.length]; + for (int i = 0; i < sources.length; i++) { + hits[i] = new SearchHit(i); + hits[i].sourceRef(null); + hits[i].score(1.0f); + // Use reflection to set source + try { + java.lang.reflect.Field sourceField = SearchHit.class.getDeclaredField("source"); + sourceField.setAccessible(true); + sourceField.set(hits[i], sources[i]); + } catch (Exception e) { + // Fallback: create hit with source + } + } + return hits; + } + + private void mockSearchResponse(SearchHit[] hits) { + SearchHits searchHits = new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + when(searchResponse.getHits()).thenReturn(searchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + } + + // ========== Percentile Calculation Tests ========== + + @Test + @SneakyThrows + public void testCalculatePercentile() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentile", List.class, int.class); + calculatePercentileMethod.setAccessible(true); + + // Test with sorted values + List values = List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0); + + double p25 = (double) calculatePercentileMethod.invoke(tool, values, 25); + double p50 = (double) calculatePercentileMethod.invoke(tool, values, 50); + double p75 = (double) calculatePercentileMethod.invoke(tool, values, 75); + double p90 = (double) calculatePercentileMethod.invoke(tool, values, 90); + + assertEquals(3.25, p25, 0.01); + assertEquals(5.5, p50, 0.01); + assertEquals(7.75, p75, 0.01); + assertEquals(9.1, p90, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileWithSingleValue() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentile", List.class, int.class); + calculatePercentileMethod.setAccessible(true); + + List values = List.of(5.0); + + double p25 = (double) calculatePercentileMethod.invoke(tool, values, 25); + double p50 = (double) calculatePercentileMethod.invoke(tool, values, 50); + double p75 = (double) calculatePercentileMethod.invoke(tool, values, 75); + double p90 = (double) calculatePercentileMethod.invoke(tool, values, 90); + + assertEquals(5.0, p25, 0.01); + assertEquals(5.0, p50, 0.01); + assertEquals(5.0, p75, 0.01); + assertEquals(5.0, p90, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileWithEmptyList() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentile", List.class, int.class); + calculatePercentileMethod.setAccessible(true); + + List values = List.of(); + + double p50 = (double) calculatePercentileMethod.invoke(tool, values, 50); + + assertEquals(0.0, p50, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentiles() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentilesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentiles", List.class); + calculatePercentilesMethod.setAccessible(true); + + List values = List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0); + + Object result = calculatePercentilesMethod.invoke(tool, values); + + assertNotNull(result); + java.lang.reflect.Method p50Method = result.getClass().getDeclaredMethod("p50"); + java.lang.reflect.Method p90Method = result.getClass().getDeclaredMethod("p90"); + + double p50 = (double) p50Method.invoke(result); + double p90 = (double) p90Method.invoke(result); + + assertEquals(5.5, p50, 0.01); + assertEquals(9.1, p90, 0.01); + } + + // ========== Numeric Value Extraction Tests ========== + + @Test + @SneakyThrows + public void testExtractNumericValues() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method extractNumericValuesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("extractNumericValues", List.class, String.class); + extractNumericValuesMethod.setAccessible(true); + + List> data = List + .of( + Map.of("level", 1, "status", "error"), + Map.of("level", 2, "status", "warning"), + Map.of("level", 3, "status", "info"), + Map.of("level", "4", "status", "debug"), // String number + Map.of("status", "error") // Missing level field + ); + + @SuppressWarnings("unchecked") + List values = (List) extractNumericValuesMethod.invoke(tool, data, "level"); + + assertNotNull(values); + assertEquals(4, values.size()); + assertTrue(values.contains(1.0)); + assertTrue(values.contains(2.0)); + assertTrue(values.contains(3.0)); + assertTrue(values.contains(4.0)); + } + + @Test + @SneakyThrows + public void testExtractNumericValuesWithNonNumericData() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method extractNumericValuesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("extractNumericValues", List.class, String.class); + extractNumericValuesMethod.setAccessible(true); + + List> data = List.of(Map.of("status", "error"), Map.of("status", "warning"), Map.of("status", "info")); + + @SuppressWarnings("unchecked") + List values = (List) extractNumericValuesMethod.invoke(tool, data, "status"); + + assertNotNull(values); + assertTrue(values.isEmpty()); + } + + @Test + @SneakyThrows + public void testExtractNumericValuesWithNestedField() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method extractNumericValuesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("extractNumericValues", List.class, String.class); + extractNumericValuesMethod.setAccessible(true); + + List> data = List + .of( + Map.of("metrics", Map.of("response_time", 100)), + Map.of("metrics", Map.of("response_time", 200)), + Map.of("metrics", Map.of("response_time", 300)) + ); + + @SuppressWarnings("unchecked") + List values = (List) extractNumericValuesMethod.invoke(tool, data, "metrics.response_time"); + + assertNotNull(values); + assertEquals(3, values.size()); + assertTrue(values.contains(100.0)); + assertTrue(values.contains(200.0)); + assertTrue(values.contains(300.0)); + } + + // ========== Log Ratio Tests ========== + + @Test + @SneakyThrows + public void testSafeLogRatio() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method safeLogRatioMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("safeLogRatio", double.class, double.class); + safeLogRatioMethod.setAccessible(true); + + // Test 2x increase: |log(2/1)| = log(2) ≈ 0.693 + double ratio1 = (double) safeLogRatioMethod.invoke(tool, 2.0, 1.0); + assertEquals(Math.log(2.0), ratio1, 0.01); + + // Test 1.8x increase: |log(180/100)| = log(1.8) ≈ 0.588 + double ratio2 = (double) safeLogRatioMethod.invoke(tool, 180.0, 100.0); + assertEquals(Math.log(1.8), ratio2, 0.01); + + // Test decrease: |log(5/10)| = |log(0.5)| = log(2) ≈ 0.693 + double ratio3 = (double) safeLogRatioMethod.invoke(tool, 5.0, 10.0); + assertEquals(Math.log(2.0), ratio3, 0.01); + + // Test zero baseline: should return cap (10.0) + double ratio4 = (double) safeLogRatioMethod.invoke(tool, 10.0, 0.0); + assertEquals(10.0, ratio4, 0.01); + + // Test both near zero: should return 0.0 + double ratio5 = (double) safeLogRatioMethod.invoke(tool, 0.0, 0.0); + assertEquals(0.0, ratio5, 0.01); + + // Test no change: |log(1)| = 0 + double ratio6 = (double) safeLogRatioMethod.invoke(tool, 100.0, 100.0); + assertEquals(0.0, ratio6, 0.01); + } + + // ========== Variance Calculation Tests ========== + + @Test + @SneakyThrows + public void testCalculatePercentileVarianceSkipsZeroBaseline() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + // Both baselines zero → score 0 + Object sel1 = constructor.newInstance(10.0, 5.0); + Object base1 = constructor.newInstance(0.0, 0.0); + assertEquals(0.0, (double) calculatePercentileVarianceMethod.invoke(tool, sel1, base1), 0.01); + + // Only P50 baseline zero → score based on P90 only + Object sel2 = constructor.newInstance(10.0, 20.0); + Object base2 = constructor.newInstance(0.0, 10.0); + double expected = Math.abs(Math.log(20.0 / 10.0)); // log(2) ≈ 0.693 + assertEquals(expected, (double) calculatePercentileVarianceMethod.invoke(tool, sel2, base2), 0.01); + + // Only P90 baseline zero → score based on P50 only + Object sel3 = constructor.newInstance(20.0, 5.0); + Object base3 = constructor.newInstance(10.0, 0.0); + double expected3 = Math.abs(Math.log(20.0 / 10.0)); // log(2) ≈ 0.693 + assertEquals(expected3, (double) calculatePercentileVarianceMethod.invoke(tool, sel3, base3), 0.01); + } + + // ========== Variance Calculation Tests ========== + + @Test + @SneakyThrows + public void testCalculatePercentileVariance() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + // Create PercentileStats using reflection (p50, p90) + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + // selection p50=20, p90=40; baseline p50=10, p90=20 → both are 2x + Object selectionStats = constructor.newInstance(20.0, 40.0); + Object baselineStats = constructor.newInstance(10.0, 20.0); + + double variance = (double) calculatePercentileVarianceMethod.invoke(tool, selectionStats, baselineStats); + + // score = 0.5 * log(2) + 0.5 * log(2) = log(2) ≈ 0.693 + assertEquals(Math.log(2.0), variance, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileVarianceWithNoChange() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + Object selectionStats = constructor.newInstance(20.0, 40.0); + Object baselineStats = constructor.newInstance(20.0, 40.0); + + double variance = (double) calculatePercentileVarianceMethod.invoke(tool, selectionStats, baselineStats); + + assertEquals(0.0, variance, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileVarianceRelativeChange() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + // Test that 1->2 (2x) ranks higher than 100->180 (1.8x) + Object smallChangeStats = constructor.newInstance(2.0, 2.0); + Object smallBaselineStats = constructor.newInstance(1.0, 1.0); + + Object largeChangeStats = constructor.newInstance(180.0, 180.0); + Object largeBaselineStats = constructor.newInstance(100.0, 100.0); + + double smallVariance = (double) calculatePercentileVarianceMethod.invoke(tool, smallChangeStats, smallBaselineStats); + double largeVariance = (double) calculatePercentileVarianceMethod.invoke(tool, largeChangeStats, largeBaselineStats); + + // 2x change: 0.5 * log(2) + 0.5 * log(2) = log(2) ≈ 0.693 + // 1.8x change: 0.5 * log(1.8) + 0.5 * log(1.8) = log(1.8) ≈ 0.588 + assertEquals(Math.log(2.0), smallVariance, 0.01); + assertEquals(Math.log(1.8), largeVariance, 0.01); + assertTrue("2x change should rank higher than 1.8x change", smallVariance > largeVariance); + } + + // ========== Percentile Analysis Tests ========== + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysis() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + // Use non-monotonic (gauge-like) data to avoid counter detection + List> selectionData = List + .of( + Map.of("response_time", 400, "cpu_usage", 80), + Map.of("response_time", 100, "cpu_usage", 50), + Map.of("response_time", 300, "cpu_usage", 60), + Map.of("response_time", 200, "cpu_usage", 70) + ); + + List> baselineData = List + .of( + Map.of("response_time", 150, "cpu_usage", 65), + Map.of("response_time", 50, "cpu_usage", 45), + Map.of("response_time", 200, "cpu_usage", 75), + Map.of("response_time", 100, "cpu_usage", 55) + ); + + java.util.Set numberFields = java.util.Set.of("response_time", "cpu_usage"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + assertEquals(2, analyses.size()); + + // Verify first analysis has highest variance + java.lang.reflect.Method varianceMethod = analyses.get(0).getClass().getDeclaredMethod("variance"); + double firstVariance = (double) varianceMethod.invoke(analyses.get(0)); + double secondVariance = (double) varianceMethod.invoke(analyses.get(1)); + + assertTrue(firstVariance >= secondVariance); + } + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysisWithEmptyData() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + List> selectionData = List.of(); + List> baselineData = List.of(); + java.util.Set numberFields = java.util.Set.of("response_time"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + assertTrue(analyses.isEmpty()); + } + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysisWithMissingFields() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + List> selectionData = List.of(Map.of("response_time", 100), Map.of("response_time", 200)); + + List> baselineData = List.of(Map.of("cpu_usage", 50), Map.of("cpu_usage", 60)); + + java.util.Set numberFields = java.util.Set.of("response_time", "cpu_usage"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + // Should skip fields that don't have data in both datasets + assertTrue(analyses.isEmpty()); + } + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysisRanking() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + // Create non-monotonic data where response_time has high change and cpu_usage has low change + List> selectionData = List + .of( + Map.of("response_time", 3000, "cpu_usage", 52), + Map.of("response_time", 1000, "cpu_usage", 54), + Map.of("response_time", 4000, "cpu_usage", 51), + Map.of("response_time", 2000, "cpu_usage", 53) + ); + + List> baselineData = List + .of( + Map.of("response_time", 300, "cpu_usage", 52), + Map.of("response_time", 100, "cpu_usage", 50), + Map.of("response_time", 400, "cpu_usage", 53), + Map.of("response_time", 200, "cpu_usage", 51) + ); + + java.util.Set numberFields = java.util.Set.of("response_time", "cpu_usage"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + assertEquals(2, analyses.size()); + + // First field should be response_time (higher variance) + java.lang.reflect.Method fieldMethod = analyses.get(0).getClass().getDeclaredMethod("field"); + String firstField = (String) fieldMethod.invoke(analyses.get(0)); + assertEquals("response_time", firstField); + + // Second field should be cpu_usage (lower variance) + String secondField = (String) fieldMethod.invoke(analyses.get(1)); + assertEquals("cpu_usage", secondField); + } + + @Test + @SneakyThrows + public void testFormatResultsLimitsToTopTen() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method formatResultsMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("formatResults", List.class, int.class); + formatResultsMethod.setAccessible(true); + + // Create 15 fields with different variance scores + List analyses = new ArrayList<>(); + Class fieldAnalysisClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$FieldPercentileAnalysis"); + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + + java.lang.reflect.Constructor statsConstructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + statsConstructor.setAccessible(true); + + java.lang.reflect.Constructor analysisConstructor = fieldAnalysisClass + .getDeclaredConstructor(String.class, double.class, percentileStatsClass, percentileStatsClass); + analysisConstructor.setAccessible(true); + + Object stats = statsConstructor.newInstance(20.0, 40.0); + + // Create 15 fields with descending variance scores + for (int i = 0; i < 15; i++) { + double variance = 15.0 - i; // 15.0, 14.0, 13.0, ..., 1.0 + Object analysis = analysisConstructor.newInstance("field_" + i, variance, stats, stats); + analyses.add(analysis); + } + + // Test with topN = 10 + @SuppressWarnings("unchecked") + List results10 = (List) formatResultsMethod.invoke(tool, analyses, 10); + assertNotNull(results10); + assertEquals("Should return only top 10 results", 10, results10.size()); + + // Test with topN = 5 (default) + @SuppressWarnings("unchecked") + List results5 = (List) formatResultsMethod.invoke(tool, analyses, 5); + assertNotNull(results5); + assertEquals("Should return only top 5 results", 5, results5.size()); + + // Test with topN = 3 + @SuppressWarnings("unchecked") + List results3 = (List) formatResultsMethod.invoke(tool, analyses, 3); + assertNotNull(results3); + assertEquals("Should return only top 3 results", 3, results3.size()); + } +} diff --git a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java index dee0a019..8c2f197a 100644 --- a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java +++ b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java @@ -176,7 +176,7 @@ public void testDataDistributionToolWithFilter() { @SneakyThrows public void testDataDistributionToolMissingRequiredParameters() { Exception exception = assertThrows(Exception.class, () -> executeAgent(agentId, "{\"parameters\": {\"index\": \"test_index\"}}")); - MatcherAssert.assertThat(exception.getMessage(), containsString("Unable to parse time string")); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid time format")); } @SneakyThrows @@ -429,7 +429,7 @@ public void testDataDistributionToolInvalidTimeFormat() { ) ) ); - MatcherAssert.assertThat(exception.getMessage(), containsString("Unable to parse time string")); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid time format")); } @SneakyThrows From 4be3a13fa543497f56bac82b52149ff06b482b87 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Wed, 4 Mar 2026 17:18:59 +0800 Subject: [PATCH 24/30] update tool description to be more clear (#703) spotlesApply update tool description to be more clear Signed-off-by: Hailong Cui --- .../agent/tools/DataDistributionTool.java | 26 +++-- .../agent/tools/LogPatternAnalysisTool.java | 104 +++++++++--------- .../agent/tools/MetricChangeAnalysisTool.java | 21 ++-- 3 files changed, 80 insertions(+), 71 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java index 9a163126..b3d44505 100644 --- a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java +++ b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java @@ -103,7 +103,10 @@ public class DataDistributionTool implements Tool { public static final String STRICT_FIELD = "strict"; private static final String DEFAULT_DESCRIPTION = - "This tool analyzes data distribution differences between time ranges or provides single dataset insights."; + "Analyzes field value distributions in a target time range, optionally compared to a baseline. " + + "Use to identify which fields changed most when investigating anomalies. " + + "Two modes: (1) Comparison (baseline provided): ranks fields by divergence. " + + "(2) Single (no baseline): summarizes field distributions."; private static final Set USEFUL_FIELD_TYPES = Set .of("keyword", "boolean", "text", "byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); @@ -134,42 +137,42 @@ public class DataDistributionTool implements Tool { }, "selectionTimeRangeStart": { "type": "string", - "description": "Start time for analysis period" + "description": "Start of target period (format: yyyy-MM-dd HH:mm:ss)" }, "selectionTimeRangeEnd": { "type": "string", - "description": "End time for analysis period" + "description": "End of target period (format: yyyy-MM-dd HH:mm:ss)" }, "baselineTimeRangeStart": { "type": "string", - "description": "Start time for baseline period (optional)" + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baselineTimeRangeEnd" }, "baselineTimeRangeEnd": { "type": "string", - "description": "End time for baseline period (optional)" + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baselineTimeRangeStart" }, "size": { "type": "integer", - "description": "Maximum number of documents to analyze (default: 1000)" + "description": "Max documents to sample (default: 1000, max: 10000)" }, "queryType": { "type": "string", - "description": "Query type: 'ppl' or 'dsl' (default: 'dsl')" + "description": "Query type: 'dsl' (default) or 'ppl'" }, "filter": { "type": "array", "items": { "type": "string" }, - "description": "Additional DSL query conditions for filtering (optional)" + "description": "Additional DSL filter clauses as JSON strings" }, "dsl": { "type": "string", - "description": "Complete raw DSL query as JSON string (optional)" + "description": "Complete DSL query as JSON string" }, "ppl": { "type": "string", - "description": "Complete PPL statement without time information (optional)" + "description": "PPL query without time filtering (added automatically)" } }, "required": ["index", "selectionTimeRangeStart", "selectionTimeRangeEnd"], @@ -177,7 +180,8 @@ public class DataDistributionTool implements Tool { } """; - public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); + public static final Map DEFAULT_ATTRIBUTES = Map + .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class)), STRICT_FIELD, false); /** * Result class for data distribution analysis diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java index cd19be9f..777be285 100644 --- a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java +++ b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java @@ -92,63 +92,67 @@ public class LogPatternAnalysisTool implements Tool { public static final String STRICT_FIELD = "strict"; // Constants - private static final String DEFAULT_DESCRIPTION = - "This is a tool used to detect selection log patterns by the patterns command in PPL or to detect selection log sequences by the log clustering algorithm."; + private static final String DEFAULT_DESCRIPTION = "Analyzes log patterns in a target time range, optionally compared to a baseline. " + + "Use when investigating incidents to find new, spiking, or anomalous log messages. " + + "Three modes: " + + "(1) Sequence analysis (traceFieldName + baseline): groups logs by trace ID, returns exceptional trace sequences. " + + "(2) Pattern diff (baseline, no traceFieldName): compares pattern frequencies, returns highest-lift patterns. " + + "(3) Log insight (no baseline): finds top error/warning patterns with sample logs."; private static final double LOG_VECTORS_CLUSTERING_THRESHOLD = 0.5; private static final double LOG_PATTERN_THRESHOLD = 0.75; private static final double LOG_PATTERN_LIFT = 3; private static final String DEFAULT_TIME_FIELD = "@timestamp"; - public static final String DEFAULT_INPUT_SCHEMA = - """ - { - "type": "object", - "properties": { - "index": { - "type": "string", - "description": "Target OpenSearch index name containing log data (e.g., 'ss4o_logs-otel-2025.06.24')" - }, - "timeField": { - "type": "string", - "description": "Date/time field in the index mapping used for time-based filtering" - }, - "logFieldName": { - "type": "string", - "description": "Field containing raw log messages to analyze (e.g., 'body', 'message', 'log')" - }, - "traceFieldName": { - "type": "string", - "description": "[OPTIONAL] Field for trace/correlation ID to enable sequence analysis (e.g., 'traceId', 'correlationId'). Leave empty for pattern-only analysis." - }, - "baseTimeRangeStart": { - "type": "string", - "description": "Start time for baseline comparison period (date string in utc timezone, e.g., '2025-06-24 07:33:05')" - }, - "baseTimeRangeEnd": { - "type": "string", - "description": "End time for baseline comparison period (date string in utc timezone, e.g., '2025-06-24 07:51:27')" - }, - "selectionTimeRangeStart": { - "type": "string", - "description": "Start time for analysis target period (date string in utc timezone, e.g., '2025-06-24 07:50:26')" - }, - "selectionTimeRangeEnd": { - "type": "string", - "description": "End time for analysis target period (date string in utc timezone, e.g., '2025-06-24 07:55:56')" - } + public static final String DEFAULT_INPUT_SCHEMA = """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "Target OpenSearch index name" }, - "required": [ - "index", - "timeField", - "logFieldName", - "selectionTimeRangeStart", - "selectionTimeRangeEnd" - ], - "additionalProperties": false - } - """; + "timeField": { + "type": "string", + "description": "Date/time field for filtering" + }, + "logFieldName": { + "type": "string", + "description": "Field containing log message text" + }, + "traceFieldName": { + "type": "string", + "description": "Trace/correlation ID field. Enables sequence analysis mode when provided with baseline time range" + }, + "baseTimeRangeStart": { + "type": "string", + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeEnd" + }, + "baseTimeRangeEnd": { + "type": "string", + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeStart" + }, + "selectionTimeRangeStart": { + "type": "string", + "description": "Start of target/incident period (format: yyyy-MM-dd HH:mm:ss)" + }, + "selectionTimeRangeEnd": { + "type": "string", + "description": "End of target/incident period (format: yyyy-MM-dd HH:mm:ss)" + } + }, + "required": [ + "index", + "timeField", + "logFieldName", + "selectionTimeRangeStart", + "selectionTimeRangeEnd" + ], + "additionalProperties": false + } + """; - public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); + public static final Map DEFAULT_ATTRIBUTES = Map + .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class)), STRICT_FIELD, false); // Compiled regex patterns for better performance private static final Pattern REPEATED_WILDCARDS_PATTERN = Pattern.compile("(<\\*>)(\\s+<\\*>)+"); diff --git a/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java index e755b174..d2af20d1 100644 --- a/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java +++ b/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java @@ -50,10 +50,10 @@ public class MetricChangeAnalysisTool implements Tool { public static final String TYPE = "MetricChangeAnalysisTool"; private static final String DEFAULT_DESCRIPTION = - "This tool analyzes a metric index to identify which numeric metrics changed most significantly between a selection time range and a baseline time range. " - + "It compares percentile distributions (P50, P90) of all numeric fields and returns the top N fields ranked by log-ratio change score. " - + "Use this tool for root cause analysis when investigating performance degradation, anomalies, or incidents in metric data. " - + "Keep both time ranges short and focused (e.g. 15-30 minutes) and similar in duration for accurate comparison."; + "Compares percentile distributions (P50, P90) of numeric fields between two time ranges. " + + "Returns top fields ranked by change score. " + + "Use for root cause analysis when investigating metric anomalies. " + + "Keep both time ranges short (e.g. 15-30 minutes) and similar in duration for accurate comparison."; private static final int DEFAULT_TOP_N = 10; @@ -68,23 +68,23 @@ public class MetricChangeAnalysisTool implements Tool { }, "timeField": { "type": "string", - "description": "Date/time field for filtering (default: @timestamp)" + "description": "Date/time field for filtering" }, "selectionTimeRangeStart": { "type": "string", - "description": "Start time for the analysis period (format: yyyy-MM-dd HH:mm:ss). The selection period is the time range where the change or anomaly occurred." + "description": "Start of target period (format: yyyy-MM-dd HH:mm:ss)" }, "selectionTimeRangeEnd": { "type": "string", - "description": "End time for the analysis period (format: yyyy-MM-dd HH:mm:ss)." + "description": "End of target period (format: yyyy-MM-dd HH:mm:ss)" }, "baselineTimeRangeStart": { "type": "string", - "description": "Start time for the baseline period (format: yyyy-MM-dd HH:mm:ss). The baseline represents normal behavior. Its duration should be same to the selection period for a fair comparison." + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss)" }, "baselineTimeRangeEnd": { "type": "string", - "description": "End time for the baseline period (format: yyyy-MM-dd HH:mm:ss). Should be at or before selectionTimeRangeStart." + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Should be at or before selectionTimeRangeStart" }, "size": { "type": "integer", @@ -118,7 +118,8 @@ public class MetricChangeAnalysisTool implements Tool { } """; - public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); + public static final Map DEFAULT_ATTRIBUTES = Map + .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class))); /** * Record for percentile statistics From 0f2516240d7e8609fa64a3e72711b2930f906a31 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Tue, 10 Mar 2026 10:38:53 +0800 Subject: [PATCH 25/30] Search around tool (#702) --- .../java/org/opensearch/agent/ToolPlugin.java | 3 + .../agent/tools/SearchAroundDocumentTool.java | 370 ++++++++++ .../org/opensearch/agent/ToolPluginTests.java | 2 +- .../tools/SearchAroundDocumentToolTests.java | 687 ++++++++++++++++++ .../integTest/SearchAroundDocumentToolIT.java | 280 +++++++ ...rch_around_document_tool_request_body.json | 9 + 6 files changed, 1350 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java create mode 100644 src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java create mode 100644 src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java create mode 100644 src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 97228054..a8faae88 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -23,6 +23,7 @@ import org.opensearch.agent.tools.SearchAlertsTool; import org.opensearch.agent.tools.SearchAnomalyDetectorsTool; import org.opensearch.agent.tools.SearchAnomalyResultsTool; +import org.opensearch.agent.tools.SearchAroundDocumentTool; import org.opensearch.agent.tools.SearchMonitorsTool; import org.opensearch.agent.tools.VectorDBTool; import org.opensearch.agent.tools.WebSearchTool; @@ -103,6 +104,7 @@ public Collection createComponents( WebSearchTool.Factory.getInstance().init(threadPool); LogPatternAnalysisTool.Factory.getInstance().init(client); DataDistributionTool.Factory.getInstance().init(client); + SearchAroundDocumentTool.Factory.getInstance().init(client, xContentRegistry); MetricChangeAnalysisTool.Factory.getInstance().init(client); return Collections.emptyList(); } @@ -125,6 +127,7 @@ public List> getToolFactories() { WebSearchTool.Factory.getInstance(), LogPatternAnalysisTool.Factory.getInstance(), DataDistributionTool.Factory.getInstance(), + SearchAroundDocumentTool.Factory.getInstance(), MetricChangeAnalysisTool.Factory.getInstance() ); } diff --git a/src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java b/src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java new file mode 100644 index 00000000..104c0bc0 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java @@ -0,0 +1,370 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.transport.client.Client; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonObject; +import com.google.gson.JsonSyntaxException; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Tool to search N documents before and N documents after a specific document ID, + * ordered by a timestamp field using search_after pagination. + */ +@Getter +@Setter +@Log4j2 +@ToolAnnotation(SearchAroundDocumentTool.TYPE) +public class SearchAroundDocumentTool implements Tool { + + public static final String INPUT_FIELD = "input"; + public static final String INDEX_FIELD = "index"; + public static final String DOC_ID_FIELD = "doc_id"; + public static final String TIMESTAMP_FIELD = "timestamp_field"; + public static final String COUNT_FIELD = "count"; + public static final String INPUT_SCHEMA_FIELD = "input_schema"; + + public static final String TYPE = "SearchAroundDocumentTool"; + private static final String DEFAULT_DESCRIPTION = """ + Use this tool to search documents around a specific document by providing: \ + 'index' for the index name, 'doc_id' for the target document ID, \ + 'timestamp_field' for the field to order by, and 'count' for the number of documents before and after. \ + Returns N documents before, the target document, and N documents after."""; + + public static final String DEFAULT_INPUT_SCHEMA = """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "OpenSearch index name" + }, + "doc_id": { + "type": "string", + "description": "Target document ID" + }, + "timestamp_field": { + "type": "string", + "description": "Timestamp field for ordering (e.g., @timestamp)" + }, + "count": { + "type": "integer", + "description": "Number of documents before and after the target" + } + }, + "required": ["index", "doc_id", "timestamp_field", "count"], + "additionalProperties": false + } + """; + + private static final Gson GSON = new GsonBuilder().serializeSpecialFloatingPointValues().create(); + + public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); + + private String name = TYPE; + private Map attributes; + private String description = DEFAULT_DESCRIPTION; + + private Client client; + + private NamedXContentRegistry xContentRegistry; + + public SearchAroundDocumentTool(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + + this.attributes = new HashMap<>(); + attributes.put(INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.isEmpty()) { + return false; + } + boolean argumentsFromInput = parameters.containsKey(INPUT_FIELD) && !StringUtils.isEmpty(parameters.get(INPUT_FIELD)); + boolean argumentsFromParameters = parameters.containsKey(INDEX_FIELD) + && parameters.containsKey(DOC_ID_FIELD) + && parameters.containsKey(TIMESTAMP_FIELD) + && parameters.containsKey(COUNT_FIELD) + && !StringUtils.isEmpty(parameters.get(INDEX_FIELD)) + && !StringUtils.isEmpty(parameters.get(DOC_ID_FIELD)) + && !StringUtils.isEmpty(parameters.get(TIMESTAMP_FIELD)) + && !StringUtils.isEmpty(parameters.get(COUNT_FIELD)); + boolean validRequest = argumentsFromInput || argumentsFromParameters; + if (!validRequest) { + log.error("SearchAroundDocumentTool requires: index, doc_id, timestamp_field, and count parameters!"); + return false; + } + return true; + } + + private static Map processResponse(SearchHit hit) { + Map docContent = new HashMap<>(); + docContent.put("_index", hit.getIndex()); + docContent.put("_id", hit.getId()); + docContent.put("_score", hit.getScore()); + docContent.put("_source", hit.getSourceAsMap()); + if (hit.getSortValues() != null && hit.getSortValues().length > 0) { + docContent.put("sort", hit.getSortValues()); + } + return docContent; + } + + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); + String input = parameters.get(INPUT_FIELD); + String index = null; + String docId = null; + String timestampField = null; + Integer count = null; + + if (!StringUtils.isEmpty(input)) { + try { + JsonObject jsonObject = GSON.fromJson(input, JsonObject.class); + if (jsonObject != null) { + if (jsonObject.has(INDEX_FIELD)) { + index = jsonObject.get(INDEX_FIELD).getAsString(); + } + if (jsonObject.has(DOC_ID_FIELD)) { + docId = jsonObject.get(DOC_ID_FIELD).getAsString(); + } + if (jsonObject.has(TIMESTAMP_FIELD)) { + timestampField = jsonObject.get(TIMESTAMP_FIELD).getAsString(); + } + if (jsonObject.has(COUNT_FIELD)) { + count = jsonObject.get(COUNT_FIELD).getAsInt(); + } + } + } catch (JsonSyntaxException e) { + log.error("Invalid JSON input: {}", input, e); + } + } + + // Fall back to direct parameters + if (StringUtils.isEmpty(index)) { + index = parameters.get(INDEX_FIELD); + } + if (StringUtils.isEmpty(docId)) { + docId = parameters.get(DOC_ID_FIELD); + } + if (StringUtils.isEmpty(timestampField)) { + timestampField = parameters.get(TIMESTAMP_FIELD); + } + if (count == null && parameters.containsKey(COUNT_FIELD)) { + try { + count = Double.valueOf(parameters.get(COUNT_FIELD)).intValue(); + } catch (NumberFormatException e) { + log.error("Invalid count parameter: {}", parameters.get(COUNT_FIELD), e); + } + } + + // Validate all required parameters + if (StringUtils.isEmpty(index) || StringUtils.isEmpty(docId) || StringUtils.isEmpty(timestampField) || count == null) { + listener + .onFailure( + new IllegalArgumentException( + "SearchAroundDocumentTool requires: index, doc_id, timestamp_field, and count parameters" + ) + ); + return; + } + + final String finalIndex = index; + final String finalDocId = docId; + final String finalTimestampField = timestampField; + final int finalCount = count; + + // Step 1: Fetch the target document by ID with sort values + log + .debug( + "Calling SearchAroundDocumentTool: index={}, doc_id={}, timestamp_field={}, count={}", + index, + docId, + timestampField, + count + ); + + SearchSourceBuilder targetSource = new SearchSourceBuilder() + .query(QueryBuilders.idsQuery().addIds(docId)) + .sort(new FieldSortBuilder(timestampField).order(SortOrder.DESC).unmappedType("boolean")) + .sort(new FieldSortBuilder("_doc").order(SortOrder.DESC).unmappedType("boolean")) + .size(1); + SearchRequest targetRequest = new SearchRequest(index).source(targetSource); + + client.search(targetRequest, ActionListener.wrap(targetResponse -> { + SearchHit[] targetHits = targetResponse.getHits().getHits(); + if (targetHits == null || targetHits.length == 0) { + listener.onFailure(new IllegalArgumentException("Document not found: " + finalDocId)); + return; + } + + SearchHit targetHit = targetHits[0]; + Object[] sortValues = targetHit.getSortValues(); + if (sortValues == null || sortValues.length < 2) { + listener.onFailure(new IllegalArgumentException("Could not get sort values from target document")); + return; + } + + // Build target document response + Map targetDoc = processResponse(targetHit); + + // Step 2: Search for documents BEFORE using search_after with DESC sort + BoolQueryBuilder beforeQuery = QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(finalDocId)); + + SearchSourceBuilder beforeSource = new SearchSourceBuilder() + .query(beforeQuery) + .sort(new FieldSortBuilder(finalTimestampField).order(SortOrder.DESC).unmappedType("boolean")) + .sort(new FieldSortBuilder("_doc").order(SortOrder.DESC).unmappedType("boolean")) + .searchAfter(sortValues) + .size(finalCount); + SearchRequest beforeRequest = new SearchRequest(finalIndex).source(beforeSource); + + client.search(beforeRequest, ActionListener.wrap(beforeResponse -> { + // Step 3: Search for documents AFTER using search_after with ASC sort + BoolQueryBuilder afterQuery = QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(finalDocId)); + + SearchSourceBuilder afterSource = new SearchSourceBuilder() + .query(afterQuery) + .sort(new FieldSortBuilder(finalTimestampField).order(SortOrder.ASC).unmappedType("boolean")) + .sort(new FieldSortBuilder("_doc").order(SortOrder.ASC).unmappedType("boolean")) + .searchAfter(sortValues) + .size(finalCount); + SearchRequest afterRequest = new SearchRequest(finalIndex).source(afterSource); + + client.search(afterRequest, ActionListener.wrap(afterSearchResponse -> { + + // Process "before" results (need to reverse to get chronological order) + SearchHit[] beforeHits = beforeResponse.getHits().getHits(); + List> beforeDocs = new ArrayList<>(); + for (SearchHit hit : beforeHits) { + beforeDocs.add(processResponse(hit)); + } + Collections.reverse(beforeDocs); + List> result = new ArrayList<>(beforeDocs); + + // Add target document + result.add(targetDoc); + + // Process "after" results + SearchHit[] afterHits = afterSearchResponse.getHits().getHits(); + for (SearchHit hit : afterHits) { + result.add(processResponse(hit)); + } + + String resultJson = GSON.toJson(result); + listener.onResponse((T) resultJson); + }, e -> { + log.error("Failed to search for documents after target", e); + listener.onFailure(e); + })); + }, e -> { + log.error("Failed to search for documents before target", e); + listener.onFailure(e); + })); + }, e -> { + log.error("Failed to fetch target document", e); + listener.onFailure(e); + })); + } catch (Exception e) { + log.error("Failed to run SearchAroundDocumentTool", e); + listener.onFailure(e); + } + } + + public static class Factory implements Tool.Factory { + + private Client client; + private static Factory INSTANCE; + + private NamedXContentRegistry xContentRegistry; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchAroundDocumentTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public SearchAroundDocumentTool create(Map params) { + return new SearchAroundDocumentTool(client, xContentRegistry); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + } +} diff --git a/src/test/java/org/opensearch/agent/ToolPluginTests.java b/src/test/java/org/opensearch/agent/ToolPluginTests.java index 311fa88c..d8cbc388 100644 --- a/src/test/java/org/opensearch/agent/ToolPluginTests.java +++ b/src/test/java/org/opensearch/agent/ToolPluginTests.java @@ -96,7 +96,7 @@ public void test_getRestHandlers_successful() { @Test public void test_getToolFactories_successful() { - assertEquals(15, toolPlugin.getToolFactories().size()); + assertEquals(16, toolPlugin.getToolFactories().size()); } @Test diff --git a/src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java new file mode 100644 index 00000000..fb8dba22 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java @@ -0,0 +1,687 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.transport.client.Client; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; + +public class SearchAroundDocumentToolTests { + + private Client client; + private SearchAroundDocumentTool tool; + private static final Gson GSON = new Gson(); + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + client = mock(Client.class); + SearchAroundDocumentTool.Factory.getInstance().init(client, NamedXContentRegistry.EMPTY); + tool = SearchAroundDocumentTool.Factory.getInstance().create(Collections.emptyMap()); + } + + private SearchHit createMockHit(String id, String index, Map source, Object[] sortValues) { + SearchHit hit = mock(SearchHit.class); + when(hit.getId()).thenReturn(id); + when(hit.getIndex()).thenReturn(index); + when(hit.getScore()).thenReturn(1.0f); + when(hit.getSourceAsMap()).thenReturn(source); + when(hit.getSortValues()).thenReturn(sortValues); + return hit; + } + + private SearchResponse createMockSearchResponse(SearchHit[] hits) { + SearchResponse response = mock(SearchResponse.class); + SearchHits searchHits = mock(SearchHits.class); + when(searchHits.getHits()).thenReturn(hits); + when(response.getHits()).thenReturn(searchHits); + return response; + } + + private void mockThreeSearchCalls(SearchResponse targetResponse, SearchResponse beforeResponse, SearchResponse afterResponse) { + AtomicInteger callCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + int call = callCount.getAndIncrement(); + switch (call) { + case 0: + listener.onResponse(targetResponse); + break; + case 1: + listener.onResponse(beforeResponse); + break; + case 2: + listener.onResponse(afterResponse); + break; + default: + listener.onFailure(new RuntimeException("Unexpected search call")); + } + return null; + }).when(client).search(any(SearchRequest.class), any()); + } + + // ========== Validate Tests ========== + + @Test + public void testValidateWithNullParameters() { + assertFalse(tool.validate(null)); + } + + @Test + public void testValidateWithEmptyParameters() { + assertFalse(tool.validate(Collections.emptyMap())); + } + + @Test + public void testValidateWithJsonInput() { + Map params = new HashMap<>(); + params.put("input", "{\"index\":\"test\",\"doc_id\":\"1\",\"timestamp_field\":\"@timestamp\",\"count\":5}"); + assertTrue(tool.validate(params)); + } + + @Test + public void testValidateWithDirectParameters() { + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "5"); + assertTrue(tool.validate(params)); + } + + @Test + public void testValidateWithMissingParameters() { + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc1"); + assertFalse(tool.validate(params)); + } + + @Test + public void testValidateWithEmptyValues() { + Map params = new HashMap<>(); + params.put("index", ""); + params.put("doc_id", "doc1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "5"); + assertFalse(tool.validate(params)); + } + + // ========== Factory Tests ========== + + @Test + public void testFactoryCreate() { + SearchAroundDocumentTool.Factory factory = SearchAroundDocumentTool.Factory.getInstance(); + SearchAroundDocumentTool createdTool = factory.create(Collections.emptyMap()); + assertNotNull(createdTool); + assertEquals(SearchAroundDocumentTool.TYPE, createdTool.getType()); + } + + @Test + public void testFactoryGetInstance() { + SearchAroundDocumentTool.Factory factory1 = SearchAroundDocumentTool.Factory.getInstance(); + SearchAroundDocumentTool.Factory factory2 = SearchAroundDocumentTool.Factory.getInstance(); + assertTrue(factory1 == factory2); + } + + @Test + public void testFactoryDefaults() { + SearchAroundDocumentTool.Factory factory = SearchAroundDocumentTool.Factory.getInstance(); + assertEquals(SearchAroundDocumentTool.TYPE, factory.getDefaultType()); + assertNotNull(factory.getDefaultDescription()); + assertNull(factory.getDefaultVersion()); + assertNotNull(factory.getDefaultAttributes()); + } + + // ========== Tool Metadata Tests ========== + + @Test + public void testGetType() { + assertEquals("SearchAroundDocumentTool", tool.getType()); + } + + @Test + public void testGetVersion() { + assertNull(tool.getVersion()); + } + + // ========== Run - Success Tests ========== + + @Test + public void testRunWithDirectParameters() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit( + "target-doc", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:05", "message", "target"), + sortValues + ); + + SearchHit beforeHit1 = createMockHit( + "before-1", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:03", "message", "before1"), + new Object[] { 998L, 3L } + ); + SearchHit beforeHit2 = createMockHit( + "before-2", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:04", "message", "before2"), + new Object[] { 999L, 4L } + ); + + SearchHit afterHit1 = createMockHit( + "after-1", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:06", "message", "after1"), + new Object[] { 1001L, 6L } + ); + SearchHit afterHit2 = createMockHit( + "after-2", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:07", "message", "after2"), + new Object[] { 1002L, 7L } + ); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + // Before: returned in DESC order (before2, before1) - tool reverses to chronological + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] { beforeHit2, beforeHit1 }); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] { afterHit1, afterHit2 }); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "target-doc"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + assertNotNull(result); + + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(5, docs.size()); + + // Verify chronological order: before1, before2, target, after1, after2 + assertEquals("before-1", docs.get(0).get("_id")); + assertEquals("before-2", docs.get(1).get("_id")); + assertEquals("target-doc", docs.get(2).get("_id")); + assertEquals("after-1", docs.get(3).get("_id")); + assertEquals("after-2", docs.get(4).get("_id")); + } + + @Test + public void testRunWithJsonInput() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit( + "doc-1", + "my-index", + Map.of("@timestamp", "2024-01-01T00:00:05", "message", "target"), + sortValues + ); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("input", "{\"index\":\"my-index\",\"doc_id\":\"doc-1\",\"timestamp_field\":\"@timestamp\",\"count\":2}"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + assertNotNull(result); + + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + assertEquals("doc-1", docs.get(0).get("_id")); + } + + @Test + public void testRunWithNoBeforeOrAfterDocuments() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "only doc"), sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "5"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + assertEquals("doc-1", docs.get(0).get("_id")); + } + + @Test + public void testRunWithCountAsDouble() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2.0"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + assertNotNull(result); + } + + @Test + public void testRunResponseContainsSortValues() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "1"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + + Map doc = docs.get(0); + assertEquals("doc-1", doc.get("_id")); + assertEquals("test-index", doc.get("_index")); + assertNotNull(doc.get("_source")); + assertNotNull(doc.get("sort")); + } + + // ========== Run - Error Tests ========== + + @Test + public void testRunWithDocumentNotFound() throws Exception { + SearchResponse emptyResponse = createMockSearchResponse(new SearchHit[] {}); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(emptyResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "nonexistent"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("Document not found")); + } + } + + @Test + public void testRunWithMissingSortValues() throws Exception { + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), new Object[] {}); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(targetResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("sort values")); + } + } + + @Test + public void testRunWithNullSortValues() throws Exception { + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), null); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(targetResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("sort values")); + } + } + + @Test + public void testRunWithOnlySingleSortValue() throws Exception { + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), new Object[] { 1000L }); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(targetResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("sort values")); + } + } + + @Test + public void testRunWithMissingRequiredParameters() throws Exception { + Map params = new HashMap<>(); + params.put("index", "test-index"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("requires")); + } + } + + @Test + public void testRunWithTargetSearchFailure() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Target search failed")); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("Target search failed", e.getCause().getMessage()); + } + } + + @Test + public void testRunWithBeforeSearchFailure() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + AtomicInteger callCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + int call = callCount.getAndIncrement(); + if (call == 0) { + listener.onResponse(targetResponse); + } else { + listener.onFailure(new RuntimeException("Before search failed")); + } + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("Before search failed", e.getCause().getMessage()); + } + } + + @Test + public void testRunWithAfterSearchFailure() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + + AtomicInteger callCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + int call = callCount.getAndIncrement(); + if (call == 0) { + listener.onResponse(targetResponse); + } else if (call == 1) { + listener.onResponse(beforeResponse); + } else { + listener.onFailure(new RuntimeException("After search failed")); + } + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("After search failed", e.getCause().getMessage()); + } + } + + // ========== Run - Ordering Tests ========== + + @Test + public void testRunBeforeDocsAreReversedToChronological() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("target", "idx", Map.of("ts", 1000), sortValues); + + // Before search returns in DESC order: doc3 (newest before), doc2, doc1 (oldest before) + SearchHit beforeHit3 = createMockHit("before-3", "idx", Map.of("ts", 999), new Object[] { 999L, 4L }); + SearchHit beforeHit2 = createMockHit("before-2", "idx", Map.of("ts", 998), new Object[] { 998L, 3L }); + SearchHit beforeHit1 = createMockHit("before-1", "idx", Map.of("ts", 997), new Object[] { 997L, 2L }); + + SearchHit afterHit1 = createMockHit("after-1", "idx", Map.of("ts", 1001), new Object[] { 1001L, 6L }); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] { beforeHit3, beforeHit2, beforeHit1 }); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] { afterHit1 }); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "idx"); + params.put("doc_id", "target"); + params.put("timestamp_field", "ts"); + params.put("count", "3"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(5, docs.size()); + + // Before docs should be reversed to chronological (oldest first) + assertEquals("before-1", docs.get(0).get("_id")); + assertEquals("before-2", docs.get(1).get("_id")); + assertEquals("before-3", docs.get(2).get("_id")); + assertEquals("target", docs.get(3).get("_id")); + assertEquals("after-1", docs.get(4).get("_id")); + } + + @Test + public void testRunResponseDocStructure() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + Map source = Map.of("@timestamp", "2024-01-01", "level", "INFO", "message", "test log"); + + SearchHit targetHit = createMockHit("doc-1", "logs-index", source, sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "logs-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "1"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + + Map doc = docs.get(0); + assertEquals("doc-1", doc.get("_id")); + assertEquals("logs-index", doc.get("_index")); + assertNotNull(doc.get("_score")); + assertNotNull(doc.get("sort")); + + @SuppressWarnings("unchecked") + Map returnedSource = (Map) doc.get("_source"); + assertEquals("2024-01-01", returnedSource.get("@timestamp")); + assertEquals("INFO", returnedSource.get("level")); + assertEquals("test log", returnedSource.get("message")); + } +} diff --git a/src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java b/src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java new file mode 100644 index 00000000..50f45b29 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java @@ -0,0 +1,280 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; + +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; + +import lombok.SneakyThrows; + +public class SearchAroundDocumentToolIT extends BaseAgentToolsIT { + + private static final String TEST_INDEX_NAME = "test_search_around_document_index"; + private static final String REGISTER_AGENT_RESOURCE = + "org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json"; + + private String registerAgentRequestBody; + private String agentId; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(REGISTER_AGENT_RESOURCE).toURI())); + prepareDataIndex(); + agentId = createAgent(registerAgentRequestBody); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + @SneakyThrows + private void prepareDataIndex() { + createIndexWithConfiguration(TEST_INDEX_NAME, """ + { + "mappings": { + "properties": { + "@timestamp": { + "type": "date", + "format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis" + }, + "message": { + "type": "text" + }, + "level": { + "type": "keyword" + } + } + } + }"""); + + // Index 7 documents with known timestamps and IDs + addDocToIndex( + TEST_INDEX_NAME, + "doc1", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:00:00", "First log entry", "INFO") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc2", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:10:00", "Second log entry", "INFO") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc3", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:20:00", "Third log entry", "WARN") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc4", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:30:00", "Fourth log entry - target", "ERROR") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc5", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:40:00", "Fifth log entry", "WARN") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc6", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:50:00", "Sixth log entry", "INFO") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc7", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 10:00:00", "Seventh log entry", "ERROR") + ); + } + + @SneakyThrows + public void testSearchAroundDocument_basicSearch() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc4\", \"timestamp_field\": \"@timestamp\", \"count\": \"2\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // Should have 5 documents: 2 before + target + 2 after + assertEquals(5, docs.size()); + + // Verify chronological order: doc2, doc3, doc4 (target), doc5, doc6 + assertEquals("doc2", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc3", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc4", docs.get(2).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(3).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc6", docs.get(4).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_countExceedsAvailable() { + // doc1 is the first document, requesting 5 before but only 0 exist before it + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc1\", \"timestamp_field\": \"@timestamp\", \"count\": \"5\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // Should have target + up to 5 after (doc2-doc7 = 6 after, but count=5) + // No before docs since doc1 is the earliest + assertEquals(6, docs.size()); + + // First should be the target (doc1), followed by 5 after docs + assertEquals("doc1", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc2", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc3", docs.get(2).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc4", docs.get(3).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(4).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc6", docs.get(5).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_lastDocument() { + // doc7 is the last document, requesting 3 before + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc7\", \"timestamp_field\": \"@timestamp\", \"count\": \"3\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // Should have 3 before + target, no after docs + assertEquals(4, docs.size()); + + assertEquals("doc4", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc6", docs.get(2).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc7", docs.get(3).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_countOfOne() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc4\", \"timestamp_field\": \"@timestamp\", \"count\": \"1\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // 1 before + target + 1 after = 3 + assertEquals(3, docs.size()); + + assertEquals("doc3", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc4", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(2).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_jsonInput() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"input\": \"{\\\"index\\\": \\\"%s\\\", \\\"doc_id\\\": \\\"doc4\\\", \\\"timestamp_field\\\": \\\"@timestamp\\\", \\\"count\\\": 2}\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + assertEquals(5, docs.size()); + assertEquals("doc4", docs.get(2).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_responseContainsSource() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc4\", \"timestamp_field\": \"@timestamp\", \"count\": \"1\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + // Verify the target document has _source with correct fields + JsonElement targetDoc = docs.get(1); + assertTrue(targetDoc.getAsJsonObject().has("_source")); + assertTrue(targetDoc.getAsJsonObject().has("_id")); + assertTrue(targetDoc.getAsJsonObject().has("_index")); + + String source = targetDoc.getAsJsonObject().get("_source").toString(); + assertTrue(source.contains("Fourth log entry - target")); + assertTrue(source.contains("ERROR")); + } + + @SneakyThrows + public void testSearchAroundDocument_nonExistentDoc() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"nonexistent\", \"timestamp_field\": \"@timestamp\", \"count\": \"2\"}}", + TEST_INDEX_NAME + ) + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("Document not found")); + } + + @SneakyThrows + public void testSearchAroundDocument_missingRequiredParameters() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent(agentId, String.format(Locale.ROOT, "{\"parameters\": {\"index\": \"%s\"}}", TEST_INDEX_NAME)) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("requires")); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json new file mode 100644 index 00000000..c97af635 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json @@ -0,0 +1,9 @@ +{ + "name": "Test_Search_Around_Document_Agent", + "type": "flow", + "tools": [ + { + "type": "SearchAroundDocumentTool" + } + ] +} \ No newline at end of file From e3f247c01552298e721a3b2bcc41df3f8ffd5630 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Fri, 20 Mar 2026 11:49:19 +0800 Subject: [PATCH 26/30] Add filter support for LogPatternAnalysisTool (#707) * add filter for LogPatternAnalysisTool Signed-off-by: Hailong Cui remove filter for trace analysis Signed-off-by: Hailong Cui spotlessApply Signed-off-by: Hailong Cui * use fields command to narrow down queery fields Signed-off-by: Hailong Cui * limit the size of log insight analysis Signed-off-by: Hailong Cui * fix wrong log order Signed-off-by: Hailong Cui --------- Signed-off-by: Hailong Cui --- .../agent/tools/DataDistributionTool.java | 2 +- .../agent/tools/LogPatternAnalysisTool.java | 222 ++++++++++-------- .../tools/LogPatternAnalysisToolTests.java | 101 ++++++++ .../integTest/LogPatternAnalysisToolIT.java | 80 +++++-- 4 files changed, 289 insertions(+), 116 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java index b3d44505..b133f999 100644 --- a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java +++ b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java @@ -175,7 +175,7 @@ public class DataDistributionTool implements Tool { "description": "PPL query without time filtering (added automatically)" } }, - "required": ["index", "selectionTimeRangeStart", "selectionTimeRangeEnd"], + "required": ["index", "timeField", "selectionTimeRangeStart", "selectionTimeRangeEnd"], "additionalProperties": false } """; diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java index 777be285..741cc94c 100644 --- a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java +++ b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java @@ -102,54 +102,60 @@ public class LogPatternAnalysisTool implements Tool { private static final double LOG_PATTERN_THRESHOLD = 0.75; private static final double LOG_PATTERN_LIFT = 3; private static final String DEFAULT_TIME_FIELD = "@timestamp"; - - public static final String DEFAULT_INPUT_SCHEMA = """ - { - "type": "object", - "properties": { - "index": { - "type": "string", - "description": "Target OpenSearch index name" - }, - "timeField": { - "type": "string", - "description": "Date/time field for filtering" - }, - "logFieldName": { - "type": "string", - "description": "Field containing log message text" - }, - "traceFieldName": { - "type": "string", - "description": "Trace/correlation ID field. Enables sequence analysis mode when provided with baseline time range" - }, - "baseTimeRangeStart": { - "type": "string", - "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeEnd" - }, - "baseTimeRangeEnd": { - "type": "string", - "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeStart" - }, - "selectionTimeRangeStart": { - "type": "string", - "description": "Start of target/incident period (format: yyyy-MM-dd HH:mm:ss)" + private static final int MAX_LOG_SAMPLE_SIZE = 10000; + + public static final String DEFAULT_INPUT_SCHEMA = + """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "Target OpenSearch index name" + }, + "timeField": { + "type": "string", + "description": "Date/time field for filtering" + }, + "logFieldName": { + "type": "string", + "description": "Field containing log message text" + }, + "traceFieldName": { + "type": "string", + "description": "Trace/correlation ID field. Enables sequence analysis mode when provided with baseline time range" + }, + "baseTimeRangeStart": { + "type": "string", + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeEnd" + }, + "baseTimeRangeEnd": { + "type": "string", + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeStart" + }, + "selectionTimeRangeStart": { + "type": "string", + "description": "Start of target/incident period (format: yyyy-MM-dd HH:mm:ss)" + }, + "selectionTimeRangeEnd": { + "type": "string", + "description": "End of target/incident period (format: yyyy-MM-dd HH:mm:ss)" + }, + "filter": { + "type": "string", + "description": "PPL boolean expression to filter logs (e.g. serviceName='ts-auth-service' or severity='ERROR'). Applied as additional where clause. Not applicable for sequence analysis mode (when traceFieldName is provided with baseline), as sequence analysis requires all logs within a trace" + } }, - "selectionTimeRangeEnd": { - "type": "string", - "description": "End of target/incident period (format: yyyy-MM-dd HH:mm:ss)" - } - }, - "required": [ - "index", - "timeField", - "logFieldName", - "selectionTimeRangeStart", - "selectionTimeRangeEnd" - ], - "additionalProperties": false - } - """; + "required": [ + "index", + "timeField", + "logFieldName", + "selectionTimeRangeStart", + "selectionTimeRangeEnd" + ], + "additionalProperties": false + } + """; public static final Map DEFAULT_ATTRIBUTES = Map .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class)), STRICT_FIELD, false); @@ -169,6 +175,7 @@ private static class AnalysisParameters { final String baseTimeRangeEnd; final String selectionTimeRangeStart; final String selectionTimeRangeEnd; + final String filter; AnalysisParameters(Map parameters) { this.index = parameters.getOrDefault("index", ""); @@ -179,6 +186,7 @@ private static class AnalysisParameters { this.baseTimeRangeEnd = parameters.getOrDefault("baseTimeRangeEnd", ""); this.selectionTimeRangeStart = parameters.getOrDefault("selectionTimeRangeStart", ""); this.selectionTimeRangeEnd = parameters.getOrDefault("selectionTimeRangeEnd", ""); + this.filter = parameters.getOrDefault("filter", ""); } private void validate() { @@ -303,9 +311,12 @@ public void run(Map originalParameters, ActionListener li } private void logSequenceAnalysis(AnalysisParameters params, ActionListener listener) { + if (!Strings.isEmpty(params.filter)) { + log.warn("Filter parameter is ignored for sequence analysis mode as it requires all logs within a trace"); + } // Step 1: Analyze selection time range analyzeSelectionTimeRange(params, ActionListener.wrap(selectionResult -> { - log.debug("Base time range analysis completed, found {} traces", selectionResult.tracePatternMap.size()); + log.debug("Selection time range analysis completed, found {} traces", selectionResult.tracePatternMap.size()); if (selectionResult.tracePatternMap.isEmpty()) { Map> emptyResult = buildFinalResult( @@ -320,7 +331,7 @@ private void logSequenceAnalysis(AnalysisParameters params, ActionListener { - log.debug("Selection time range analysis completed, found {} traces", baseResult.tracePatternMap.size()); + log.debug("Base time range analysis completed, found {} traces", baseResult.tracePatternMap.size()); // Step 3: Generate comparison result generateSequenceComparisonResult(baseResult, selectionResult, listener); @@ -338,7 +349,8 @@ private void analyzeBaseTimeRange(AnalysisParameters params, ActionListener'%s' and %s<'%s' | patterns %s method=brain " - + "variable_count_threshold=3 | fields %s, patterns_field, %s | sort %s", - index, - traceFieldName, - timeField, - startTime, - timeField, - endTime, - logFieldName, - traceFieldName, - timeField, - timeField - ); + String filterClause = Strings.isEmpty(filter) ? "" : String.format(Locale.ROOT, " | where %s", filter); + + String pplTemplate = + "source={INDEX} | where {TRACE_FIELD}!='' | where {TIME_FIELD}>'{START_TIME}' and {TIME_FIELD}<'{END_TIME}'{FILTER} " + + "| fields {TRACE_FIELD}, {LOG_FIELD}, {TIME_FIELD} | patterns {LOG_FIELD} method=brain variable_count_threshold=3 " + + "| fields {TRACE_FIELD}, patterns_field, {TIME_FIELD} | sort {TIME_FIELD}"; + + return pplTemplate + .replace("{INDEX}", index) + .replace("{TRACE_FIELD}", traceFieldName) + .replace("{TIME_FIELD}", timeField) + .replace("{START_TIME}", startTime) + .replace("{END_TIME}", endTime) + .replace("{FILTER}", filterClause) + .replace("{LOG_FIELD}", logFieldName); } private Map vectorizePattern(Map> patternCountMap, int totalTraceCount) { @@ -620,7 +633,8 @@ private void logPatternDiffAnalysis(AnalysisParameters params, ActionListene params.timeField, params.logFieldName, params.baseTimeRangeStart, - params.baseTimeRangeEnd + params.baseTimeRangeEnd, + params.filter ); Function>, Map> dataRowsParser = dataRows -> { Map patternMap = new HashMap<>(); @@ -652,7 +666,8 @@ private void logPatternDiffAnalysis(AnalysisParameters params, ActionListene params.timeField, params.logFieldName, params.selectionTimeRangeStart, - params.selectionTimeRangeEnd + params.selectionTimeRangeEnd, + params.filter ); log.debug("Executing selection time range pattern PPL: {}", selectionTimeRangeLogPatternPPL); @@ -753,22 +768,23 @@ private void logInsight(AnalysisParameters params, ActionListener listene "violation" ); - String selectionTimeRangeLogPatternPPL = String - .format( - Locale.ROOT, - "source=%s | where %s>'%s' and %s<'%s' | where match(%s, '%s') | patterns %s method=brain " - + "mode=aggregation max_sample_count=5 " - + "variable_count_threshold=3 | fields patterns_field, pattern_count, sample_logs " - + "| sort -pattern_count | head 5", - params.index, - params.timeField, - params.selectionTimeRangeStart, - params.timeField, - params.selectionTimeRangeEnd, - params.logFieldName, - String.join(" ", errorKeywords), - params.logFieldName - ); + String filterClause = Strings.isEmpty(params.filter) ? "" : String.format(Locale.ROOT, " | where %s", params.filter); + + String pplTemplate = "source={INDEX} | where {TIME_FIELD}>'{START_TIME}' and {TIME_FIELD}<'{END_TIME}'{FILTER} " + + "| where match({LOG_FIELD}, '{ERROR_KEYWORDS}') | head " + + MAX_LOG_SAMPLE_SIZE + + " | fields {LOG_FIELD} | patterns {LOG_FIELD} method=brain " + + "mode=aggregation max_sample_count=5 variable_count_threshold=3 " + + "| fields patterns_field, pattern_count, sample_logs | sort -pattern_count | head 5"; + + String selectionTimeRangeLogPatternPPL = pplTemplate + .replace("{INDEX}", params.index) + .replace("{TIME_FIELD}", params.timeField) + .replace("{START_TIME}", params.selectionTimeRangeStart) + .replace("{END_TIME}", params.selectionTimeRangeEnd) + .replace("{FILTER}", filterClause) + .replace("{LOG_FIELD}", params.logFieldName) + .replace("{ERROR_KEYWORDS}", String.join(" ", errorKeywords)); Function>, List> dataRowsParser = dataRows -> { List patternWithSamplesList = new ArrayList<>(); @@ -804,19 +820,27 @@ private void logInsight(AnalysisParameters params, ActionListener listene ); } - private String buildLogPatternPPL(String index, String timeField, String logFieldName, String startTime, String endTime) { - return String - .format( - Locale.ROOT, - "source=%s | where %s>'%s' and %s<'%s' | patterns %s method=brain mode=aggregation " - + "variable_count_threshold=3 | fields pattern_count, patterns_field", - index, - timeField, - startTime, - timeField, - endTime, - logFieldName - ); + private String buildLogPatternPPL( + String index, + String timeField, + String logFieldName, + String startTime, + String endTime, + String filter + ) { + String filterClause = Strings.isEmpty(filter) ? "" : String.format(Locale.ROOT, " | where %s", filter); + + String pplTemplate = "source={INDEX} | where {TIME_FIELD}>'{START_TIME}' and {TIME_FIELD}<'{END_TIME}'{FILTER} " + + "| fields {LOG_FIELD} | patterns {LOG_FIELD} method=brain mode=aggregation variable_count_threshold=3 " + + "| fields pattern_count, patterns_field"; + + return pplTemplate + .replace("{INDEX}", index) + .replace("{TIME_FIELD}", timeField) + .replace("{START_TIME}", startTime) + .replace("{END_TIME}", endTime) + .replace("{FILTER}", filterClause) + .replace("{LOG_FIELD}", logFieldName); } private List calculatePatternDifferences(Map basePatterns, Map selectionPatterns) { diff --git a/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java b/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java index 3adffe17..dd17e50d 100644 --- a/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java @@ -20,6 +20,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; import org.hamcrest.MatcherAssert; import org.junit.Before; @@ -28,6 +29,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.core.action.ActionListener; import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.transport.client.Client; @@ -466,6 +468,105 @@ public void testExecutionWithNonExistentIndex() { ); } + @Test + @SneakyThrows + public void testLogInsightWithFilter() { + String pplResponse = + """ + {"schema":[{"name":"patterns_field","type":"string"},{"name":"pattern_count","type":"long"},{"name":"sample_logs","type":"array"}], + "datarows":[["Auth error for user <*>",3,["Auth error for user admin","Auth error for user guest"]]], + "total":1,"size":1} + """; + + AtomicReference capturedPPL = new AtomicReference<>(); + doAnswer(invocation -> { + TransportPPLQueryRequest request = (TransportPPLQueryRequest) invocation.getArguments()[1]; + capturedPPL.set(request.getRequest()); + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + when(pplQueryResponse.getResult()).thenReturn(pplResponse); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .builder() + .put("index", "test_index") + .put("timeField", "@timestamp") + .put("logFieldName", "message") + .put("selectionTimeRangeStart", "2025-01-01T00:00:00Z") + .put("selectionTimeRangeEnd", "2025-01-01T01:00:00Z") + .put("filter", "serviceName='ts-auth-service'") + .build(), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("logInsights")); + // Verify the PPL query contains the filter clause + assertTrue(capturedPPL.get().contains("where serviceName='ts-auth-service'")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testLogPatternDiffWithFilter() { + String baseResponse = """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[100,"User login successful"],[20,"Database query executed"]], + "total":2,"size":2} + """; + + String selectionResponse = """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[50,"User login successful"],[80,"Error in authentication <*>"]], + "total":2,"size":2} + """; + + AtomicReference firstPPL = new AtomicReference<>(); + AtomicReference secondPPL = new AtomicReference<>(); + doAnswer(invocation -> { + TransportPPLQueryRequest request = (TransportPPLQueryRequest) invocation.getArguments()[1]; + String ppl = request.getRequest(); + if (firstPPL.get() == null) { + firstPPL.set(ppl); + } else { + secondPPL.set(ppl); + } + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + when(pplQueryResponse.getResult()).thenReturn(baseResponse).thenReturn(selectionResponse); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .builder() + .put("index", "test_index") + .put("timeField", "@timestamp") + .put("logFieldName", "message") + .put("baseTimeRangeStart", "2025-01-01T00:00:00Z") + .put("baseTimeRangeEnd", "2025-01-01T01:00:00Z") + .put("selectionTimeRangeStart", "2025-01-01T01:00:00Z") + .put("selectionTimeRangeEnd", "2025-01-01T02:00:00Z") + .put("filter", "severity='ERROR'") + .build(), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("patternMapDifference")); + // Verify both PPL queries contain the filter clause + assertTrue(firstPPL.get().contains("where severity='ERROR'")); + assertTrue(secondPPL.get().contains("where severity='ERROR'")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + @Test @SneakyThrows public void testExecutionWithNonExistentLogField() { diff --git a/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java b/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java index 4ad50662..9ea60485 100644 --- a/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java +++ b/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java @@ -59,6 +59,9 @@ private void prepareLogIndex() { + " },\n" + " \"traceId\": {\n" + " \"type\": \"keyword\"\n" + + " },\n" + + " \"serviceName\": {\n" + + " \"type\": \"keyword\"\n" + " }\n" + " }\n" + " }\n" @@ -69,52 +72,52 @@ private void prepareLogIndex() { addDocToIndex( TEST_LOG_INDEX_NAME, "base1", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 09:30:00", "System startup completed", "trace-base-001") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 09:30:00", "System startup completed", "trace-base-001", "auth-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "base2", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 09:45:00", "Database connection established", "trace-base-002") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 09:45:00", "Database connection established", "trace-base-002", "db-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "base3", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 09:50:00", "User session initialized", "trace-base-003") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 09:50:00", "User session initialized", "trace-base-003", "auth-service") ); // Add test log data with error keywords for logInsight addDocToIndex( TEST_LOG_INDEX_NAME, "1", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:00:00", "User login successful", "trace-001") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:00:00", "User login successful", "trace-001", "auth-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "2", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:01:00", "Database connection established", "trace-001") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:01:00", "Database connection established", "trace-001", "db-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "3", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:02:00", "Error connection timeout failed", "trace-002") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:02:00", "Error connection timeout failed", "trace-002", "db-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "4", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:03:00", "User logout completed", "trace-001") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:03:00", "User logout completed", "trace-001", "auth-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "5", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:04:00", "Exception in authentication service", "trace-003") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:04:00", "Exception in authentication service", "trace-003", "auth-service") ); } @@ -206,6 +209,51 @@ public void testLogPatternAnalysisToolInvalidTimeFormat() { MatcherAssert.assertThat(exception.getMessage(), containsString("not a valid term")); } + @SneakyThrows + public void testLogPatternAnalysisToolLogInsightWithFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\", \"filter\": \"serviceName='db-service'\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("logInsights")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithBaseTimeRangeAndFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\", \"filter\": \"serviceName='auth-service'\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("patternMapDifference")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithTraceFieldAndFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"traceFieldName\": \"traceId\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\", \"filter\": \"serviceName='auth-service'\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("BASE") || result.contains("EXCEPTIONAL")); + } + @SneakyThrows public void testLogPatternAnalysisToolEmptyTimeRange() { Exception exception = assertThrows( From 5a07968a9520cd6a3a2a3647f474904206924958 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:33:41 -0400 Subject: [PATCH 27/30] Increment version to 3.6.0-SNAPSHOT (#693) Signed-off-by: opensearch-ci-bot Co-authored-by: opensearch-ci-bot --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index be84ec76..7b72bee5 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ import com.github.jengelman.gradle.plugins.shadow.ShadowPlugin buildscript { ext { opensearch_group = "org.opensearch" - opensearch_version = System.getProperty("opensearch.version", "3.5.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.6.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") isSnapshot = "true" == System.getProperty("build.snapshot", "true") version_tokens = opensearch_version.tokenize('-') From d9e2e8b97152456e783378ed0e43d65fed1b391f Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 1 Apr 2026 10:49:20 -0700 Subject: [PATCH 28/30] fix(deps): update dependency org.apache.spark:spark-common-utils_2.13 to v3.5.8 (#713) (#717) (cherry picked from commit 89351f0bbd7e2a0b5839d9e23f6bc9d5aa148233) Signed-off-by: mend-for-github-com[bot] Signed-off-by: Daniel Widdis Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] Co-authored-by: mend-for-github-com[bot] <50673670+mend-for-github-com[bot]@users.noreply.github.com> --- build.gradle | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/build.gradle b/build.gradle index 7b72bee5..b96229fe 100644 --- a/build.gradle +++ b/build.gradle @@ -154,11 +154,11 @@ dependencies { compileOnly(group: 'software.amazon.awssdk', name: 'utils', version: "${versions.aws}") compileOnly(group: 'software.amazon.awssdk', name: 'sdk-core', version: "${versions.aws}") - spark 'org.apache.spark:spark-sql-api_2.13:3.5.4' - spark ('org.apache.spark:spark-core_2.13:3.5.4') { + spark 'org.apache.spark:spark-sql-api_2.13:3.5.8' + spark ('org.apache.spark:spark-core_2.13:3.5.8') { exclude group: 'org.eclipse.jetty', module: 'jetty-server' } - spark group: 'org.apache.spark', name: 'spark-common-utils_2.13', version: '3.5.4' + spark group: 'org.apache.spark', name: 'spark-common-utils_2.13', version: '3.5.8' implementation 'org.scala-lang:scala-library:2.13.9' implementation group: 'org.antlr', name: 'antlr4-runtime', version: '4.9.3' @@ -217,9 +217,9 @@ task addSparkJar(type: Copy) { into sparkDir doLast { - def jarA = file("$sparkDir/spark-sql-api_2.13-3.5.4.jar") - def jarB = file("$sparkDir/spark-core_2.13-3.5.4.jar") - def jarC = file("$sparkDir/spark-common-utils_2.13-3.5.4.jar") + def jarA = file("$sparkDir/spark-sql-api_2.13-3.5.8.jar") + def jarB = file("$sparkDir/spark-core_2.13-3.5.8.jar") + def jarC = file("$sparkDir/spark-common-utils_2.13-3.5.8.jar") // 3a. Extract jar A to manipulate it def jarAContents = file("$buildDir/tmp/JarAContents") From 18ec6636faceb8d954ceb2c5f33de71093e356b6 Mon Sep 17 00:00:00 2001 From: opensearch-ci <83309141+opensearch-ci-bot@users.noreply.github.com> Date: Sat, 4 Apr 2026 00:04:06 -0400 Subject: [PATCH 29/30] [AUTO] Add release notes for 3.6.0 (#719) * Add release notes for 3.6.0 Signed-off-by: opensearch-ci * Add release notes for 3.6.0 Signed-off-by: opensearch-ci --------- Signed-off-by: opensearch-ci --- .../opensearch-skills.release-notes-3.6.0.0.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 release-notes/opensearch-skills.release-notes-3.6.0.0.md diff --git a/release-notes/opensearch-skills.release-notes-3.6.0.0.md b/release-notes/opensearch-skills.release-notes-3.6.0.0.md new file mode 100644 index 00000000..6982daac --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.6.0.0.md @@ -0,0 +1,17 @@ +## Version 3.6.0 Release Notes + +Compatible with OpenSearch and OpenSearch Dashboards version 3.6.0 + +### Features + +* Add SearchAroundTool to search N documents around a given document ([#702](https://github.com/opensearch-project/skills/pull/702)) +* Add MetricChangeAnalysisTool for detecting and analyzing metric changes via percentile comparison between baseline and selection periods ([#698](https://github.com/opensearch-project/skills/pull/698)) + +### Enhancements + +* Add filter support for LogPatternAnalysisTool to enable log pattern analysis for specific services ([#707](https://github.com/opensearch-project/skills/pull/707)) +* Update default tool descriptions for LogPatternAnalysisTool and DataDistributionTool to improve clarity for LLM usage ([#703](https://github.com/opensearch-project/skills/pull/703)) + +### Maintenance + +* Update Apache Spark dependencies (spark-common-utils_2.13) from 3.5.4 to 3.5.8 ([#713](https://github.com/opensearch-project/skills/pull/713)) From 770da34192dd1397200fccb00db028c3ea687bdb Mon Sep 17 00:00:00 2001 From: opensearch-ci-bot Date: Tue, 7 Apr 2026 22:06:32 +0000 Subject: [PATCH 30/30] Incremented version to 3.6.1 Signed-off-by: GitHub --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index b96229fe..be0d853d 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ import com.github.jengelman.gradle.plugins.shadow.ShadowPlugin buildscript { ext { opensearch_group = "org.opensearch" - opensearch_version = System.getProperty("opensearch.version", "3.6.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.6.1-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") isSnapshot = "true" == System.getProperty("build.snapshot", "true") version_tokens = opensearch_version.tokenize('-')