diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index deeed07e..c3edc539 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: needs: Get-CI-Image-Tag strategy: matrix: - java: [21, 24] + java: [21, 25] name: Build and Test skills plugin on Linux runs-on: ubuntu-latest container: @@ -52,7 +52,7 @@ jobs: build-MacOS: strategy: matrix: - java: [21, 24] + java: [21, 25] name: Build and Test skills Plugin on MacOS needs: Get-CI-Image-Tag @@ -77,7 +77,7 @@ jobs: build-windows: strategy: matrix: - java: [21, 24] + java: [21, 25] name: Build and Test skills plugin on Windows needs: Get-CI-Image-Tag runs-on: windows-latest diff --git a/.github/workflows/maven-publish.yml b/.github/workflows/maven-publish.yml index 1e3bc651..8b73d17d 100644 --- a/.github/workflows/maven-publish.yml +++ b/.github/workflows/maven-publish.yml @@ -32,8 +32,14 @@ jobs: export-env: true env: OP_SERVICE_ACCOUNT_TOKEN: ${{ secrets.OP_SERVICE_ACCOUNT_TOKEN }} - SONATYPE_USERNAME: op://opensearch-infra-secrets/maven-central-portal-credentials/username - SONATYPE_PASSWORD: op://opensearch-infra-secrets/maven-central-portal-credentials/password + MAVEN_SNAPSHOTS_S3_REPO: op://opensearch-infra-secrets/maven-snapshots-s3/repo + MAVEN_SNAPSHOTS_S3_ROLE: op://opensearch-infra-secrets/maven-snapshots-s3/role + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v5 + with: + role-to-assume: ${{ env.MAVEN_SNAPSHOTS_S3_ROLE }} + aws-region: us-east-1 - name: publish snapshots to maven run: | diff --git a/.github/workflows/test_security.yml b/.github/workflows/test_security.yml index 0f22b923..63c99b29 100644 --- a/.github/workflows/test_security.yml +++ b/.github/workflows/test_security.yml @@ -16,7 +16,7 @@ jobs: integ-test-with-security-linux: strategy: matrix: - java: [21, 24] + java: [21, 25] env: ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true name: Run Security Integration Tests on Linux diff --git a/.gitignore b/.gitignore index 722e14d1..21b99ea6 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ out/ .settings .vscode bin/ +.factorypath diff --git a/build.gradle b/build.gradle index 423333bb..be0d853d 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ import com.github.jengelman.gradle.plugins.shadow.ShadowPlugin buildscript { ext { opensearch_group = "org.opensearch" - opensearch_version = System.getProperty("opensearch.version", "3.3.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.6.1-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") isSnapshot = "true" == System.getProperty("build.snapshot", "true") version_tokens = opensearch_version.tokenize('-') @@ -31,8 +31,7 @@ buildscript { repositories { mavenLocal() - maven { url "https://central.sonatype.com/repository/maven-snapshots/" } - maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } + maven { url "https://ci.opensearch.org/ci/dbc/snapshots/maven/" } maven { url "https://plugins.gradle.org/m2/" } mavenCentral() } @@ -57,8 +56,7 @@ repositories { mavenLocal() mavenCentral() maven { url "https://plugins.gradle.org/m2/" } - maven { url "https://central.sonatype.com/repository/maven-snapshots/" } - maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } + maven { url "https://ci.opensearch.org/ci/dbc/snapshots/maven/" } } allprojects { @@ -138,29 +136,36 @@ dependencies { compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.23.1" compileOnly group: 'org.json', name: 'json', version: '20240303' compileOnly("com.google.guava:guava:33.2.1-jre") - compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.16.0' + compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: "${versions.commonslang}" compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.12.0' compileOnly group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' - compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") + compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson_annotations}") compileOnly("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") compileOnly(group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: "${versions.httpcore5}") compileOnly(group: 'org.apache.httpcomponents.client5', name: 'httpclient5', version: "${versions.httpclient5}") compileOnly ('com.jayway.jsonpath:json-path:2.9.0') { exclude group: 'net.minidev', module: 'json-smart' } + compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}") { + exclude(group: 'org.reactivestreams', module: 'reactive-streams') + exclude(group: 'org.slf4j', module: 'slf4j-api') + } + compileOnly(group: 'software.amazon.awssdk', name: 'http-client-spi', version: "${versions.aws}") + compileOnly(group: 'software.amazon.awssdk', name: 'utils', version: "${versions.aws}") + compileOnly(group: 'software.amazon.awssdk', name: 'sdk-core', version: "${versions.aws}") - spark 'org.apache.spark:spark-sql-api_2.13:3.5.4' - spark ('org.apache.spark:spark-core_2.13:3.5.4') { + spark 'org.apache.spark:spark-sql-api_2.13:3.5.8' + spark ('org.apache.spark:spark-core_2.13:3.5.8') { exclude group: 'org.eclipse.jetty', module: 'jetty-server' } - spark group: 'org.apache.spark', name: 'spark-common-utils_2.13', version: '3.5.4' + spark group: 'org.apache.spark', name: 'spark-common-utils_2.13', version: '3.5.8' implementation 'org.scala-lang:scala-library:2.13.9' implementation group: 'org.antlr', name: 'antlr4-runtime', version: '4.9.3' implementation("org.json4s:json4s-ast_2.13:3.7.0-M11") implementation("org.json4s:json4s-core_2.13:3.7.0-M11") implementation("org.json4s:json4s-jackson_2.13:3.7.0-M11") - implementation 'com.fasterxml.jackson.module:jackson-module-scala_3:2.18.2' + implementation "com.fasterxml.jackson.module:jackson-module-scala_3:${versions.jackson}" implementation group: 'org.scala-lang', name: 'scala3-library_3', version: '3.7.0-RC1-bin-20250119-bd699fc-NIGHTLY' implementation("com.thoughtworks.paranamer:paranamer:2.8") implementation("org.jsoup:jsoup:1.19.1") @@ -170,7 +175,7 @@ dependencies { compileOnly group: 'org.opensearch', name:'opensearch-ml-spi', version: "${opensearch_build}" compileOnly fileTree(dir: jsJarDirectory, include: ["opensearch-job-scheduler-${opensearch_build}.jar"]) implementation fileTree(dir: adJarDirectory, include: ["opensearch-anomaly-detection-${opensearch_build}.jar"]) - implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-thin-${opensearch_build}.jar", "ppl-${opensearch_build}.jar", "protocol-${opensearch_build}.jar"]) + implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-thin-${opensearch_build}.jar", "ppl-${opensearch_build}.jar", "protocol-${opensearch_build}.jar", "core-${opensearch_build}.jar"]) implementation fileTree(dir: sparkDir, include: ["spark*.jar"]) compileOnly "org.opensearch:common-utils:${opensearch_build}" compileOnly "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}" @@ -194,7 +199,7 @@ dependencies { testImplementation group: 'org.json', name: 'json', version: '20240303' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.14.2' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0' - testImplementation("net.bytebuddy:byte-buddy:1.17.5") + testImplementation("net.bytebuddy:byte-buddy:1.17.7") testImplementation("net.bytebuddy:byte-buddy-agent:1.17.5") testImplementation 'org.junit.jupiter:junit-jupiter-api:5.11.2' testImplementation 'org.mockito:mockito-junit-jupiter:5.14.2' @@ -212,9 +217,9 @@ task addSparkJar(type: Copy) { into sparkDir doLast { - def jarA = file("$sparkDir/spark-sql-api_2.13-3.5.4.jar") - def jarB = file("$sparkDir/spark-core_2.13-3.5.4.jar") - def jarC = file("$sparkDir/spark-common-utils_2.13-3.5.4.jar") + def jarA = file("$sparkDir/spark-sql-api_2.13-3.5.8.jar") + def jarB = file("$sparkDir/spark-core_2.13-3.5.8.jar") + def jarC = file("$sparkDir/spark-common-utils_2.13-3.5.8.jar") // 3a. Extract jar A to manipulate it def jarAContents = file("$buildDir/tmp/JarAContents") @@ -441,10 +446,11 @@ publishing { repositories { maven { name = "Snapshots" - url = "https://central.sonatype.com/repository/maven-snapshots/" - credentials { - username System.getenv("SONATYPE_USERNAME") - password System.getenv("SONATYPE_PASSWORD") + url = System.getenv("MAVEN_SNAPSHOTS_S3_REPO") + credentials(AwsCredentials) { + accessKey = System.getenv("AWS_ACCESS_KEY_ID") + secretKey = System.getenv("AWS_SECRET_ACCESS_KEY") + sessionToken = System.getenv("AWS_SESSION_TOKEN") } } } diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index f373f37a..b11741a1 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=efe9a3d147d948d7528a9887fa35abcf24ca1a43ad06439996490f77569b02d1 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-all.zip +distributionSha256Sum=16f2b95838c1ddcf7242b1c39e7bbbb43c842f1f1a1a0dc4959b6d4d68abcac3 +distributionUrl=https\://services.gradle.org/distributions/gradle-9.2.0-all.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/release-notes/opensearch-skills.release-notes-3.3.0.0.md b/release-notes/opensearch-skills.release-notes-3.3.0.0.md new file mode 100644 index 00000000..8a549a08 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.3.0.0.md @@ -0,0 +1,20 @@ +## Version 3.3.0 Release Notes + +Compatible with OpenSearch and OpenSearch Dashboards version 3.3.0 + +### Features +* Log patterns analysis tool ([#625](https://github.com/opensearch-project/skills/pull/625)) +* Data Distribution Tool ([#634](https://github.com/opensearch-project/skills/pull/634)) + +### Enhancements +* Add more information in ppl tool when passing to sagemaker ([#636](https://github.com/opensearch-project/skills/pull/636)) + +### Bug Fixes +* Delete-single-baseline ([#641](https://github.com/opensearch-project/skills/pull/641)) +* Fix WebSearchTool issue ([#639](https://github.com/opensearch-project/skills/pull/639)) + +### Infrastructure +* Update System.env syntax for Gradle 9 compatibility ([#630](https://github.com/opensearch-project/skills/pull/630)) + +### Maintenance +* Increment version to 3.3.0-SNAPSHOT ([#626](https://github.com/opensearch-project/skills/pull/626)) \ No newline at end of file diff --git a/release-notes/opensearch-skills.release-notes-3.3.2.0.md b/release-notes/opensearch-skills.release-notes-3.3.2.0.md new file mode 100644 index 00000000..dfddd694 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.3.2.0.md @@ -0,0 +1,6 @@ +## Version 3.3.2 Release Notes + +Compatible with OpenSearch 3.3.2 and OpenSearch Dashboards 3.3.0 + +### Bug Fixes +* Fix regex bypass issue ([#656](https://github.com/opensearch-project/skills/pull/656)) diff --git a/release-notes/opensearch-skills.release-notes-3.6.0.0.md b/release-notes/opensearch-skills.release-notes-3.6.0.0.md new file mode 100644 index 00000000..6982daac --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.6.0.0.md @@ -0,0 +1,17 @@ +## Version 3.6.0 Release Notes + +Compatible with OpenSearch and OpenSearch Dashboards version 3.6.0 + +### Features + +* Add SearchAroundTool to search N documents around a given document ([#702](https://github.com/opensearch-project/skills/pull/702)) +* Add MetricChangeAnalysisTool for detecting and analyzing metric changes via percentile comparison between baseline and selection periods ([#698](https://github.com/opensearch-project/skills/pull/698)) + +### Enhancements + +* Add filter support for LogPatternAnalysisTool to enable log pattern analysis for specific services ([#707](https://github.com/opensearch-project/skills/pull/707)) +* Update default tool descriptions for LogPatternAnalysisTool and DataDistributionTool to improve clarity for LLM usage ([#703](https://github.com/opensearch-project/skills/pull/703)) + +### Maintenance + +* Update Apache Spark dependencies (spark-common-utils_2.13) from 3.5.4 to 3.5.8 ([#713](https://github.com/opensearch-project/skills/pull/713)) diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 5de1227d..a8faae88 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -16,12 +16,14 @@ import org.opensearch.agent.tools.DataDistributionTool; import org.opensearch.agent.tools.LogPatternAnalysisTool; import org.opensearch.agent.tools.LogPatternTool; +import org.opensearch.agent.tools.MetricChangeAnalysisTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; import org.opensearch.agent.tools.RAGTool; import org.opensearch.agent.tools.SearchAlertsTool; import org.opensearch.agent.tools.SearchAnomalyDetectorsTool; import org.opensearch.agent.tools.SearchAnomalyResultsTool; +import org.opensearch.agent.tools.SearchAroundDocumentTool; import org.opensearch.agent.tools.SearchMonitorsTool; import org.opensearch.agent.tools.VectorDBTool; import org.opensearch.agent.tools.WebSearchTool; @@ -102,6 +104,8 @@ public Collection createComponents( WebSearchTool.Factory.getInstance().init(threadPool); LogPatternAnalysisTool.Factory.getInstance().init(client); DataDistributionTool.Factory.getInstance().init(client); + SearchAroundDocumentTool.Factory.getInstance().init(client, xContentRegistry); + MetricChangeAnalysisTool.Factory.getInstance().init(client); return Collections.emptyList(); } @@ -122,7 +126,9 @@ public List> getToolFactories() { LogPatternTool.Factory.getInstance(), WebSearchTool.Factory.getInstance(), LogPatternAnalysisTool.Factory.getInstance(), - DataDistributionTool.Factory.getInstance() + DataDistributionTool.Factory.getInstance(), + SearchAroundDocumentTool.Factory.getInstance(), + MetricChangeAnalysisTool.Factory.getInstance() ); } diff --git a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java index 1e6dde61..b133f999 100644 --- a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java +++ b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java @@ -5,16 +5,17 @@ package org.opensearch.agent.tools; +import static org.opensearch.agent.tools.DataFetchingHelper.DATE_FORMAT_PATTERN; +import static org.opensearch.agent.tools.DataFetchingHelper.NUMBER_FIELD_TYPES; +import static org.opensearch.agent.tools.DataFetchingHelper.QUERY_TYPE_PPL; import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.time.LocalDateTime; -import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; @@ -26,25 +27,14 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.math.NumberUtils; -import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; -import org.opensearch.action.search.SearchRequest; import org.opensearch.agent.tools.utils.PPLExecuteHelper; -import org.opensearch.agent.tools.utils.ToolHelper; -import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.ToolUtils; -import org.opensearch.search.SearchHit; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.transport.client.Client; -import com.google.gson.reflect.TypeToken; - import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -113,28 +103,13 @@ public class DataDistributionTool implements Tool { public static final String STRICT_FIELD = "strict"; private static final String DEFAULT_DESCRIPTION = - "This tool analyzes data distribution differences between time ranges or provides single dataset insights."; - private static final String DEFAULT_TIME_FIELD = "@timestamp"; - - private static final String PARAM_INDEX = "index"; - private static final String PARAM_TIME_FIELD = "timeField"; - private static final String PARAM_SELECTION_TIME_RANGE_START = "selectionTimeRangeStart"; - private static final String PARAM_SELECTION_TIME_RANGE_END = "selectionTimeRangeEnd"; - private static final String PARAM_BASELINE_TIME_RANGE_START = "baselineTimeRangeStart"; - private static final String PARAM_BASELINE_TIME_RANGE_END = "baselineTimeRangeEnd"; - private static final String PARAM_SIZE = "size"; - private static final String PARAM_QUERY_TYPE = "queryType"; - private static final String PARAM_FILTER = "filter"; - private static final String PARAM_DSL = "dsl"; - private static final String QUERY_TYPE_PPL = "ppl"; - private static final String QUERY_TYPE_DSL = "dsl"; - private static final String DEFAULT_SIZE = "1000"; - private static final String DATE_FORMAT_PATTERN = "yyyy-MM-dd HH:mm:ss"; + "Analyzes field value distributions in a target time range, optionally compared to a baseline. " + + "Use to identify which fields changed most when investigating anomalies. " + + "Two modes: (1) Comparison (baseline provided): ranks fields by divergence. " + + "(2) Single (no baseline): summarizes field distributions."; private static final Set USEFUL_FIELD_TYPES = Set .of("keyword", "boolean", "text", "byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); - private static final Set NUMBER_FIELD_TYPES = Set - .of("byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); private static final int DEFAULT_COMPARISON_RESULT_LIMIT = 10; private static final int DEFAULT_SINGLE_ANALYSIS_RESULT_LIMIT = 30; @@ -162,141 +137,51 @@ public class DataDistributionTool implements Tool { }, "selectionTimeRangeStart": { "type": "string", - "description": "Start time for analysis period" + "description": "Start of target period (format: yyyy-MM-dd HH:mm:ss)" }, "selectionTimeRangeEnd": { "type": "string", - "description": "End time for analysis period" + "description": "End of target period (format: yyyy-MM-dd HH:mm:ss)" }, "baselineTimeRangeStart": { "type": "string", - "description": "Start time for baseline period (optional)" + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baselineTimeRangeEnd" }, "baselineTimeRangeEnd": { "type": "string", - "description": "End time for baseline period (optional)" + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baselineTimeRangeStart" }, "size": { "type": "integer", - "description": "Maximum number of documents to analyze (default: 1000)" + "description": "Max documents to sample (default: 1000, max: 10000)" }, "queryType": { "type": "string", - "description": "Query type: 'ppl' or 'dsl' (default: 'dsl')" + "description": "Query type: 'dsl' (default) or 'ppl'" }, "filter": { "type": "array", "items": { "type": "string" }, - "description": "Additional DSL query conditions for filtering (optional)" + "description": "Additional DSL filter clauses as JSON strings" }, "dsl": { "type": "string", - "description": "Complete raw DSL query as JSON string (optional)" + "description": "Complete DSL query as JSON string" }, "ppl": { "type": "string", - "description": "Complete PPL statement without time information (optional)" + "description": "PPL query without time filtering (added automatically)" } }, - "required": ["index", "selectionTimeRangeStart", "selectionTimeRangeEnd"], + "required": ["index", "timeField", "selectionTimeRangeStart", "selectionTimeRangeEnd"], "additionalProperties": false } """; - public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); - - /** - * Parameter class to hold analysis parameters with validation - */ - private static class AnalysisParameters { - final String index; - final String timeField; - final String selectionTimeRangeStart; - final String selectionTimeRangeEnd; - final String baselineTimeRangeStart; - final String baselineTimeRangeEnd; - final int size; - final String queryType; - final List filter; - final String dsl; - final String ppl; - - /** - * Constructs analysis parameters from input map with default values - * - * @param parameters Input parameter map from user request - */ - AnalysisParameters(Map parameters) { - this.index = parameters.getOrDefault(PARAM_INDEX, ""); - this.timeField = parameters.getOrDefault(PARAM_TIME_FIELD, DEFAULT_TIME_FIELD); - this.selectionTimeRangeStart = parameters.getOrDefault(PARAM_SELECTION_TIME_RANGE_START, ""); - this.selectionTimeRangeEnd = parameters.getOrDefault(PARAM_SELECTION_TIME_RANGE_END, ""); - this.baselineTimeRangeStart = parameters.getOrDefault(PARAM_BASELINE_TIME_RANGE_START, ""); - this.baselineTimeRangeEnd = parameters.getOrDefault(PARAM_BASELINE_TIME_RANGE_END, ""); - - try { - this.size = Integer.parseInt(parameters.getOrDefault(PARAM_SIZE, DEFAULT_SIZE)); - if (this.size > MAX_SIZE_LIMIT) { - throw new IllegalArgumentException("Size parameter exceeds maximum limit of " + MAX_SIZE_LIMIT + ", got: " + this.size); - } - } catch (NumberFormatException e) { - throw new IllegalArgumentException( - "Invalid 'size' parameter: must be a valid integer, got '" + parameters.get(PARAM_SIZE) + "'" - ); - } - - this.queryType = parameters.getOrDefault(PARAM_QUERY_TYPE, QUERY_TYPE_DSL); - - String filterParam = parameters.getOrDefault(PARAM_FILTER, ""); - if (Strings.isEmpty(filterParam)) { - this.filter = List.of(); - } else { - try { - this.filter = Arrays.asList(gson.fromJson(filterParam, String[].class)); - } catch (Exception e) { - throw new IllegalArgumentException( - "Invalid 'filter' parameter: must be a valid JSON array of strings, got '" - + filterParam - + "'. Example: [\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]" - ); - } - } - - this.dsl = parameters.getOrDefault(PARAM_DSL, ""); - this.ppl = parameters.getOrDefault(QUERY_TYPE_PPL, ""); - } - - /** - * Validates required parameters are present - * - * @throws IllegalArgumentException if required parameters are missing - */ - void validate() { - List missingParams = new ArrayList<>(); - if (Strings.isEmpty(index)) - missingParams.add(PARAM_INDEX); - if (Strings.isEmpty(selectionTimeRangeStart)) - missingParams.add(PARAM_SELECTION_TIME_RANGE_START); - if (Strings.isEmpty(selectionTimeRangeEnd)) - missingParams.add(PARAM_SELECTION_TIME_RANGE_END); - if (Strings.isEmpty(timeField)) - missingParams.add(PARAM_TIME_FIELD); - if (!missingParams.isEmpty()) { - throw new IllegalArgumentException("Missing required parameters: " + String.join(", ", missingParams)); - } - } - - /** - * Checks if baseline time range is provided for comparison analysis - * - * @return true if both baseline start and end times are provided - */ - boolean hasBaselineTime() { - return !Strings.isEmpty(baselineTimeRangeStart) && !Strings.isEmpty(baselineTimeRangeEnd); - } - } + public static final Map DEFAULT_ATTRIBUTES = Map + .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class)), STRICT_FIELD, false); /** * Result class for data distribution analysis @@ -319,6 +204,7 @@ private record ChangeItem(String value, double selectionPercentage, Double basel @Getter private String version; private Client client; + private DataFetchingHelper dataFetchingHelper; /** * Constructs a DataDistributionTool with the given OpenSearch client @@ -327,6 +213,7 @@ private record ChangeItem(String value, double selectionPercentage, Double basel */ public DataDistributionTool(Client client) { this.client = client; + this.dataFetchingHelper = new DataFetchingHelper(client); } @Override @@ -345,7 +232,7 @@ public void setAttributes(Map map) {} @Override public boolean validate(Map map) { try { - new AnalysisParameters(map).validate(); + new DataFetchingHelper.AnalysisParameters(map).validate(); } catch (Exception e) { log.error("Failed to validate the data distribution analysis parameter: {}", e.getMessage()); return false; @@ -366,7 +253,7 @@ public void run(Map originalParameters, ActionListener li try { Map parameters = ToolUtils.extractInputParameters(originalParameters, DEFAULT_ATTRIBUTES); log.debug("Starting data distribution analysis with parameters: {}", parameters.keySet()); - AnalysisParameters params = new AnalysisParameters(parameters); + DataFetchingHelper.AnalysisParameters params = new DataFetchingHelper.AnalysisParameters(parameters); if (QUERY_TYPE_PPL.equals(params.queryType)) { executePPLAnalysis(params, listener); @@ -389,8 +276,8 @@ public void run(Map originalParameters, ActionListener li * @param params Analysis parameters containing query details * @param listener Action listener for handling results */ - private void executePPLAnalysis(AnalysisParameters params, ActionListener listener) { - if (params.hasBaselineTime()) { + private void executePPLAnalysis(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { + if (params.hasBaselineTimeRange()) { fetchPPLComparisonData(params, listener); } else { String pplQuery = buildPPLQuery( @@ -423,8 +310,8 @@ private void executePPLAnalysis(AnalysisParameters params, ActionListener * @param params Analysis parameters containing query details * @param listener Action listener for handling results */ - private void executeDSLAnalysis(AnalysisParameters params, ActionListener listener) { - if (params.hasBaselineTime()) { + private void executeDSLAnalysis(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { + if (params.hasBaselineTimeRange()) { fetchComparisonData(params, listener); } else { getSingleDataDistribution(params, listener); @@ -438,7 +325,7 @@ private void executeDSLAnalysis(AnalysisParameters params, ActionListener * @param params Analysis parameters containing time ranges * @param listener Action listener for handling comparison results */ - private void fetchComparisonData(AnalysisParameters params, ActionListener listener) { + private void fetchComparisonData(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { fetchIndexData(params.selectionTimeRangeStart, params.selectionTimeRangeEnd, params, ActionListener.wrap(selectionData -> { fetchIndexData(params.baselineTimeRangeStart, params.baselineTimeRangeEnd, params, ActionListener.wrap(baselineData -> { try { @@ -465,7 +352,7 @@ private void fetchComparisonData(AnalysisParameters params, ActionListener void getSingleDataDistribution(AnalysisParameters params, ActionListener listener) { + private void getSingleDataDistribution(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { fetchIndexData(params.selectionTimeRangeStart, params.selectionTimeRangeEnd, params, ActionListener.wrap(data -> { try { if (data.isEmpty()) { @@ -480,48 +367,6 @@ private void getSingleDataDistribution(AnalysisParameters params, ActionList }, listener::onFailure)); } - /** - * Formats time string to ISO 8601 format for OpenSearch compatibility - * - * @param timeString Input time string - * @return Formatted time string in ISO 8601 format - * @throws DateTimeParseException if time string cannot be parsed - */ - private String formatTimeString(String timeString) throws DateTimeParseException { - log.debug("Attempting to parse time string: {}", timeString); - - // Try parsing with zone first - try { - if (timeString.endsWith("Z")) { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss'Z'", Locale.ROOT); - ZonedDateTime dateTime = ZonedDateTime.parse(timeString, formatter.withZone(ZoneOffset.UTC)); - return dateTime.format(DateTimeFormatter.ISO_INSTANT); - } - } catch (DateTimeParseException e) { - log.debug("Failed to parse as UTC time: {}", e.getMessage()); - } - - // Try parsing as local time without zone - try { - DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DATE_FORMAT_PATTERN, Locale.ROOT); - LocalDateTime localDateTime = LocalDateTime.parse(timeString, formatter); - ZonedDateTime zonedDateTime = localDateTime.atOffset(ZoneOffset.UTC).toZonedDateTime(); - return zonedDateTime.format(DateTimeFormatter.ISO_INSTANT); - } catch (DateTimeParseException e) { - log.debug("Failed to parse as local time: {}", e.getMessage()); - } - - // Try ISO format - try { - ZonedDateTime dateTime = ZonedDateTime.parse(timeString); - return dateTime.format(DateTimeFormatter.ISO_INSTANT); - } catch (DateTimeParseException e) { - log.debug("Failed to parse as ISO format: {}", e.getMessage()); - } - - throw new DateTimeParseException("Unable to parse time string: " + timeString, timeString, 0); - } - /** * Fetches data from the specified index within the given time range * @@ -533,82 +378,28 @@ private String formatTimeString(String timeString) throws DateTimeParseException private void fetchIndexData( String startTime, String endTime, - AnalysisParameters params, + DataFetchingHelper.AnalysisParameters params, ActionListener>> listener ) { - try { - String formattedStartTime = formatTimeString(startTime); - String formattedEndTime = formatTimeString(endTime); - BoolQueryBuilder query; - - // Use raw DSL query if provided - if (!Strings.isEmpty(params.dsl)) { - try { - Map dslMap = gson.fromJson(params.dsl, new TypeToken>() { - }.getType()); - query = QueryBuilders.boolQuery(); - - // Handle DSL query structure - check if it has "query" wrapper - if (dslMap.containsKey("query")) { - @SuppressWarnings("unchecked") - Map queryMap = (Map) dslMap.get("query"); - log.debug("Processing DSL query with wrapper: {}", queryMap); - - // Build the DSL query directly into the main query - buildQueryFromMap(queryMap, query); - - // Add time range filter - query.filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); - } else { - log.debug("Processing DSL query without wrapper: {}", dslMap); - buildQueryFromMap(dslMap, query); - // Add time range filter to the raw DSL query - query.filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); - } - - log.debug("Final DSL query: {}", query.toString()); - } catch (Exception e) { - log.warn("Failed to parse raw DSL query: {}, falling back to time range only", params.dsl, e); - query = QueryBuilders - .boolQuery() - .filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); - } - } else { - query = QueryBuilders - .boolQuery() - .filter(new RangeQueryBuilder(params.timeField).gte(formattedStartTime).lte(formattedEndTime)); - - // Add additional filters if provided - if (!params.filter.isEmpty()) { - for (String filterStr : params.filter) { - try { - Map filterMap = gson.fromJson(filterStr, new TypeToken>() { - }.getType()); - BoolQueryBuilder filterQuery = QueryBuilders.boolQuery(); - buildQueryFromMap(filterMap, filterQuery); - query.must(filterQuery); - } catch (Exception e) { - log.warn("Failed to parse filter parameter: {}", filterStr, e); - } - } - } - } - - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query).size(params.size); + // Convert AnalysisParameters to helper parameters + Map helperParams = new HashMap<>(); + helperParams.put("index", params.index); + helperParams.put("timeField", params.timeField); + helperParams.put("size", String.valueOf(params.size)); + helperParams.put("queryType", params.queryType); + if (!Strings.isEmpty(params.dsl)) { + helperParams.put("dsl", params.dsl); + } + if (!params.filter.isEmpty()) { + helperParams.put("filter", gson.toJson(params.filter)); + } + if (!Strings.isEmpty(params.ppl)) { + helperParams.put("ppl", params.ppl); + } - SearchRequest request = new SearchRequest(params.index).source(sourceBuilder); + DataFetchingHelper.AnalysisParameters helperAnalysisParams = new DataFetchingHelper.AnalysisParameters(helperParams); - client.search(request, ActionListener.wrap(response -> { - List> data = Arrays - .stream(response.getHits().getHits()) - .map(SearchHit::getSourceAsMap) - .collect(Collectors.toList()); - listener.onResponse(data); - }, listener::onFailure)); - } catch (Exception e) { - log.error("Failed to format time strings: {}", e.getMessage()); - listener.onFailure(new IllegalArgumentException("Invalid time format: " + e.getMessage(), e)); - } + dataFetchingHelper.fetchIndexData(startTime, endTime, helperAnalysisParams, listener); } /** @@ -618,7 +409,7 @@ private void fetchIndexData( * @param params Analysis parameters containing time ranges * @param listener Action listener for handling comparison results */ - private void fetchPPLComparisonData(AnalysisParameters params, ActionListener listener) { + private void fetchPPLComparisonData(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { String selectionQuery = buildPPLQuery( params.index, params.timeField, @@ -840,38 +631,7 @@ private record GroupedDistributions(Map groupedSelectionDist, Ma * @param listener Action listener for handling field types result */ private void getFieldTypes(String index, ActionListener> listener) { - try { - GetMappingsRequest getMappingsRequest = new GetMappingsRequest().indices(index); - client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(response -> { - try { - Map mappings = response.getMappings(); - if (mappings.isEmpty()) { - listener.onResponse(Map.of()); - return; - } - - MappingMetadata mappingMetadata = mappings.values().iterator().next(); - Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); - if (mappingSource == null) { - listener.onResponse(Map.of()); - return; - } - - Map fieldsToType = new HashMap<>(); - ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", true); - listener.onResponse(fieldsToType); - } catch (Exception e) { - log.error("Failed to process field types for index: {}", index, e); - listener.onResponse(Map.of()); - } - }, e -> { - log.error("Failed to get field types for index: {}", index, e); - listener.onResponse(Map.of()); - })); - } catch (Exception e) { - log.error("Failed to create getMappings request for index: {}", index, e); - listener.onResponse(Map.of()); - } + dataFetchingHelper.getFieldTypes(index, listener); } /** @@ -940,20 +700,7 @@ private List getUsefulFields(List> data, Map doc, String field) { - String[] parts = field.split("\\."); - Object current = doc; - - for (String part : parts) { - if (current instanceof Map) { - current = ((Map) current).get(part); - } else if (current instanceof List) { - return gson.toJson(current); - } else { - return null; - } - } - - return current; + return dataFetchingHelper.getFlattenedValue(doc, field); } /** @@ -1043,12 +790,7 @@ private List getFieldsFromData(List> data) { * @return Set of number field names */ private Set getNumberFields(Map fieldTypes) { - return fieldTypes - .entrySet() - .stream() - .filter(entry -> NUMBER_FIELD_TYPES.contains(entry.getValue())) - .map(Map.Entry::getKey) - .collect(Collectors.toSet()); + return dataFetchingHelper.getNumberFields(fieldTypes); } /** @@ -1150,303 +892,6 @@ private List formatComparisonSummary(List differ }).collect(Collectors.toList()); } - /** - * Builds query conditions from filter map for DSL queries - * - * @param filterMap Filter conditions as map - * @param queryBuilder Query builder to add conditions to - */ - private void buildQueryFromMap(Map filterMap, BoolQueryBuilder queryBuilder) { - log.debug("Building query from map: {}", filterMap); - - for (Map.Entry entry : filterMap.entrySet()) { - String key = entry.getKey(); - Object value = entry.getValue(); - - log.debug("Processing query key: {}, value: {}", key, value); - - // Handle special query types - switch (key) { - case "match_all" -> { - // {"match_all": {}} - log.debug("Adding match_all query"); - queryBuilder.must(QueryBuilders.matchAllQuery()); - } - case "match_none" -> { - // {"match_none": {}} - log.debug("Adding match_none query"); - queryBuilder.mustNot(QueryBuilders.matchAllQuery()); - } - case "bool" -> { - if (value instanceof Map) { - log.debug("Processing bool query: {}", value); - processBoolQuery((Map) value, queryBuilder); - } - } - case "term" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - log.debug("Adding term query: {}", valueMap); - // {"term": {"field": "value"}} - for (Map.Entry termEntry : valueMap.entrySet()) { - log.debug("Term query - field: {}, value: {}", termEntry.getKey(), termEntry.getValue()); - queryBuilder.must(QueryBuilders.termQuery(termEntry.getKey(), termEntry.getValue())); - } - } - } - case "wildcard" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - log.debug("Adding wildcard query: {}", valueMap); - // {"wildcard": {"field": "pattern"}} - for (Map.Entry wildcardEntry : valueMap.entrySet()) { - log.debug("Wildcard query - field: {}, pattern: {}", wildcardEntry.getKey(), wildcardEntry.getValue()); - queryBuilder.must(QueryBuilders.wildcardQuery(wildcardEntry.getKey(), wildcardEntry.getValue().toString())); - } - } - } - case "range" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"range": {"field": {"gte": 1, "lte": 10}}} - for (Map.Entry rangeEntry : valueMap.entrySet()) { - String field = rangeEntry.getKey(); - Object rangeValue = rangeEntry.getValue(); - if (rangeValue instanceof Map) { - processRangeQuery(field, rangeValue, queryBuilder); - } - } - } - } - case "match" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"match": {"field": "value"}} - for (Map.Entry matchEntry : valueMap.entrySet()) { - queryBuilder.must(QueryBuilders.matchQuery(matchEntry.getKey(), matchEntry.getValue())); - } - } - } - case "match_phrase" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"match_phrase": {"field": "value"}} - for (Map.Entry matchPhraseEntry : valueMap.entrySet()) { - queryBuilder.must(QueryBuilders.matchPhraseQuery(matchPhraseEntry.getKey(), matchPhraseEntry.getValue())); - } - } - } - case "prefix" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"prefix": {"field": "value"}} - for (Map.Entry prefixEntry : valueMap.entrySet()) { - queryBuilder.must(QueryBuilders.prefixQuery(prefixEntry.getKey(), prefixEntry.getValue().toString())); - } - } - } - case "exists" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"exists": {"field": "fieldname"}} - Object fieldValue = valueMap.get("field"); - if (fieldValue != null) { - queryBuilder.must(QueryBuilders.existsQuery(fieldValue.toString())); - } - } - } - case "regexp" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"regexp": {"field": "pattern"}} - for (Map.Entry regexpEntry : valueMap.entrySet()) { - queryBuilder.must(QueryBuilders.regexpQuery(regexpEntry.getKey(), regexpEntry.getValue().toString())); - } - } - } - case "terms" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // {"terms": {"field": ["value1", "value2"]}} - for (Map.Entry termsEntry : valueMap.entrySet()) { - if (termsEntry.getValue() instanceof List) { - queryBuilder.must(QueryBuilders.termsQuery(termsEntry.getKey(), (List) termsEntry.getValue())); - } - } - } - } - case "multi_match" -> { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - Object queryValue = valueMap.get("query"); - Object fieldsValue = valueMap.get("fields"); - if (queryValue != null && fieldsValue instanceof List) { - @SuppressWarnings("unchecked") - List fields = (List) fieldsValue; - queryBuilder.must(QueryBuilders.multiMatchQuery(queryValue, fields.toArray(new String[0]))); - } - } - } - default -> { - // Handle direct field-value pairs or unknown query types - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) value; - // This might be a field with nested operators like {"field": {"term": "value"}} - processNestedQuery(key, valueMap, queryBuilder); - } else { - // Direct field-value mapping - queryBuilder.must(QueryBuilders.termQuery(key, value)); - } - } - } - } - } - - /** - * Processes bool query conditions - * - * @param boolMap Bool query conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processBoolQuery(Map boolMap, BoolQueryBuilder queryBuilder) { - for (Map.Entry boolEntry : boolMap.entrySet()) { - String boolType = boolEntry.getKey(); - Object boolValue = boolEntry.getValue(); - - if (boolValue instanceof List) { - @SuppressWarnings("unchecked") - List> clauses = (List>) boolValue; - for (Map clause : clauses) { - BoolQueryBuilder subQuery = QueryBuilders.boolQuery(); - buildQueryFromMap(clause, subQuery); - switch (boolType) { - case "must" -> queryBuilder.must(subQuery); - case "should" -> queryBuilder.should(subQuery); - case "must_not" -> queryBuilder.mustNot(subQuery); - case "filter" -> queryBuilder.filter(subQuery); - default -> log.warn("Unsupported bool query type: {}", boolType); - } - } - } - } - } - - /** - * Processes nested query conditions for a field - * - * @param field Field name - * @param nestedMap Nested query conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processNestedQuery(String field, Map nestedMap, BoolQueryBuilder queryBuilder) { - for (Map.Entry nestedEntry : nestedMap.entrySet()) { - String operator = nestedEntry.getKey(); - Object operatorValue = nestedEntry.getValue(); - - switch (operator) { - case "term" -> queryBuilder.must(QueryBuilders.termQuery(field, operatorValue)); - case "range" -> processRangeQuery(field, operatorValue, queryBuilder); - case "match" -> queryBuilder.must(QueryBuilders.matchQuery(field, operatorValue)); - case "match_phrase" -> queryBuilder.must(QueryBuilders.matchPhraseQuery(field, operatorValue)); - case "prefix" -> queryBuilder.must(QueryBuilders.prefixQuery(field, operatorValue.toString())); - case "wildcard" -> processWildcardQuery(field, operatorValue, queryBuilder); - case "exists" -> queryBuilder.must(QueryBuilders.existsQuery(field)); - case "regexp" -> processRegexpQuery(field, operatorValue, queryBuilder); - default -> { - // Handle direct field-value mapping for nested structures - if (operatorValue instanceof Map) { - @SuppressWarnings("unchecked") - Map valueMap = (Map) operatorValue; - BoolQueryBuilder nestedQuery = QueryBuilders.boolQuery(); - buildQueryFromMap(Map.of(operator, valueMap), nestedQuery); - queryBuilder.must(nestedQuery); - } else { - log.warn("Unsupported query operator: {}", operator); - } - } - } - } - } - - /** - * Processes range query conditions - * - * @param field Field name - * @param operatorValue Range conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processRangeQuery(String field, Object operatorValue, BoolQueryBuilder queryBuilder) { - if (!(operatorValue instanceof Map)) { - return; - } - - @SuppressWarnings("unchecked") - Map rangeMap = (Map) operatorValue; - RangeQueryBuilder rangeQuery = QueryBuilders.rangeQuery(field); - - rangeMap.forEach((rangeOp, rangeVal) -> { - switch (rangeOp) { - case "gte" -> rangeQuery.gte(rangeVal); - case "lte" -> rangeQuery.lte(rangeVal); - case "gt" -> rangeQuery.gt(rangeVal); - case "lt" -> rangeQuery.lt(rangeVal); - } - }); - - queryBuilder.must(rangeQuery); - } - - /** - * Processes wildcard query conditions - * - * @param field Field name - * @param operatorValue Wildcard conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processWildcardQuery(String field, Object operatorValue, BoolQueryBuilder queryBuilder) { - if (operatorValue instanceof Map) { - @SuppressWarnings("unchecked") - Map wildcardMap = (Map) operatorValue; - Object wildcardValue = wildcardMap.get("value"); - if (wildcardValue != null) { - queryBuilder.must(QueryBuilders.wildcardQuery(field, wildcardValue.toString())); - } - } else { - queryBuilder.must(QueryBuilders.wildcardQuery(field, operatorValue.toString())); - } - } - - /** - * Processes regexp query conditions - * - * @param field Field name - * @param operatorValue Regexp conditions - * @param queryBuilder Query builder to add conditions to - */ - private void processRegexpQuery(String field, Object operatorValue, BoolQueryBuilder queryBuilder) { - if (operatorValue instanceof Map) { - @SuppressWarnings("unchecked") - Map regexpMap = (Map) operatorValue; - Object regexpValue = regexpMap.get("value"); - if (regexpValue != null) { - queryBuilder.must(QueryBuilders.regexpQuery(field, regexpValue.toString())); - } - } else { - queryBuilder.must(QueryBuilders.regexpQuery(field, operatorValue.toString())); - } - } - /** * Parses PPL query result into list of documents * diff --git a/src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java b/src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java new file mode 100644 index 00000000..39b8e45d --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java @@ -0,0 +1,482 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.agent.tools.utils.PPLExecuteHelper; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.transport.client.Client; + +import com.google.gson.reflect.TypeToken; + +import lombok.extern.log4j.Log4j2; + +/** + * Helper class for fetching and processing data from OpenSearch indices. + * Provides common functionality for data analysis tools including: + * - Field type detection and mapping + * - Time-based data fetching with DSL/PPL query support + * - Query building and parameter validation + * - Nested field value extraction + */ +@Log4j2 +public class DataFetchingHelper { + + private static final String DEFAULT_TIME_FIELD = "@timestamp"; + public static final String DATE_FORMAT_PATTERN = "yyyy-MM-dd HH:mm:ss"; + public static final String QUERY_TYPE_PPL = "ppl"; + public static final String QUERY_TYPE_DSL = "dsl"; + private static final String DEFAULT_SIZE = "1000"; + private static final int MAX_SIZE_LIMIT = 10000; + + public static final Set NUMBER_FIELD_TYPES = Set + .of("byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); + + private final Client client; + + /** + * Constructs a DataFetchingHelper with the given OpenSearch client + * + * @param client The OpenSearch client for executing queries + */ + public DataFetchingHelper(Client client) { + this.client = client; + } + + /** + * Parameters for data analysis operations + */ + public static class AnalysisParameters { + public final String index; + public final String timeField; + public final String selectionTimeRangeStart; + public final String selectionTimeRangeEnd; + public final String baselineTimeRangeStart; + public final String baselineTimeRangeEnd; + public final int size; + public final String queryType; + public final List filter; + public final String dsl; + public final String ppl; + + public AnalysisParameters(Map parameters) { + this.index = parameters.getOrDefault("index", ""); + this.timeField = parameters.getOrDefault("timeField", DEFAULT_TIME_FIELD); + this.selectionTimeRangeStart = parameters.getOrDefault("selectionTimeRangeStart", ""); + this.selectionTimeRangeEnd = parameters.getOrDefault("selectionTimeRangeEnd", ""); + this.baselineTimeRangeStart = parameters.getOrDefault("baselineTimeRangeStart", ""); + this.baselineTimeRangeEnd = parameters.getOrDefault("baselineTimeRangeEnd", ""); + + String sizeStr = parameters.getOrDefault("size", DEFAULT_SIZE); + int parsedSize; + try { + parsedSize = Double.valueOf(sizeStr).intValue(); + if (parsedSize <= 0 || parsedSize > MAX_SIZE_LIMIT) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid 'size' parameter: must be between 1 and %d, got '%s'", MAX_SIZE_LIMIT, sizeStr) + ); + } + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid 'size' parameter: '%s', must be a valid integer", sizeStr) + ); + } + this.size = parsedSize; + + this.queryType = parameters.getOrDefault("queryType", QUERY_TYPE_DSL); + + // Parse filter from JSON string to List + String filterParam = parameters.getOrDefault("filter", ""); + if (Strings.isNullOrEmpty(filterParam)) { + this.filter = List.of(); + } else { + try { + this.filter = List.of(gson.fromJson(filterParam, String[].class)); + } catch (Exception e) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Invalid 'filter' parameter: must be a valid JSON array of strings, got '%s'. " + + "Example: [\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]", + filterParam + ) + ); + } + } + + this.dsl = parameters.getOrDefault("dsl", ""); + this.ppl = parameters.getOrDefault("ppl", ""); + } + + /** + * Validates required parameters are present + * + * @throws IllegalArgumentException if required parameters are missing + */ + public void validate() { + if (Strings.isNullOrEmpty(index)) { + throw new IllegalArgumentException("Missing required parameter: 'index'"); + } + if (Strings.isNullOrEmpty(selectionTimeRangeStart) || Strings.isNullOrEmpty(selectionTimeRangeEnd)) { + throw new IllegalArgumentException("Missing required parameters: 'selectionTimeRangeStart' and 'selectionTimeRangeEnd'"); + } + if (!QUERY_TYPE_DSL.equals(queryType) && !QUERY_TYPE_PPL.equals(queryType)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid 'queryType': must be 'dsl' or 'ppl', got '%s'", queryType) + ); + } + } + + public boolean hasBaselineTimeRange() { + return !Strings.isNullOrEmpty(baselineTimeRangeStart) && !Strings.isNullOrEmpty(baselineTimeRangeEnd); + } + } + + /** + * Retrieves field type mappings from the specified index + * + * @param index The index name + * @param listener Action listener for handling the field type map or failures + */ + public void getFieldTypes(String index, ActionListener> listener) { + GetMappingsRequest request = new GetMappingsRequest().indices(index); + + client.admin().indices().getMappings(request, ActionListener.wrap(response -> { + Map fieldTypes = new HashMap<>(); + + for (Map.Entry entry : response.getMappings().entrySet()) { + MappingMetadata metadata = entry.getValue(); + Map mappingSource = metadata.getSourceAsMap(); + + if (mappingSource.containsKey("properties")) { + @SuppressWarnings("unchecked") + Map properties = (Map) mappingSource.get("properties"); + extractFieldTypes(properties, "", fieldTypes); + } + } + + listener.onResponse(fieldTypes); + }, listener::onFailure)); + } + + /** + * Recursively extracts field types from mapping properties + */ + private void extractFieldTypes(Map properties, String prefix, Map fieldTypes) { + for (Map.Entry entry : properties.entrySet()) { + String fieldName = prefix.isEmpty() ? entry.getKey() : prefix + "." + entry.getKey(); + + @SuppressWarnings("unchecked") + Map fieldProps = (Map) entry.getValue(); + + if (fieldProps.containsKey("type")) { + fieldTypes.put(fieldName, (String) fieldProps.get("type")); + } + + if (fieldProps.containsKey("properties")) { + @SuppressWarnings("unchecked") + Map nestedProps = (Map) fieldProps.get("properties"); + extractFieldTypes(nestedProps, fieldName, fieldTypes); + } + } + } + + /** + * Fetches data from index within the specified time range + * + * @param timeRangeStart Start time string + * @param timeRangeEnd End time string + * @param params Analysis parameters + * @param listener Action listener for handling the fetched data or failures + */ + public void fetchIndexData( + String timeRangeStart, + String timeRangeEnd, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + if (QUERY_TYPE_PPL.equals(params.queryType)) { + fetchDataWithPPL(timeRangeStart, timeRangeEnd, params, listener); + } else { + fetchDataWithDSL(timeRangeStart, timeRangeEnd, params, listener); + } + } catch (Exception e) { + log.error("Failed to fetch index data", e); + listener.onFailure(e); + } + } + + /** + * Builds a BoolQueryBuilder with time range filter and optional custom filters. + * Can be used with any SearchSourceBuilder (for documents or aggregations). + * + * @param timeRangeStart Start time string + * @param timeRangeEnd End time string + * @param params Analysis parameters containing timeField, dsl, and filter settings + * @return BoolQueryBuilder with time range and custom filters applied + */ + public BoolQueryBuilder buildFilterQuery(String timeRangeStart, String timeRangeEnd, AnalysisParameters params) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + + // Add time range filter + RangeQueryBuilder timeRangeQuery = QueryBuilders + .rangeQuery(params.timeField) + .gte(formatTimeString(timeRangeStart)) + .lte(formatTimeString(timeRangeEnd)) + .format("strict_date_optional_time||epoch_millis"); + boolQuery.must(timeRangeQuery); + + // Add custom query if provided + if (!Strings.isNullOrEmpty(params.dsl)) { + Map dslMap = buildQueryFromMap(params.dsl); + boolQuery.must(QueryBuilders.wrapperQuery(gson.toJson(dslMap))); + } else if (!params.filter.isEmpty()) { + for (String filterStr : params.filter) { + Map filterMap = buildQueryFromMap(filterStr); + boolQuery.must(QueryBuilders.wrapperQuery(gson.toJson(filterMap))); + } + } + + return boolQuery; + } + + /** + * Fetches data using DSL query + */ + private void fetchDataWithDSL( + String timeRangeStart, + String timeRangeEnd, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + BoolQueryBuilder boolQuery = buildFilterQuery(timeRangeStart, timeRangeEnd, params); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQuery).size(params.size).fetchSource(true); + + SearchRequest searchRequest = new SearchRequest(params.index).source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(response -> { + List> data = new ArrayList<>(); + for (SearchHit hit : response.getHits().getHits()) { + data.add(hit.getSourceAsMap()); + } + listener.onResponse(data); + }, listener::onFailure)); + + } catch (Exception e) { + log.error("Failed to fetch data with DSL", e); + listener.onFailure(e); + } + } + + /** + * Fetches data using PPL query + */ + private void fetchDataWithPPL( + String timeRangeStart, + String timeRangeEnd, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + String pplQuery = buildPPLQuery(timeRangeStart, timeRangeEnd, params); + + // Use PPLExecuteHelper with parser function + Function, List>> pplResultParser = this::parsePPLResult; + + PPLExecuteHelper.executePPLAndParseResult(client, pplQuery, pplResultParser, listener); + + } catch (Exception e) { + log.error("Failed to fetch data with PPL", e); + listener.onFailure(e); + } + } + + /** + * Parses PPL query result into list of documents + */ + private List> parsePPLResult(Map pplResult) { + Object datarowsObj = pplResult.get("datarows"); + Object schemaObj = pplResult.get("schema"); + + if (!(datarowsObj instanceof List) || !(schemaObj instanceof List)) { + log.warn("Invalid PPL result format"); + return new ArrayList<>(); + } + + @SuppressWarnings("unchecked") + List> datarows = (List>) datarowsObj; + @SuppressWarnings("unchecked") + List> schema = (List>) schemaObj; + + List fieldNames = new ArrayList<>(); + for (Map field : schema) { + fieldNames.add((String) field.get("name")); + } + + List> documents = new ArrayList<>(); + for (List row : datarows) { + Map doc = new HashMap<>(); + for (int i = 0; i < Math.min(row.size(), fieldNames.size()); i++) { + doc.put(fieldNames.get(i), row.get(i)); + } + documents.add(doc); + } + + return documents; + } + + /** + * Builds PPL query with time range filter + */ + private String buildPPLQuery(String timeRangeStart, String timeRangeEnd, AnalysisParameters params) { + String pplBase = !Strings.isNullOrEmpty(params.ppl) ? params.ppl : "source=" + params.index; + + String timeFilter = String + .format( + Locale.ROOT, + "%s >= '%s' AND %s <= '%s'", + params.timeField, + formatTimeString(timeRangeStart), + params.timeField, + formatTimeString(timeRangeEnd) + ); + + return String.format(Locale.ROOT, "%s | where %s | head %d", pplBase, timeFilter, params.size); + } + + /** + * Formats time string to ISO 8601 format for OpenSearch compatibility + * + * @param timeString Input time string + * @return Formatted time string in ISO 8601 format + * @throws DateTimeParseException if time string cannot be parsed + */ + private String formatTimeString(String timeString) throws DateTimeParseException { + log.debug("Attempting to parse time string: {}", timeString); + + // Try parsing with zone first + try { + if (timeString.endsWith("Z")) { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss'Z'", Locale.ROOT); + ZonedDateTime dateTime = ZonedDateTime.parse(timeString, formatter.withZone(ZoneOffset.UTC)); + return dateTime.format(DateTimeFormatter.ISO_INSTANT); + } + } catch (DateTimeParseException e) { + log.debug("Failed to parse as UTC time: {}", e.getMessage()); + } + + // Try parsing as local time without zone + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DATE_FORMAT_PATTERN, Locale.ROOT); + LocalDateTime localDateTime = LocalDateTime.parse(timeString, formatter); + ZonedDateTime zonedDateTime = localDateTime.atOffset(ZoneOffset.UTC).toZonedDateTime(); + return zonedDateTime.format(DateTimeFormatter.ISO_INSTANT); + } catch (DateTimeParseException e) { + log.debug("Failed to parse as local time: {}", e.getMessage()); + } + + // Try ISO format + try { + ZonedDateTime dateTime = ZonedDateTime.parse(timeString); + return dateTime.format(DateTimeFormatter.ISO_INSTANT); + } catch (DateTimeParseException e) { + log.debug("Failed to parse as ISO format: {}", e.getMessage()); + } + + throw new RuntimeException("Invalid time format: " + timeString); + } + + /** + * Builds query map from JSON string + */ + public Map buildQueryFromMap(String queryStr) { + if (Strings.isNullOrEmpty(queryStr)) { + return new HashMap<>(); + } + + try { + String normalizedQuery = queryStr.trim().replace("'", "\""); + return gson.fromJson(normalizedQuery, new TypeToken>() { + }.getType()); + } catch (Exception e) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid query format: %s. Error: %s", queryStr, e.getMessage())); + } + } + + /** + * Filters numeric fields from field type mappings + * + * @param fieldTypes Map of field names to their types + * @return Set of numeric field names + */ + public Set getNumberFields(Map fieldTypes) { + Set numberFields = new HashSet<>(); + for (Map.Entry entry : fieldTypes.entrySet()) { + if (NUMBER_FIELD_TYPES.contains(entry.getValue())) { + numberFields.add(entry.getKey()); + } + } + return numberFields; + } + + /** + * Extracts nested field value from document using dot notation + * + * @param doc The document map + * @param field The field path (e.g., "metrics.response_time") + * @return The field value or null if not found + */ + public Object getFlattenedValue(Map doc, String field) { + if (doc == null || field == null) { + return null; + } + + String[] parts = field.split("\\."); + Object current = doc; + + for (String part : parts) { + if (!(current instanceof Map)) { + return null; + } + @SuppressWarnings("unchecked") + Map currentMap = (Map) current; + current = currentMap.get(part); + if (current == null) { + return null; + } + } + + return current; + } +} diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java index 71508d95..741cc94c 100644 --- a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java +++ b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java @@ -92,12 +92,17 @@ public class LogPatternAnalysisTool implements Tool { public static final String STRICT_FIELD = "strict"; // Constants - private static final String DEFAULT_DESCRIPTION = - "This is a tool used to detect selection log patterns by the patterns command in PPL or to detect selection log sequences by the log clustering algorithm."; + private static final String DEFAULT_DESCRIPTION = "Analyzes log patterns in a target time range, optionally compared to a baseline. " + + "Use when investigating incidents to find new, spiking, or anomalous log messages. " + + "Three modes: " + + "(1) Sequence analysis (traceFieldName + baseline): groups logs by trace ID, returns exceptional trace sequences. " + + "(2) Pattern diff (baseline, no traceFieldName): compares pattern frequencies, returns highest-lift patterns. " + + "(3) Log insight (no baseline): finds top error/warning patterns with sample logs."; private static final double LOG_VECTORS_CLUSTERING_THRESHOLD = 0.5; private static final double LOG_PATTERN_THRESHOLD = 0.75; private static final double LOG_PATTERN_LIFT = 3; private static final String DEFAULT_TIME_FIELD = "@timestamp"; + private static final int MAX_LOG_SAMPLE_SIZE = 10000; public static final String DEFAULT_INPUT_SCHEMA = """ @@ -106,35 +111,39 @@ public class LogPatternAnalysisTool implements Tool { "properties": { "index": { "type": "string", - "description": "Target OpenSearch index name containing log data (e.g., 'ss4o_logs-otel-2025.06.24')" + "description": "Target OpenSearch index name" }, "timeField": { "type": "string", - "description": "Date/time field in the index mapping used for time-based filtering" + "description": "Date/time field for filtering" }, "logFieldName": { "type": "string", - "description": "Field containing raw log messages to analyze (e.g., 'body', 'message', 'log')" + "description": "Field containing log message text" }, "traceFieldName": { "type": "string", - "description": "[OPTIONAL] Field for trace/correlation ID to enable sequence analysis (e.g., 'traceId', 'correlationId'). Leave empty for pattern-only analysis." + "description": "Trace/correlation ID field. Enables sequence analysis mode when provided with baseline time range" }, "baseTimeRangeStart": { "type": "string", - "description": "Start time for baseline comparison period (date string in utc timezone, e.g., '2025-06-24 07:33:05')" + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeEnd" }, "baseTimeRangeEnd": { "type": "string", - "description": "End time for baseline comparison period (date string in utc timezone, e.g., '2025-06-24 07:51:27')" + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeStart" }, "selectionTimeRangeStart": { "type": "string", - "description": "Start time for analysis target period (date string in utc timezone, e.g., '2025-06-24 07:50:26')" + "description": "Start of target/incident period (format: yyyy-MM-dd HH:mm:ss)" }, "selectionTimeRangeEnd": { "type": "string", - "description": "End time for analysis target period (date string in utc timezone, e.g., '2025-06-24 07:55:56')" + "description": "End of target/incident period (format: yyyy-MM-dd HH:mm:ss)" + }, + "filter": { + "type": "string", + "description": "PPL boolean expression to filter logs (e.g. serviceName='ts-auth-service' or severity='ERROR'). Applied as additional where clause. Not applicable for sequence analysis mode (when traceFieldName is provided with baseline), as sequence analysis requires all logs within a trace" } }, "required": [ @@ -148,7 +157,8 @@ public class LogPatternAnalysisTool implements Tool { } """; - public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); + public static final Map DEFAULT_ATTRIBUTES = Map + .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class)), STRICT_FIELD, false); // Compiled regex patterns for better performance private static final Pattern REPEATED_WILDCARDS_PATTERN = Pattern.compile("(<\\*>)(\\s+<\\*>)+"); @@ -165,6 +175,7 @@ private static class AnalysisParameters { final String baseTimeRangeEnd; final String selectionTimeRangeStart; final String selectionTimeRangeEnd; + final String filter; AnalysisParameters(Map parameters) { this.index = parameters.getOrDefault("index", ""); @@ -175,6 +186,7 @@ private static class AnalysisParameters { this.baseTimeRangeEnd = parameters.getOrDefault("baseTimeRangeEnd", ""); this.selectionTimeRangeStart = parameters.getOrDefault("selectionTimeRangeStart", ""); this.selectionTimeRangeEnd = parameters.getOrDefault("selectionTimeRangeEnd", ""); + this.filter = parameters.getOrDefault("filter", ""); } private void validate() { @@ -254,7 +266,7 @@ public String getType() { @Override public Map getAttributes() { - return Map.of(); + return DEFAULT_ATTRIBUTES; } @Override @@ -299,9 +311,12 @@ public void run(Map originalParameters, ActionListener li } private void logSequenceAnalysis(AnalysisParameters params, ActionListener listener) { + if (!Strings.isEmpty(params.filter)) { + log.warn("Filter parameter is ignored for sequence analysis mode as it requires all logs within a trace"); + } // Step 1: Analyze selection time range analyzeSelectionTimeRange(params, ActionListener.wrap(selectionResult -> { - log.debug("Base time range analysis completed, found {} traces", selectionResult.tracePatternMap.size()); + log.debug("Selection time range analysis completed, found {} traces", selectionResult.tracePatternMap.size()); if (selectionResult.tracePatternMap.isEmpty()) { Map> emptyResult = buildFinalResult( @@ -316,7 +331,7 @@ private void logSequenceAnalysis(AnalysisParameters params, ActionListener { - log.debug("Selection time range analysis completed, found {} traces", baseResult.tracePatternMap.size()); + log.debug("Base time range analysis completed, found {} traces", baseResult.tracePatternMap.size()); // Step 3: Generate comparison result generateSequenceComparisonResult(baseResult, selectionResult, listener); @@ -334,7 +349,8 @@ private void analyzeBaseTimeRange(AnalysisParameters params, ActionListener'%s' and %s<'%s' | patterns %s method=brain " - + "variable_count_threshold=3 | fields %s, patterns_field, %s | sort %s", - index, - traceFieldName, - timeField, - startTime, - timeField, - endTime, - logFieldName, - traceFieldName, - timeField, - timeField - ); + String filterClause = Strings.isEmpty(filter) ? "" : String.format(Locale.ROOT, " | where %s", filter); + + String pplTemplate = + "source={INDEX} | where {TRACE_FIELD}!='' | where {TIME_FIELD}>'{START_TIME}' and {TIME_FIELD}<'{END_TIME}'{FILTER} " + + "| fields {TRACE_FIELD}, {LOG_FIELD}, {TIME_FIELD} | patterns {LOG_FIELD} method=brain variable_count_threshold=3 " + + "| fields {TRACE_FIELD}, patterns_field, {TIME_FIELD} | sort {TIME_FIELD}"; + + return pplTemplate + .replace("{INDEX}", index) + .replace("{TRACE_FIELD}", traceFieldName) + .replace("{TIME_FIELD}", timeField) + .replace("{START_TIME}", startTime) + .replace("{END_TIME}", endTime) + .replace("{FILTER}", filterClause) + .replace("{LOG_FIELD}", logFieldName); } private Map vectorizePattern(Map> patternCountMap, int totalTraceCount) { @@ -616,7 +633,8 @@ private void logPatternDiffAnalysis(AnalysisParameters params, ActionListene params.timeField, params.logFieldName, params.baseTimeRangeStart, - params.baseTimeRangeEnd + params.baseTimeRangeEnd, + params.filter ); Function>, Map> dataRowsParser = dataRows -> { Map patternMap = new HashMap<>(); @@ -648,7 +666,8 @@ private void logPatternDiffAnalysis(AnalysisParameters params, ActionListene params.timeField, params.logFieldName, params.selectionTimeRangeStart, - params.selectionTimeRangeEnd + params.selectionTimeRangeEnd, + params.filter ); log.debug("Executing selection time range pattern PPL: {}", selectionTimeRangeLogPatternPPL); @@ -749,22 +768,23 @@ private void logInsight(AnalysisParameters params, ActionListener listene "violation" ); - String selectionTimeRangeLogPatternPPL = String - .format( - Locale.ROOT, - "source=%s | where %s>'%s' and %s<'%s' | where match(%s, '%s') | patterns %s method=brain " - + "mode=aggregation max_sample_count=2 " - + "variable_count_threshold=3 | fields patterns_field, pattern_count, sample_logs " - + "| sort -pattern_count | head 5", - params.index, - params.timeField, - params.selectionTimeRangeStart, - params.timeField, - params.selectionTimeRangeEnd, - params.logFieldName, - String.join(" ", errorKeywords), - params.logFieldName - ); + String filterClause = Strings.isEmpty(params.filter) ? "" : String.format(Locale.ROOT, " | where %s", params.filter); + + String pplTemplate = "source={INDEX} | where {TIME_FIELD}>'{START_TIME}' and {TIME_FIELD}<'{END_TIME}'{FILTER} " + + "| where match({LOG_FIELD}, '{ERROR_KEYWORDS}') | head " + + MAX_LOG_SAMPLE_SIZE + + " | fields {LOG_FIELD} | patterns {LOG_FIELD} method=brain " + + "mode=aggregation max_sample_count=5 variable_count_threshold=3 " + + "| fields patterns_field, pattern_count, sample_logs | sort -pattern_count | head 5"; + + String selectionTimeRangeLogPatternPPL = pplTemplate + .replace("{INDEX}", params.index) + .replace("{TIME_FIELD}", params.timeField) + .replace("{START_TIME}", params.selectionTimeRangeStart) + .replace("{END_TIME}", params.selectionTimeRangeEnd) + .replace("{FILTER}", filterClause) + .replace("{LOG_FIELD}", params.logFieldName) + .replace("{ERROR_KEYWORDS}", String.join(" ", errorKeywords)); Function>, List> dataRowsParser = dataRows -> { List patternWithSamplesList = new ArrayList<>(); @@ -800,19 +820,27 @@ private void logInsight(AnalysisParameters params, ActionListener listene ); } - private String buildLogPatternPPL(String index, String timeField, String logFieldName, String startTime, String endTime) { - return String - .format( - Locale.ROOT, - "source=%s | where %s>'%s' and %s<'%s' | patterns %s method=brain mode=aggregation " - + "variable_count_threshold=3 | fields pattern_count, patterns_field", - index, - timeField, - startTime, - timeField, - endTime, - logFieldName - ); + private String buildLogPatternPPL( + String index, + String timeField, + String logFieldName, + String startTime, + String endTime, + String filter + ) { + String filterClause = Strings.isEmpty(filter) ? "" : String.format(Locale.ROOT, " | where %s", filter); + + String pplTemplate = "source={INDEX} | where {TIME_FIELD}>'{START_TIME}' and {TIME_FIELD}<'{END_TIME}'{FILTER} " + + "| fields {LOG_FIELD} | patterns {LOG_FIELD} method=brain mode=aggregation variable_count_threshold=3 " + + "| fields pattern_count, patterns_field"; + + return pplTemplate + .replace("{INDEX}", index) + .replace("{TIME_FIELD}", timeField) + .replace("{START_TIME}", startTime) + .replace("{END_TIME}", endTime) + .replace("{FILTER}", filterClause) + .replace("{LOG_FIELD}", logFieldName); } private List calculatePatternDifferences(Map basePatterns, Map selectionPatterns) { diff --git a/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java new file mode 100644 index 00000000..d2af20d1 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java @@ -0,0 +1,557 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.transport.client.Client; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Tool for analyzing metric changes by comparing percentile distributions between time periods. + * Uses relative change (percentage change) to rank fields by significance. + * + * Usage: + * POST /_plugins/_ml/agents/{agent_id}/_execute + * { + * "parameters": { + * "index": "logs-2025.01.15", + * "timeField": "@timestamp", + * "selectionTimeRangeStart": "2025-01-15 10:00:00", + * "selectionTimeRangeEnd": "2025-01-15 11:00:00", + * "baselineTimeRangeStart": "2025-01-15 08:00:00", + * "baselineTimeRangeEnd": "2025-01-15 09:00:00", + * "size": 1000 + * } + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(MetricChangeAnalysisTool.TYPE) +public class MetricChangeAnalysisTool implements Tool { + public static final String TYPE = "MetricChangeAnalysisTool"; + + private static final String DEFAULT_DESCRIPTION = + "Compares percentile distributions (P50, P90) of numeric fields between two time ranges. " + + "Returns top fields ranked by change score. " + + "Use for root cause analysis when investigating metric anomalies. " + + "Keep both time ranges short (e.g. 15-30 minutes) and similar in duration for accurate comparison."; + + private static final int DEFAULT_TOP_N = 10; + + public static final String DEFAULT_INPUT_SCHEMA = + """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "Target OpenSearch index name" + }, + "timeField": { + "type": "string", + "description": "Date/time field for filtering" + }, + "selectionTimeRangeStart": { + "type": "string", + "description": "Start of target period (format: yyyy-MM-dd HH:mm:ss)" + }, + "selectionTimeRangeEnd": { + "type": "string", + "description": "End of target period (format: yyyy-MM-dd HH:mm:ss)" + }, + "baselineTimeRangeStart": { + "type": "string", + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss)" + }, + "baselineTimeRangeEnd": { + "type": "string", + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Should be at or before selectionTimeRangeStart" + }, + "size": { + "type": "integer", + "description": "Maximum number of documents to analyze (default: 1000, max: 10000)" + }, + "topN": { + "type": "integer", + "description": "Number of top fields to return, ranked by change score (default: 10)" + }, + "queryType": { + "type": "string", + "description": "Query type: 'ppl' or 'dsl' (default: 'dsl')" + }, + "filter": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Additional DSL query conditions (optional)" + }, + "dsl": { + "type": "string", + "description": "Complete raw DSL query as JSON string (optional)" + }, + "ppl": { + "type": "string", + "description": "Complete PPL statement without time information (optional)" + } + }, + "required": ["index", "timeField", "selectionTimeRangeStart", "selectionTimeRangeEnd", "baselineTimeRangeStart", "baselineTimeRangeEnd"] + } + """; + + public static final Map DEFAULT_ATTRIBUTES = Map + .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class))); + + /** + * Record for percentile statistics + */ + private record PercentileStats(double p50, double p90) { + } + + /** + * Record for field percentile analysis with variance + */ + private record FieldPercentileAnalysis(String field, double variance, PercentileStats selectionStats, PercentileStats baselineStats) { + } + + /** + * Result item for JSON output + */ + private record PercentileAnalysisResult(String field, Double changeScore, Map selectionPercentiles, + Map baselinePercentiles, Map logRatios) { + } + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String version; + private Client client; + private DataFetchingHelper dataFetchingHelper; + + /** + * Constructs a MetricChangeAnalysisTool with the given OpenSearch client + * + * @param client The OpenSearch client for executing queries + */ + public MetricChangeAnalysisTool(Client client) { + this.client = client; + this.dataFetchingHelper = new DataFetchingHelper(client); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public Map getAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public void setAttributes(Map map) {} + + @Override + public boolean validate(Map parameters) { + try { + // Use helper's validation logic + DataFetchingHelper.AnalysisParameters params = new DataFetchingHelper.AnalysisParameters(parameters); + params.validate(); + return true; + } catch (IllegalArgumentException e) { + log.error("Invalid parameters: {}", e.getMessage()); + return false; + } + } + + /** + * Executes percentile analysis on numeric fields between selection and baseline periods + * + * @param The response type + * @param originalParameters Input parameters for analysis + * @param listener Action listener for handling results or failures + */ + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, DEFAULT_ATTRIBUTES); + log.debug("Starting metric change analysis with parameters: {}", parameters.keySet()); + + // Extract topN parameter (use Double.valueOf to handle "30.0" from Gson serialization) + int topN = DEFAULT_TOP_N; + String topNStr = parameters.get("topN"); + if (topNStr != null && !topNStr.isEmpty()) { + try { + topN = Double.valueOf(topNStr).intValue(); + if (topN <= 0) { + topN = DEFAULT_TOP_N; + } + } catch (NumberFormatException e) { + log.warn("Invalid topN parameter '{}', using default: {}", topNStr, DEFAULT_TOP_N); + } + } + + // Use DataDistributionTool's data fetching mechanism + fetchDataAndAnalyze(parameters, topN, listener); + + } catch (IllegalArgumentException e) { + log.error("Invalid parameters for MetricChangeAnalysisTool: {}", e.getMessage()); + listener.onFailure(e); + } catch (Exception e) { + log.error("Unexpected error in MetricChangeAnalysisTool", e); + listener.onFailure(e); + } + } + + /** + * Fetches data using DataDistributionTool's mechanism and performs percentile analysis + */ + private void fetchDataAndAnalyze(Map parameters, int topN, ActionListener listener) { + try { + // Create analysis parameters + DataFetchingHelper.AnalysisParameters params = new DataFetchingHelper.AnalysisParameters(parameters); + + // Get field types first + String index = parameters.get("index"); + dataFetchingHelper.getFieldTypes(index, ActionListener.wrap((Map fieldTypes) -> { + // Get number fields + Set numberFields = dataFetchingHelper.getNumberFields(fieldTypes); + + if (numberFields.isEmpty()) { + listener + .onFailure( + new IllegalStateException("No numeric fields found in index. Percentile analysis requires numeric fields.") + ); + return; + } + + // Fetch selection and baseline data + fetchBothDatasets(params, numberFields, topN, listener); + }, listener::onFailure)); + + } catch (Exception e) { + log.error("Failed to fetch data for percentile analysis", e); + listener.onFailure(e); + } + } + + /** + * Fetches both selection and baseline datasets + */ + private void fetchBothDatasets( + DataFetchingHelper.AnalysisParameters params, + Set numberFields, + int topN, + ActionListener listener + ) { + try { + // Fetch selection data + dataFetchingHelper + .fetchIndexData( + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd, + params, + ActionListener.wrap((List> selectionData) -> { + // Fetch baseline data + dataFetchingHelper + .fetchIndexData( + params.baselineTimeRangeStart, + params.baselineTimeRangeEnd, + params, + ActionListener.wrap((List> baselineData) -> { + // Perform metric change analysis + performAnalysis(selectionData, baselineData, numberFields, topN, listener); + }, listener::onFailure) + ); + }, listener::onFailure) + ); + + } catch (Exception e) { + log.error("Failed to fetch datasets", e); + listener.onFailure(e); + } + } + + /** + * Performs metric change analysis on the fetched data + */ + private void performAnalysis( + List> selectionData, + List> baselineData, + Set numberFields, + int topN, + ActionListener listener + ) { + try { + if (selectionData.isEmpty()) { + listener.onFailure(new IllegalStateException("No data found for selection time range")); + return; + } + + if (baselineData.isEmpty()) { + listener.onFailure(new IllegalStateException("No data found for baseline time range")); + return; + } + + // Calculate percentiles and relative changes + List analyses = calculateMetricChangeAnalysis(selectionData, baselineData, numberFields); + List results = formatResults(analyses, topN); + listener.onResponse((T) gson.toJson(Map.of("percentileAnalysis", results))); + + } catch (Exception e) { + log.error("Failed to perform metric change analysis", e); + listener.onFailure(e); + } + } + + /** + * Formats analysis results for JSON output, limiting to top N results + */ + private List formatResults(List analyses, int topN) { + return analyses.stream().limit(topN).map(analysis -> { + Map selectionStats = Map.of("p50", analysis.selectionStats.p50, "p90", analysis.selectionStats.p90); + + Map baselineStats = Map.of("p50", analysis.baselineStats.p50, "p90", analysis.baselineStats.p90); + + Map logRatios = Map + .of( + "p50", + safeLogRatio(analysis.selectionStats.p50, analysis.baselineStats.p50), + "p90", + safeLogRatio(analysis.selectionStats.p90, analysis.baselineStats.p90) + ); + + return new PercentileAnalysisResult(analysis.field, analysis.variance, selectionStats, baselineStats, logRatios); + }).toList(); + } + + // ========== Metric Change Analysis Functions ========== + + /** + * Calculates metric change analysis for all numeric fields + */ + private List calculateMetricChangeAnalysis( + List> selectionData, + List> baselineData, + Set numberFields + ) { + List analyses = new ArrayList<>(); + + for (String field : numberFields) { + List selectionValues = extractNumericValues(selectionData, field); + List baselineValues = extractNumericValues(baselineData, field); + + if (selectionValues.isEmpty() || baselineValues.isEmpty()) { + continue; + } + + PercentileStats selectionStats = calculatePercentiles(selectionValues); + PercentileStats baselineStats = calculatePercentiles(baselineValues); + double variance = calculatePercentileVariance(selectionStats, baselineStats); + + analyses.add(new FieldPercentileAnalysis(field, variance, selectionStats, baselineStats)); + } + + analyses.sort(Comparator.comparingDouble((FieldPercentileAnalysis a) -> a.variance).reversed()); + return analyses; + } + + /** + * Extracts numeric values from dataset for a specific field + */ + private List extractNumericValues(List> data, String field) { + List values = new ArrayList<>(); + + for (Map doc : data) { + Object value = getFlattenedValue(doc, field); + if (value != null) { + try { + if (value instanceof Number) { + values.add(((Number) value).doubleValue()); + } else { + values.add(Double.parseDouble(value.toString())); + } + } catch (NumberFormatException e) { + // Skip non-numeric values + } + } + } + + return values; + } + + /** + * Extracts nested field values from document using dot notation + */ + private Object getFlattenedValue(Map doc, String field) { + return dataFetchingHelper.getFlattenedValue(doc, field); + } + + /** + * Calculates statistics (avg, P50, P90) for a list of values + */ + private PercentileStats calculatePercentiles(List values) { + if (values.isEmpty()) { + return new PercentileStats(0.0, 0.0); + } + + List sorted = new ArrayList<>(values); + sorted.sort(Double::compareTo); + + double p50 = calculatePercentile(sorted, 50); + double p90 = calculatePercentile(sorted, 90); + + return new PercentileStats(p50, p90); + } + + /** + * Calculates a specific percentile from sorted values + */ + private double calculatePercentile(List sortedValues, int percentile) { + if (sortedValues.isEmpty()) { + return 0.0; + } + + if (sortedValues.size() == 1) { + return sortedValues.get(0); + } + + double index = (percentile / 100.0) * (sortedValues.size() - 1); + int lowerIndex = (int) Math.floor(index); + int upperIndex = (int) Math.ceil(index); + + if (lowerIndex == upperIndex) { + return sortedValues.get(lowerIndex); + } + + double lowerValue = sortedValues.get(lowerIndex); + double upperValue = sortedValues.get(upperIndex); + double fraction = index - lowerIndex; + + return lowerValue + (upperValue - lowerValue) * fraction; + } + + private static final double LOG_RATIO_CAP = 10.0; + private static final double EPSILON = 1e-10; + + /** + * Calculates change score between selection and baseline statistics. + * Uses weighted log-ratio scoring on avg, P50, and P90. + * Skips a metric if its baseline is near zero to avoid inflated scores. + * Redistributes weight equally among valid metrics. + */ + private double calculatePercentileVariance(PercentileStats selectionStats, PercentileStats baselineStats) { + boolean p50Valid = Math.abs(baselineStats.p50) >= EPSILON; + boolean p90Valid = Math.abs(baselineStats.p90) >= EPSILON; + + if (!p50Valid && !p90Valid) { + return 0.0; + } + if (p50Valid && p90Valid) { + return 0.5 * safeLogRatio(selectionStats.p50, baselineStats.p50) + 0.5 * safeLogRatio(selectionStats.p90, baselineStats.p90); + } + if (p50Valid) { + return safeLogRatio(selectionStats.p50, baselineStats.p50); + } + return safeLogRatio(selectionStats.p90, baselineStats.p90); + } + + /** + * Computes |log(selection / baseline)| with safe handling of near-zero values. + * Returns 0 when both values are near zero, caps at LOG_RATIO_CAP when only baseline is near zero. + */ + private double safeLogRatio(double selection, double baseline) { + if (Math.abs(baseline) < EPSILON && Math.abs(selection) < EPSILON) { + return 0.0; + } + if (Math.abs(baseline) < EPSILON) { + return LOG_RATIO_CAP; + } + double ratio = selection / baseline; + if (ratio <= 0) { + return 0.0; + } + return Math.abs(Math.log(ratio)); + } + + /** + * Factory class for creating MetricChangeAnalysisTool instances + */ + public static class Factory implements Tool.Factory { + private Client client; + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (MetricChangeAnalysisTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public MetricChangeAnalysisTool create(Map map) { + return new MetricChangeAnalysisTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java b/src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java new file mode 100644 index 00000000..104c0bc0 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java @@ -0,0 +1,370 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.transport.client.Client; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonObject; +import com.google.gson.JsonSyntaxException; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Tool to search N documents before and N documents after a specific document ID, + * ordered by a timestamp field using search_after pagination. + */ +@Getter +@Setter +@Log4j2 +@ToolAnnotation(SearchAroundDocumentTool.TYPE) +public class SearchAroundDocumentTool implements Tool { + + public static final String INPUT_FIELD = "input"; + public static final String INDEX_FIELD = "index"; + public static final String DOC_ID_FIELD = "doc_id"; + public static final String TIMESTAMP_FIELD = "timestamp_field"; + public static final String COUNT_FIELD = "count"; + public static final String INPUT_SCHEMA_FIELD = "input_schema"; + + public static final String TYPE = "SearchAroundDocumentTool"; + private static final String DEFAULT_DESCRIPTION = """ + Use this tool to search documents around a specific document by providing: \ + 'index' for the index name, 'doc_id' for the target document ID, \ + 'timestamp_field' for the field to order by, and 'count' for the number of documents before and after. \ + Returns N documents before, the target document, and N documents after."""; + + public static final String DEFAULT_INPUT_SCHEMA = """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "OpenSearch index name" + }, + "doc_id": { + "type": "string", + "description": "Target document ID" + }, + "timestamp_field": { + "type": "string", + "description": "Timestamp field for ordering (e.g., @timestamp)" + }, + "count": { + "type": "integer", + "description": "Number of documents before and after the target" + } + }, + "required": ["index", "doc_id", "timestamp_field", "count"], + "additionalProperties": false + } + """; + + private static final Gson GSON = new GsonBuilder().serializeSpecialFloatingPointValues().create(); + + public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); + + private String name = TYPE; + private Map attributes; + private String description = DEFAULT_DESCRIPTION; + + private Client client; + + private NamedXContentRegistry xContentRegistry; + + public SearchAroundDocumentTool(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + + this.attributes = new HashMap<>(); + attributes.put(INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.isEmpty()) { + return false; + } + boolean argumentsFromInput = parameters.containsKey(INPUT_FIELD) && !StringUtils.isEmpty(parameters.get(INPUT_FIELD)); + boolean argumentsFromParameters = parameters.containsKey(INDEX_FIELD) + && parameters.containsKey(DOC_ID_FIELD) + && parameters.containsKey(TIMESTAMP_FIELD) + && parameters.containsKey(COUNT_FIELD) + && !StringUtils.isEmpty(parameters.get(INDEX_FIELD)) + && !StringUtils.isEmpty(parameters.get(DOC_ID_FIELD)) + && !StringUtils.isEmpty(parameters.get(TIMESTAMP_FIELD)) + && !StringUtils.isEmpty(parameters.get(COUNT_FIELD)); + boolean validRequest = argumentsFromInput || argumentsFromParameters; + if (!validRequest) { + log.error("SearchAroundDocumentTool requires: index, doc_id, timestamp_field, and count parameters!"); + return false; + } + return true; + } + + private static Map processResponse(SearchHit hit) { + Map docContent = new HashMap<>(); + docContent.put("_index", hit.getIndex()); + docContent.put("_id", hit.getId()); + docContent.put("_score", hit.getScore()); + docContent.put("_source", hit.getSourceAsMap()); + if (hit.getSortValues() != null && hit.getSortValues().length > 0) { + docContent.put("sort", hit.getSortValues()); + } + return docContent; + } + + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); + String input = parameters.get(INPUT_FIELD); + String index = null; + String docId = null; + String timestampField = null; + Integer count = null; + + if (!StringUtils.isEmpty(input)) { + try { + JsonObject jsonObject = GSON.fromJson(input, JsonObject.class); + if (jsonObject != null) { + if (jsonObject.has(INDEX_FIELD)) { + index = jsonObject.get(INDEX_FIELD).getAsString(); + } + if (jsonObject.has(DOC_ID_FIELD)) { + docId = jsonObject.get(DOC_ID_FIELD).getAsString(); + } + if (jsonObject.has(TIMESTAMP_FIELD)) { + timestampField = jsonObject.get(TIMESTAMP_FIELD).getAsString(); + } + if (jsonObject.has(COUNT_FIELD)) { + count = jsonObject.get(COUNT_FIELD).getAsInt(); + } + } + } catch (JsonSyntaxException e) { + log.error("Invalid JSON input: {}", input, e); + } + } + + // Fall back to direct parameters + if (StringUtils.isEmpty(index)) { + index = parameters.get(INDEX_FIELD); + } + if (StringUtils.isEmpty(docId)) { + docId = parameters.get(DOC_ID_FIELD); + } + if (StringUtils.isEmpty(timestampField)) { + timestampField = parameters.get(TIMESTAMP_FIELD); + } + if (count == null && parameters.containsKey(COUNT_FIELD)) { + try { + count = Double.valueOf(parameters.get(COUNT_FIELD)).intValue(); + } catch (NumberFormatException e) { + log.error("Invalid count parameter: {}", parameters.get(COUNT_FIELD), e); + } + } + + // Validate all required parameters + if (StringUtils.isEmpty(index) || StringUtils.isEmpty(docId) || StringUtils.isEmpty(timestampField) || count == null) { + listener + .onFailure( + new IllegalArgumentException( + "SearchAroundDocumentTool requires: index, doc_id, timestamp_field, and count parameters" + ) + ); + return; + } + + final String finalIndex = index; + final String finalDocId = docId; + final String finalTimestampField = timestampField; + final int finalCount = count; + + // Step 1: Fetch the target document by ID with sort values + log + .debug( + "Calling SearchAroundDocumentTool: index={}, doc_id={}, timestamp_field={}, count={}", + index, + docId, + timestampField, + count + ); + + SearchSourceBuilder targetSource = new SearchSourceBuilder() + .query(QueryBuilders.idsQuery().addIds(docId)) + .sort(new FieldSortBuilder(timestampField).order(SortOrder.DESC).unmappedType("boolean")) + .sort(new FieldSortBuilder("_doc").order(SortOrder.DESC).unmappedType("boolean")) + .size(1); + SearchRequest targetRequest = new SearchRequest(index).source(targetSource); + + client.search(targetRequest, ActionListener.wrap(targetResponse -> { + SearchHit[] targetHits = targetResponse.getHits().getHits(); + if (targetHits == null || targetHits.length == 0) { + listener.onFailure(new IllegalArgumentException("Document not found: " + finalDocId)); + return; + } + + SearchHit targetHit = targetHits[0]; + Object[] sortValues = targetHit.getSortValues(); + if (sortValues == null || sortValues.length < 2) { + listener.onFailure(new IllegalArgumentException("Could not get sort values from target document")); + return; + } + + // Build target document response + Map targetDoc = processResponse(targetHit); + + // Step 2: Search for documents BEFORE using search_after with DESC sort + BoolQueryBuilder beforeQuery = QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(finalDocId)); + + SearchSourceBuilder beforeSource = new SearchSourceBuilder() + .query(beforeQuery) + .sort(new FieldSortBuilder(finalTimestampField).order(SortOrder.DESC).unmappedType("boolean")) + .sort(new FieldSortBuilder("_doc").order(SortOrder.DESC).unmappedType("boolean")) + .searchAfter(sortValues) + .size(finalCount); + SearchRequest beforeRequest = new SearchRequest(finalIndex).source(beforeSource); + + client.search(beforeRequest, ActionListener.wrap(beforeResponse -> { + // Step 3: Search for documents AFTER using search_after with ASC sort + BoolQueryBuilder afterQuery = QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(finalDocId)); + + SearchSourceBuilder afterSource = new SearchSourceBuilder() + .query(afterQuery) + .sort(new FieldSortBuilder(finalTimestampField).order(SortOrder.ASC).unmappedType("boolean")) + .sort(new FieldSortBuilder("_doc").order(SortOrder.ASC).unmappedType("boolean")) + .searchAfter(sortValues) + .size(finalCount); + SearchRequest afterRequest = new SearchRequest(finalIndex).source(afterSource); + + client.search(afterRequest, ActionListener.wrap(afterSearchResponse -> { + + // Process "before" results (need to reverse to get chronological order) + SearchHit[] beforeHits = beforeResponse.getHits().getHits(); + List> beforeDocs = new ArrayList<>(); + for (SearchHit hit : beforeHits) { + beforeDocs.add(processResponse(hit)); + } + Collections.reverse(beforeDocs); + List> result = new ArrayList<>(beforeDocs); + + // Add target document + result.add(targetDoc); + + // Process "after" results + SearchHit[] afterHits = afterSearchResponse.getHits().getHits(); + for (SearchHit hit : afterHits) { + result.add(processResponse(hit)); + } + + String resultJson = GSON.toJson(result); + listener.onResponse((T) resultJson); + }, e -> { + log.error("Failed to search for documents after target", e); + listener.onFailure(e); + })); + }, e -> { + log.error("Failed to search for documents before target", e); + listener.onFailure(e); + })); + }, e -> { + log.error("Failed to fetch target document", e); + listener.onFailure(e); + })); + } catch (Exception e) { + log.error("Failed to run SearchAroundDocumentTool", e); + listener.onFailure(e); + } + } + + public static class Factory implements Tool.Factory { + + private Client client; + private static Factory INSTANCE; + + private NamedXContentRegistry xContentRegistry; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchAroundDocumentTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public SearchAroundDocumentTool create(Map params) { + return new SearchAroundDocumentTool(client, xContentRegistry); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java index 047d6769..a1d2fa7d 100644 --- a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java +++ b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java @@ -8,6 +8,9 @@ import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -17,24 +20,25 @@ import java.util.Optional; import org.apache.commons.lang3.math.NumberUtils; -import org.apache.hc.client5.http.classic.methods.HttpGet; -import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; -import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; -import org.apache.hc.client5.http.impl.classic.HttpClients; import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.io.entity.EntityUtils; import org.jsoup.Connection; import org.jsoup.Jsoup; import org.jsoup.nodes.Document; import org.jsoup.nodes.Element; import org.jsoup.select.Elements; +import org.opensearch.OpenSearchStatusException; import org.opensearch.agent.ToolPlugin; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.threadpool.ThreadPool; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import com.google.common.collect.ImmutableMap; import com.google.gson.JsonArray; @@ -45,6 +49,14 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.http.async.AsyncExecuteRequest; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; @Log4j2 @Setter @@ -77,6 +89,28 @@ public class WebSearchTool implements Tool { + "}"; public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, "strict", false); + public static final String NEXT_PAGE = "next_page"; + public static final String ENGINE_ID = "engine_id"; + public static final String OFFSET = "offset"; + public static final String DUCKDUCKGO = "duckduckgo"; + public static final String GOOGLE = "google"; + public static final String BING = "bing"; + public static final String CUSTOM = "custom"; + public static final String ITEMS = "items"; + public static final String ENGINE = "engine"; + public static final String ENDPOINT = "endpoint"; + public static final String API_KEY = "api_key"; + public static final String CUSTOM_API = "custom_api"; + public static final String AUTHORIZATION = "Authorization"; + public static final String TITLE = "title"; + public static final String URL = "url"; + public static final String CONTENT = "content"; + public static final String QUERY = "query"; + public static final String QUESTION = "question"; + public static final String QUERY_KEY = "query_key"; + public static final String LIMIT_KEY = "limit_key"; + public static final String CUSTOM_RES_URL_JSONPATH = "custom_res_url_jsonpath"; + public static final String START = "start"; @Setter @Getter @@ -86,13 +120,15 @@ public class WebSearchTool implements Tool { private String description = DEFAULT_DESCRIPTION; @Getter private String version; - private CloseableHttpClient httpClient; + private final SdkAsyncHttpClient httpClient; private final ThreadPool threadPool; private Map attributes; public WebSearchTool(ThreadPool threadPool) { - this.httpClient = HttpClients.createDefault(); + // Use 1s for connection timeout, 3s for read timeout, 30 for max connections of httpclient. + // For WebSearchTool, we don't allow user to connect to private ip. + this.httpClient = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(1), Duration.ofSeconds(3), 30, false); this.threadPool = threadPool; this.attributes = new HashMap<>(); attributes.put(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); @@ -101,103 +137,106 @@ public WebSearchTool(ThreadPool threadPool) { @Override public void run(Map originalParameters, ActionListener listener) { - Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); // common search parameters - String query = parameters.getOrDefault("query", parameters.get("question")).replaceAll(" ", "+"); - String engine = parameters.getOrDefault("engine", "google"); - String endpoint = parameters.getOrDefault("endpoint", getDefaultEndpoint(engine)); - String apiKey = parameters.get("api_key"); - String nextPage = parameters.get("next_page"); + String query = parameters.getOrDefault(QUERY, parameters.get(QUESTION)).replaceAll(" ", "+"); + String engine = parameters.getOrDefault(ENGINE, GOOGLE); + String endpoint = parameters.getOrDefault(ENDPOINT, getDefaultEndpoint(engine)); + String apiKey = parameters.get(API_KEY); + String nextPage = parameters.get(NEXT_PAGE); // Google search parameters - String engineId = parameters.get("engine_id"); + String engineId = parameters.get(ENGINE_ID); // Custom search parameters - String authorization = parameters.get("Authorization"); - String queryKey = parameters.getOrDefault("query_key", "q"); - String offsetKey = parameters.getOrDefault("offset_key", "offset"); - String limitKey = parameters.getOrDefault("limit_key", "limit"); - String customResUrlJsonpath = parameters.get("custom_res_url_jsonpath"); + String authorization = parameters.get(AUTHORIZATION); + String queryKey = parameters.getOrDefault(QUERY_KEY, "q"); + String offsetKey = parameters.getOrDefault(OFFSET + "_key", OFFSET); + String limitKey = parameters.getOrDefault(LIMIT_KEY, "limit"); + String customResUrlJsonpath = parameters.get(CUSTOM_RES_URL_JSONPATH); + threadPool.executor(ToolPlugin.WEBSEARCH_CRAWLER_THREADPOOL).submit(() -> { - try { - String parsedNextPage = null; - if ("duckduckgo".equalsIgnoreCase(engine)) { - // duckduckgo has different approach to other APIs as it's not a standard public API. + String parsedNextPage; + if (DUCKDUCKGO.equalsIgnoreCase(engine)) { + // duckduckgo has different approach to other APIs as it's not a standard public API. + if (nextPage != null) { + fetchDuckDuckGoResult(nextPage, listener); + } else { + fetchDuckDuckGoResult(buildDDGEndpoint(getDefaultEndpoint(engine), query), listener); + } + } else { + SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder().method(SdkHttpMethod.GET); + if (GOOGLE.equalsIgnoreCase(engine)) { if (nextPage != null) { - fetchDuckDuckGoResult(nextPage, listener); + builder.uri(nextPage); + parsedNextPage = buildGoogleNextPage(endpoint, engineId, query, apiKey, nextPage); } else { - fetchDuckDuckGoResult(buildDDGEndpoint(getDefaultEndpoint(engine), query), listener); + builder.uri(buildGoogleUrl(endpoint, engineId, query, apiKey, 0)); + parsedNextPage = buildGoogleUrl(endpoint, engineId, query, apiKey, 10); } - } else { - HttpGet getRequest = null; - if ("google".equalsIgnoreCase(engine)) { - if (nextPage != null) { - getRequest = new HttpGet(nextPage); - parsedNextPage = buildGoogleNextPage(endpoint, engineId, query, apiKey, nextPage); - } else { - getRequest = new HttpGet(buildGoogleUrl(endpoint, engineId, query, apiKey, 0)); - parsedNextPage = buildGoogleUrl(endpoint, engineId, query, apiKey, 10); - } - } else if ("bing".equalsIgnoreCase(engine)) { - if (nextPage != null) { - getRequest = new HttpGet(nextPage); - parsedNextPage = buildBingNextPage(endpoint, query, nextPage); - } else { - getRequest = new HttpGet(buildBingUrl(endpoint, query, 0)); - parsedNextPage = buildBingUrl(endpoint, query, 10); - } - getRequest.addHeader("Ocp-Apim-Subscription-Key", apiKey); - } else if ("custom".equalsIgnoreCase(engine)) { - if (nextPage != null) { - getRequest = new HttpGet(nextPage); - parsedNextPage = buildCustomNextPage(endpoint, nextPage, queryKey, query, offsetKey, limitKey); - } else { - getRequest = new HttpGet(buildCustomUrl(endpoint, queryKey, query, offsetKey, 0, limitKey)); - parsedNextPage = buildCustomUrl(endpoint, queryKey, query, offsetKey, 10, limitKey); - } - getRequest.addHeader("Authorization", authorization); + } else if (BING.equalsIgnoreCase(engine)) { + if (nextPage != null) { + builder.uri(nextPage); + parsedNextPage = buildBingNextPage(endpoint, query, nextPage); } else { - // Search engine not supported. - listener.onFailure(new IllegalArgumentException("Unsupported search engine: %s".formatted(engine))); - return; + builder.uri(buildBingUrl(endpoint, query, 0)); + parsedNextPage = buildBingUrl(endpoint, query, 10); } - CloseableHttpResponse res = httpClient.execute(getRequest); - if (res.getCode() >= HttpStatus.SC_BAD_REQUEST) { - listener - .onFailure( - new IllegalArgumentException("Web search failed: %d %s".formatted(res.getCode(), res.getReasonPhrase())) - ); + builder.putHeader("Ocp-Apim-Subscription-Key", apiKey); + } else if (CUSTOM.equalsIgnoreCase(engine)) { + if (nextPage != null) { + builder.uri(nextPage); + parsedNextPage = buildCustomNextPage(endpoint, nextPage, queryKey, query, offsetKey, limitKey); } else { - String responseString = EntityUtils.toString(res.getEntity()); - parseResponse(responseString, authorization, parsedNextPage, engine, customResUrlJsonpath, listener); + builder.uri(buildCustomUrl(endpoint, queryKey, query, offsetKey, 0, limitKey)); + parsedNextPage = buildCustomUrl(endpoint, queryKey, query, offsetKey, 10, limitKey); } + builder.putHeader(AUTHORIZATION, authorization); + } else { + // Search engine not supported. + listener + .onFailure(new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported search engine: %s", engine))); + return; + } + SdkHttpFullRequest getRequest = builder.build(); + AsyncExecuteRequest executeRequest = AsyncExecuteRequest + .builder() + .request(getRequest) + .requestContentPublisher(new SimpleHttpContentPublisher(getRequest)) + .responseHandler( + new WebSearchResponseHandler(endpoint, authorization, parsedNextPage, engine, customResUrlJsonpath, listener) + ) + .build(); + try { + httpClient.execute(executeRequest); + } catch (Exception e) { + log.error("Web search failed!", e); + listener.onFailure(new IllegalStateException(String.format(Locale.ROOT, "Web search failed: %s", e.getMessage()))); } - } catch (Exception e) { - listener.onFailure(new IllegalStateException("Web search failed: %s".formatted(e.getMessage()))); } }); } catch (Exception e) { - listener.onFailure(new IllegalStateException("Web search failed: %s".formatted(e.getMessage()))); + listener.onFailure(new IllegalStateException(String.format(Locale.ROOT, "Web search failed: %s", e.getMessage()))); } } private String buildDDGEndpoint(String endpoint, String query) { - return "%s?q=%s".formatted(endpoint, query); + return String.format(Locale.ROOT, "%s?q=%s", endpoint, query); } private String buildGoogleNextPage(String endpoint, String engineId, String query, String apiKey, String currentPage) { - String[] offsetSplit = currentPage.split("&start="); + String[] offsetSplit = currentPage.split("&" + START + "="); int offset = NumberUtils.toInt(offsetSplit[1], 0) + 10; return buildGoogleUrl(endpoint, engineId, query, apiKey, offset); } private String buildGoogleUrl(String endpoint, String engineId, String query, String apiKey, int start) { - return "%s?q=%s&cx=%s&key=%s&start=%d".formatted(endpoint, query, engineId, apiKey, start); + return String.format(Locale.ROOT, "%s?q=%s&cx=%s&key=%s&" + START + "=%d", endpoint, query, engineId, apiKey, start); } private String buildBingNextPage(String endpoint, String query, String currentPage) { - String[] offsetSplit = currentPage.split("&offset="); + String[] offsetSplit = currentPage.split("&" + OFFSET + "="); int offset = NumberUtils.toInt(offsetSplit[1], 0) + 10; return buildBingUrl(endpoint, query, offset); } @@ -210,100 +249,28 @@ private String buildCustomNextPage( String offsetKey, String limitKey ) { - String[] pageSplit = currentPage.split("&%s=".formatted(offsetKey)); + String[] pageSplit = currentPage.split(String.format(Locale.ROOT, "&%s=", offsetKey)); int offsetValue = NumberUtils.toInt(pageSplit[1].split("&")[0], 0) + 10; return buildCustomUrl(endpoint, queryKey, query, offsetKey, offsetValue, limitKey); } private String buildCustomUrl(String endpoint, String queryKey, String query, String offsetKey, int offsetValue, String limitKey) { - return "%s?%s=%s&%s=%d&%s=10".formatted(endpoint, queryKey, query, offsetKey, offsetValue, limitKey); + return String.format(Locale.ROOT, "%s?%s=%s&%s=%d&%s=10", endpoint, queryKey, query, offsetKey, offsetValue, limitKey); } private String getDefaultEndpoint(String engine) { return switch (engine.toLowerCase(Locale.ROOT)) { - case "google" -> "https://customsearch.googleapis.com/customsearch/v1"; - case "bing" -> "https://api.bing.microsoft.com/v7.0/search"; - case "duckduckgo" -> "https://duckduckgo.com/html"; - case "custom" -> null; - default -> throw new IllegalArgumentException("Unsupported search engine: %s".formatted(engine)); + case GOOGLE -> "https://customsearch.googleapis.com/customsearch/v1"; + case BING -> "https://api.bing.microsoft.com/v7.0/search"; + case DUCKDUCKGO -> "https://duckduckgo.com/html"; + case CUSTOM -> null; + default -> throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported search engine: %s", engine)); }; } // pagination: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/page-results#paging-through-search-results private String buildBingUrl(String endpoint, String query, int offset) { - return "%s?q%s&textFormat=HTML&count=10&offset=%d".formatted(endpoint, query, offset); - } - - private void parseResponse( - String rawResponse, - String authorization, - String nextPage, - String engine, - String customResUrlJsonpath, - ActionListener listener - ) { - JsonObject rawJson = JsonParser.parseString(rawResponse).getAsJsonObject(); - switch (engine.toLowerCase(Locale.ROOT)) { - case "google": - parseGoogleResults(rawJson, nextPage, listener); - break; - case "bing": - parseBingResults(rawJson, nextPage, listener); - break; - case "custom": - List urls = JsonPath.read(rawResponse, customResUrlJsonpath); - parseCustomResults(urls, authorization, nextPage, listener); - break; - default: - listener.onFailure(new RuntimeException("Unsupported search engine: %s".formatted(engine))); - } - } - - private void parseGoogleResults(JsonObject googleResponse, String nextPage, ActionListener listener) { - Map results = new HashMap<>(); - results.put("next_page", nextPage); - // extract search results, each item is a search result: - // https://developers.google.com/custom-search/v1/reference/rest/v1/Search#result - JsonArray items = googleResponse.getAsJsonArray("items"); - List> crawlResults = new ArrayList<>(); - for (int i = 0; i < items.size(); i++) { - JsonObject item = items.get(i).getAsJsonObject(); - // extract the actual link for scrawl. - String link = item.get("link").getAsString(); - // extract title and content. - Map crawlResult = crawlPage(link, null); - crawlResults.add(crawlResult); - } - results.put("items", crawlResults); - listener.onResponse((T) StringUtils.gson.toJson(results)); - } - - private void parseBingResults(JsonObject bingResponse, String nextPage, ActionListener listener) { - Map results = new HashMap<>(); - results.put("next_page", nextPage); - List> crawlResults = new ArrayList<>(); - JsonArray values = bingResponse.get("webPages").getAsJsonObject().getAsJsonArray("value"); - for (int i = 0; i < values.size(); i++) { - JsonObject value = values.get(i).getAsJsonObject(); - String link = value.get("url").getAsString(); - Map crawlResult = crawlPage(link, null); - crawlResults.add(crawlResult); - } - results.put("items", crawlResults); - listener.onResponse((T) StringUtils.gson.toJson(results)); - } - - private void parseCustomResults(List urls, String authorization, String nextPage, ActionListener listener) { - Map results = new HashMap<>(); - results.put("next_page", nextPage); - List> crawlResults = new ArrayList<>(); - for (int i = 0; i < urls.size(); i++) { - String link = urls.get(i); - Map crawlResult = crawlPage(link, authorization); - crawlResults.add(crawlResult); - } - results.put("items", crawlResults); - listener.onResponse((T) StringUtils.gson.toJson(results)); + return String.format(Locale.ROOT, "%s?q%s&textFormat=HTML&count=10&" + OFFSET + "=%d", endpoint, query, offset); } private void fetchDuckDuckGoResult(String endpoint, ActionListener listener) { @@ -335,8 +302,8 @@ private void fetchDuckDuckGoResult(String endpoint, ActionListener listen Map crawlResult = crawlPage(link, null); crawlResults.add(crawlResult); } - results.put("next_page", nextPage); - results.put("items", crawlResults); + results.put(NEXT_PAGE, nextPage); + results.put(ITEMS, crawlResults); listener.onResponse((T) StringUtils.gson.toJson(results)); } catch (IOException e) { log.error("Failed to fetch duckduckgo results due to exception!"); @@ -378,48 +345,6 @@ private String getDDGNextPageLink(String endpoint, Document doc) { return sb.toString(); } - /** - * crawl a page and put the page content into the results map if it can be crawled successfully. - * - * @param url The url to crawl - */ - private Map crawlPage(String url, String authorization) { - try { - Connection connection = Jsoup.connect(url).timeout(10000).userAgent(USER_AGENT); - if (authorization != null) { - connection.header("Authorization", authorization); - } - Document doc = connection.get(); - Elements parentElements = doc.select("body"); - if (isCaptchaOrLoginPage(doc)) { - log.debug("Skipping {} - CAPTCHA required", url); - return null; - } - - Element bodyElement = parentElements.getFirst(); - String title = bodyElement.select("title").text(); - String content = bodyElement.text(); - return ImmutableMap.of("url", url, "title", title, "content", content); - } catch (Exception e) { - log.error("Failed to crawl link: {}", url); - return null; - } - } - - private boolean isCaptchaOrLoginPage(Document doc) { - String html = doc.html().toLowerCase(Locale.ROOT); - // 1. Check for CAPTCHA indicators - return !doc.select("input[name*='captcha'], input[id*='captcha']").isEmpty() || - // Google reCAPTCHA markers - !doc.select(".g-recaptcha, div[data-sitekey]").isEmpty() || - // CAPTCHA image patterns - !doc.select("img[src*='captcha'], img[src*='recaptcha']").isEmpty() || - // Text-based indicators - org.apache.commons.lang3.StringUtils.containsIgnoreCase(html, "verify you are human") || - // hCAPTCHA detection - !doc.select(".h-captcha").isEmpty(); - } - @Override public String getType() { return TYPE; @@ -427,46 +352,46 @@ public String getType() { @Override public boolean validate(Map parameters) { - String engine = parameters.get("engine"); + String engine = parameters.get(ENGINE); if (org.apache.commons.lang3.StringUtils.isEmpty(engine)) { return false; } - boolean isQueryEmpty = org.apache.commons.lang3.StringUtils.isEmpty(parameters.getOrDefault("query", parameters.get("question"))); + boolean isQueryEmpty = org.apache.commons.lang3.StringUtils.isEmpty(parameters.getOrDefault(QUERY, parameters.get(QUESTION))); if (isQueryEmpty) { log.warn("Query is empty"); return false; } boolean isEndpointEmpty = org.apache.commons.lang3.StringUtils - .isEmpty(parameters.getOrDefault("endpoint", getDefaultEndpoint(engine))); + .isEmpty(parameters.getOrDefault(ENDPOINT, getDefaultEndpoint(engine))); if (isEndpointEmpty) { log.warn("Endpoint is empty"); return false; } - if ("google".equalsIgnoreCase(engine)) { - boolean hasEngineIdAndApiKey = parameters.containsKey("engine_id") - && !parameters.get("engine_id").isEmpty() - && parameters.containsKey("api_key") - && !parameters.get("api_key").isEmpty(); + if (GOOGLE.equalsIgnoreCase(engine)) { + boolean hasEngineIdAndApiKey = parameters.containsKey(ENGINE_ID) + && !parameters.get(ENGINE_ID).isEmpty() + && parameters.containsKey(API_KEY) + && !parameters.get(API_KEY).isEmpty(); if (!hasEngineIdAndApiKey) { - log.warn("Google search engine_id or api_key is empty"); + log.warn("Google search" + ENGINE_ID + "or api_key is empty"); return false; } return true; - } else if ("duckduckgo".equalsIgnoreCase(engine)) { + } else if (DUCKDUCKGO.equalsIgnoreCase(engine)) { return true; - } else if ("bing".equalsIgnoreCase(engine)) { - boolean hasApiKey = org.apache.commons.lang3.StringUtils.isEmpty(parameters.get("api_key")); + } else if (BING.equalsIgnoreCase(engine)) { + boolean hasApiKey = org.apache.commons.lang3.StringUtils.isEmpty(parameters.get(API_KEY)); if (!hasApiKey) { log.warn("Bing search api_key is empty"); return false; } return true; - } else if ("custom".equalsIgnoreCase(engine)) { - String customApi = parameters.get("custom_api"); - String customResUrlJsonpath = parameters.get("custom_res_url_jsonpath"); + } else if (CUSTOM.equalsIgnoreCase(engine)) { + String customApi = parameters.get(CUSTOM_API); + String customResUrlJsonpath = parameters.get(CUSTOM_RES_URL_JSONPATH); if (org.apache.commons.lang3.StringUtils.isEmpty(customApi) || org.apache.commons.lang3.StringUtils.isEmpty(customResUrlJsonpath)) { log.warn("custom search API is empty or result json path is empty"); @@ -480,6 +405,48 @@ public boolean validate(Map parameters) { return false; } + /** + * crawl a page and put the page content into the results map if it can be crawled successfully. + * + * @param url The url to crawl + */ + public Map crawlPage(String url, String authorization) { + try { + Connection connection = Jsoup.connect(url).timeout(10000).userAgent(USER_AGENT); + if (authorization != null) { + connection.header(AUTHORIZATION, authorization); + } + Document doc = connection.get(); + Elements parentElements = doc.select("body"); + if (isCaptchaOrLoginPage(doc)) { + log.debug("Skipping {} - CAPTCHA required", url); + return null; + } + + Element bodyElement = parentElements.getFirst(); + String title = bodyElement.select(TITLE).text(); + String content = bodyElement.text(); + return ImmutableMap.of(URL, url, TITLE, title, CONTENT, content); + } catch (Exception e) { + log.error("Failed to crawl link: {}", url); + return null; + } + } + + private boolean isCaptchaOrLoginPage(Document doc) { + String html = doc.html().toLowerCase(Locale.ROOT); + // 1. Check for CAPTCHA indicators + return !doc.select("input[name*='captcha'], input[id*='captcha']").isEmpty() || + // Google reCAPTCHA markers + !doc.select(".g-recaptcha, div[data-sitekey]").isEmpty() || + // CAPTCHA image patterns + !doc.select("img[src*='captcha'], img[src*='recaptcha']").isEmpty() || + // Text-based indicators + org.apache.commons.lang3.StringUtils.containsIgnoreCase(html, "verify you are human") || + // hCAPTCHA detection + !doc.select(".h-captcha").isEmpty(); + } + public static class Factory implements Tool.Factory { private static Factory INSTANCE; private ThreadPool threadPool; @@ -524,4 +491,163 @@ public Map getDefaultAttributes() { return DEFAULT_ATTRIBUTES; } } + + private final class WebSearchResponseHandler implements SdkAsyncHttpResponseHandler { + private final String endpoint; + private final String authorization; + private final String parsedNextPage; + private final String engine; + private final String customResUrlJsonpath; + private final ActionListener listener; + + public WebSearchResponseHandler( + String endpoint, + String authorization, + String parsedNextPage, + String engine, + String customResUrlJsonpath, + ActionListener listener + ) { + this.endpoint = endpoint; + this.authorization = authorization; + this.parsedNextPage = parsedNextPage; + this.engine = engine; + this.customResUrlJsonpath = customResUrlJsonpath; + this.listener = listener; + } + + @Override + public void onHeaders(SdkHttpResponse response) { + SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse) response; + log.debug("received response headers: " + sdkResponse.headers()); + int statusCode = sdkResponse.statusCode(); + if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) { + log + .error( + "Received error from endpoint:{} with status code {}, response headers: {}", + endpoint, + statusCode, + sdkResponse.headers() + ); + listener + .onFailure( + new OpenSearchStatusException( + String.format(Locale.ROOT, "Failed to fetch results from endpoint: %s", endpoint), + RestStatus.fromCode(statusCode) + ) + ); + } + } + + @Override + public void onStream(Publisher stream) { + stream.subscribe(new Subscriber<>() { + private final StringBuilder responseBuilder = new StringBuilder(); + private Subscription subscription; + + @Override + public void onSubscribe(Subscription subscription) { + log.debug("Starting to fetch response..."); + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + responseBuilder.append(StandardCharsets.UTF_8.decode(byteBuffer)); + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onError(Throwable throwable) { + log.error("Failed to fetch results from endpoint: {}", endpoint, throwable); + listener.onFailure(new RuntimeException(throwable)); + } + + @Override + public void onComplete() { + log.debug("Successfully fetched results from endpoint: {}", endpoint); + parseResponse(responseBuilder.toString(), authorization, parsedNextPage, engine, customResUrlJsonpath, listener); + } + }); + } + + @Override + public void onError(Throwable error) { + log.error("Failed to fetch results from endpoint: {}", endpoint, error); + listener.onFailure(new RuntimeException(error)); + } + + private void parseResponse( + String rawResponse, + String authorization, + String nextPage, + String engine, + String customResUrlJsonpath, + ActionListener listener + ) { + JsonObject rawJson = JsonParser.parseString(rawResponse).getAsJsonObject(); + switch (engine.toLowerCase(Locale.ROOT)) { + case GOOGLE: + parseGoogleResults(rawJson, nextPage, listener); + break; + case BING: + parseBingResults(rawJson, nextPage, listener); + break; + case CUSTOM: + List urls = JsonPath.read(rawResponse, customResUrlJsonpath); + parseCustomResults(urls, authorization, nextPage, listener); + break; + default: + listener.onFailure(new RuntimeException(String.format(Locale.ROOT, "Unsupported search engine: %s", engine))); + } + } + + private void parseGoogleResults(JsonObject googleResponse, String nextPage, ActionListener listener) { + Map results = new HashMap<>(); + results.put(NEXT_PAGE, nextPage); + // extract search results, each item is a search result: + // https://developers.google.com/custom-search/v1/reference/rest/v1/Search#result + JsonArray items = googleResponse.getAsJsonArray(ITEMS); + List> crawlResults = new ArrayList<>(); + for (int i = 0; i < items.size(); i++) { + JsonObject item = items.get(i).getAsJsonObject(); + // extract the actual link for scrawl. + String link = item.get("link").getAsString(); + // extract title and content. + Map crawlResult = crawlPage(link, null); + crawlResults.add(crawlResult); + } + results.put(ITEMS, crawlResults); + listener.onResponse((T) StringUtils.gson.toJson(results)); + } + + private void parseBingResults(JsonObject bingResponse, String nextPage, ActionListener listener) { + Map results = new HashMap<>(); + results.put(NEXT_PAGE, nextPage); + List> crawlResults = new ArrayList<>(); + JsonArray values = bingResponse.get("webPages").getAsJsonObject().getAsJsonArray("value"); + for (int i = 0; i < values.size(); i++) { + JsonObject value = values.get(i).getAsJsonObject(); + String link = value.get(URL).getAsString(); + Map crawlResult = crawlPage(link, null); + crawlResults.add(crawlResult); + } + results.put(ITEMS, crawlResults); + listener.onResponse((T) StringUtils.gson.toJson(results)); + } + + private void parseCustomResults(List urls, String authorization, String nextPage, ActionListener listener) { + Map results = new HashMap<>(); + results.put(NEXT_PAGE, nextPage); + List> crawlResults = new ArrayList<>(); + for (int i = 0; i < urls.size(); i++) { + String link = urls.get(i); + Map crawlResult = crawlPage(link, authorization); + crawlResults.add(crawlResult); + } + results.put(ITEMS, crawlResults); + listener.onResponse((T) StringUtils.gson.toJson(results)); + } + } } diff --git a/src/test/java/org/opensearch/agent/ToolPluginTests.java b/src/test/java/org/opensearch/agent/ToolPluginTests.java index 0bbee973..d8cbc388 100644 --- a/src/test/java/org/opensearch/agent/ToolPluginTests.java +++ b/src/test/java/org/opensearch/agent/ToolPluginTests.java @@ -96,7 +96,7 @@ public void test_getRestHandlers_successful() { @Test public void test_getToolFactories_successful() { - assertEquals(14, toolPlugin.getToolFactories().size()); + assertEquals(16, toolPlugin.getToolFactories().size()); } @Test diff --git a/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java b/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java index d53bf0d8..d79336b0 100644 --- a/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java @@ -5,6 +5,7 @@ package org.opensearch.agent.tools; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.jsoup.helper.Validate.fail; import static org.junit.Assert.assertEquals; @@ -883,8 +884,8 @@ public void testExecutionWithSizeExceedsMaxLimit() { "15000" ), ActionListener.wrap(response -> fail("Should have failed with size exceeding limit"), e -> { - MatcherAssert.assertThat(e.getMessage(), containsString("Size parameter exceeds maximum limit of 10000")); - MatcherAssert.assertThat(e.getMessage(), containsString("got: 15000")); + MatcherAssert.assertThat(e.getMessage(), containsString("must be between 1 and 10000")); + MatcherAssert.assertThat(e.getMessage(), containsString("15000")); }) ); } @@ -1168,284 +1169,6 @@ private List> createHighCardinalityTestData() { return data; } - @Test - @SneakyThrows - public void testBuildQueryFromMapWithTermQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("status", Map.of("term", "error")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("term")); - assertTrue(queryBuilder.toString().contains("status")); - assertTrue(queryBuilder.toString().contains("error")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithRangeQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("level", Map.of("range", Map.of("gte", 3, "lte", 5))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("range")); - assertTrue(queryBuilder.toString().contains("level")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMatchQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("message", Map.of("match", "test message")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("match")); - assertTrue(queryBuilder.toString().contains("message")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithExistsQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("status", Map.of("exists", true)); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("exists")); - assertTrue(queryBuilder.toString().contains("status")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithDirectTermQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("status", "error"); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("term")); - assertTrue(queryBuilder.toString().contains("status")); - assertTrue(queryBuilder.toString().contains("error")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMatchPhraseQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("message", Map.of("match_phrase", "exact phrase")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("match_phrase")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithPrefixQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("prefix", "server")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("prefix")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithWildcardQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("wildcard", "server*")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("wildcard")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithWildcardMapQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("wildcard", Map.of("value", "server*"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("wildcard")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithRegexpQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("regexp", "server-[0-9]+")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("regexp")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithRegexpMapQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("host", Map.of("regexp", Map.of("value", "server-[0-9]+"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("regexp")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithTermsQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("terms", Map.of("status", List.of("error", "warning"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("terms")); - assertTrue(queryBuilder.toString().contains("status")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMultiMatchQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map - .of("multi_match", Map.of("query", "error message", "fields", List.of("message", "description"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("multi_match")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithComplexRangeQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("level", Map.of("range", Map.of("gt", 1, "lt", 10))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - assertTrue(queryBuilder.toString().contains("range")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithUnsupportedOperator() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - Map filterMap = Map.of("status", Map.of("unsupported_op", "value")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - } - @Test @SneakyThrows public void testGroupNumericKeysWithManyNumericValues() { @@ -1852,6 +1575,7 @@ public void testDSLWithInvalidRawDSLQuery() { mockSearchResponse(); DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + // Test that invalid DSL query causes execution to fail tool .run( ImmutableMap @@ -1868,13 +1592,14 @@ public void testDSLWithInvalidRawDSLQuery() { "invalid-json-query" ), ActionListener.wrap(response -> { - // Should fallback to time range only query when DSL parsing fails - JsonElement result = gson.fromJson(response, JsonElement.class); + fail("Should have failed with invalid DSL query"); + }, e -> { + // Expect failure due to invalid DSL format assertTrue( - "Response should contain singleAnalysis even with invalid DSL", - result.getAsJsonObject().has("singleAnalysis") + "Should fail with exception for invalid DSL", + e instanceof IllegalArgumentException || e.getMessage().contains("Invalid query format") ); - }, e -> fail("Tool execution failed: " + e.getMessage())) + }) ); } @@ -2101,6 +1826,7 @@ public void testDSLWithFilterArrayInvalidFilter() { mockSearchResponse(); DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + // Test that invalid filter JSON causes parameter validation to fail tool .run( ImmutableMap @@ -2117,13 +1843,15 @@ public void testDSLWithFilterArrayInvalidFilter() { "[\"{'term': {'status': 'error'}}\", \"invalid-json-filter\"]" ), ActionListener.wrap(response -> { - // Should continue processing valid filters and ignore invalid ones - JsonElement result = gson.fromJson(response, JsonElement.class); + fail("Should have failed with invalid filter JSON"); + }, e -> { + // Expect IllegalArgumentException due to invalid filter format + assertTrue("Should fail with IllegalArgumentException for invalid filter", e instanceof IllegalArgumentException); assertTrue( - "Response should contain singleAnalysis even with some invalid filters", - result.getAsJsonObject().has("singleAnalysis") + "Error message should mention invalid filter", + e.getMessage().contains("Invalid 'filter' parameter") || e.getMessage().contains("Invalid query format") ); - }, e -> fail("Tool execution failed: " + e.getMessage())) + }) ); } @@ -2484,8 +2212,13 @@ public void testDSLWithMalformedFilterJSON() { "[malformed-json]" ), ActionListener.wrap(response -> fail("Should have failed with malformed filter JSON"), e -> { - MatcherAssert.assertThat(e.getMessage(), containsString("Invalid 'filter' parameter")); - MatcherAssert.assertThat(e.getMessage(), containsString("must be a valid JSON array of strings")); + // The filter "[malformed-json]" is valid JSON array syntax, but "malformed-json" is not valid query JSON + // Error occurs during query execution, not parameter validation + MatcherAssert + .assertThat( + e.getMessage(), + anyOf(containsString("Invalid query format"), containsString("Invalid 'filter' parameter")) + ); }) ); } @@ -2637,84 +2370,4 @@ public void testDSLQueryPrecedenceOverFilter() { }, e -> fail("Tool execution failed: " + e.getMessage())) ); } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithTermsQueryNonListValue() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - // Test terms query with non-list value (should be ignored) - Map filterMap = Map.of("terms", Map.of("status", "error")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - // Should not contain terms query since value is not a list - assertFalse(queryBuilder.toString().contains("terms")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMultiMatchQueryMissingFields() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - // Test multi_match query with missing fields (should be ignored) - Map filterMap = Map.of("multi_match", Map.of("query", "error message")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - // Should not contain multi_match query since fields is missing - assertFalse(queryBuilder.toString().contains("multi_match")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMultiMatchQueryMissingQuery() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - // Test multi_match query with missing query (should be ignored) - Map filterMap = Map.of("multi_match", Map.of("fields", List.of("message", "description"))); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - // Should not contain multi_match query since query is missing - assertFalse(queryBuilder.toString().contains("multi_match")); - } - - @Test - @SneakyThrows - public void testBuildQueryFromMapWithMultiMatchQueryNonListFields() { - DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); - - java.lang.reflect.Method buildQueryMethod = DataDistributionTool.class - .getDeclaredMethod("buildQueryFromMap", Map.class, org.opensearch.index.query.BoolQueryBuilder.class); - buildQueryMethod.setAccessible(true); - - // Test multi_match query with non-list fields (should be ignored) - Map filterMap = Map.of("multi_match", Map.of("query", "error message", "fields", "message")); - org.opensearch.index.query.BoolQueryBuilder queryBuilder = org.opensearch.index.query.QueryBuilders.boolQuery(); - - buildQueryMethod.invoke(tool, filterMap, queryBuilder); - - assertNotNull(queryBuilder); - // Should not contain multi_match query since fields is not a list - assertFalse(queryBuilder.toString().contains("multi_match")); - } } diff --git a/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java b/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java index 3adffe17..dd17e50d 100644 --- a/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java @@ -20,6 +20,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; import org.hamcrest.MatcherAssert; import org.junit.Before; @@ -28,6 +29,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.core.action.ActionListener; import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.transport.client.Client; @@ -466,6 +468,105 @@ public void testExecutionWithNonExistentIndex() { ); } + @Test + @SneakyThrows + public void testLogInsightWithFilter() { + String pplResponse = + """ + {"schema":[{"name":"patterns_field","type":"string"},{"name":"pattern_count","type":"long"},{"name":"sample_logs","type":"array"}], + "datarows":[["Auth error for user <*>",3,["Auth error for user admin","Auth error for user guest"]]], + "total":1,"size":1} + """; + + AtomicReference capturedPPL = new AtomicReference<>(); + doAnswer(invocation -> { + TransportPPLQueryRequest request = (TransportPPLQueryRequest) invocation.getArguments()[1]; + capturedPPL.set(request.getRequest()); + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + when(pplQueryResponse.getResult()).thenReturn(pplResponse); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .builder() + .put("index", "test_index") + .put("timeField", "@timestamp") + .put("logFieldName", "message") + .put("selectionTimeRangeStart", "2025-01-01T00:00:00Z") + .put("selectionTimeRangeEnd", "2025-01-01T01:00:00Z") + .put("filter", "serviceName='ts-auth-service'") + .build(), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("logInsights")); + // Verify the PPL query contains the filter clause + assertTrue(capturedPPL.get().contains("where serviceName='ts-auth-service'")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testLogPatternDiffWithFilter() { + String baseResponse = """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[100,"User login successful"],[20,"Database query executed"]], + "total":2,"size":2} + """; + + String selectionResponse = """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[50,"User login successful"],[80,"Error in authentication <*>"]], + "total":2,"size":2} + """; + + AtomicReference firstPPL = new AtomicReference<>(); + AtomicReference secondPPL = new AtomicReference<>(); + doAnswer(invocation -> { + TransportPPLQueryRequest request = (TransportPPLQueryRequest) invocation.getArguments()[1]; + String ppl = request.getRequest(); + if (firstPPL.get() == null) { + firstPPL.set(ppl); + } else { + secondPPL.set(ppl); + } + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + when(pplQueryResponse.getResult()).thenReturn(baseResponse).thenReturn(selectionResponse); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .builder() + .put("index", "test_index") + .put("timeField", "@timestamp") + .put("logFieldName", "message") + .put("baseTimeRangeStart", "2025-01-01T00:00:00Z") + .put("baseTimeRangeEnd", "2025-01-01T01:00:00Z") + .put("selectionTimeRangeStart", "2025-01-01T01:00:00Z") + .put("selectionTimeRangeEnd", "2025-01-01T02:00:00Z") + .put("filter", "severity='ERROR'") + .build(), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("patternMapDifference")); + // Verify both PPL queries contain the filter clause + assertTrue(firstPPL.get().contains("where severity='ERROR'")); + assertTrue(secondPPL.get().contains("where severity='ERROR'")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + @Test @SneakyThrows public void testExecutionWithNonExistentLogField() { diff --git a/src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java b/src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java new file mode 100644 index 00000000..0ffcf06e --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java @@ -0,0 +1,619 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.IndicesAdminClient; + +import com.google.common.collect.ImmutableMap; + +import lombok.SneakyThrows; + +public class MetricChangeAnalysisToolTests { + + private Map params = new HashMap<>(); + private final Client client = mock(Client.class); + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + @Mock + private SearchResponse searchResponse; + + @SneakyThrows + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + setupMockMappings(); + MetricChangeAnalysisTool.Factory.getInstance().init(client); + } + + private void setupMockMappings() { + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + + Map properties = ImmutableMap + .builder() + .put("response_time", ImmutableMap.of("type", "long")) + .put("cpu_usage", ImmutableMap.of("type", "double")) + .put("memory_usage", ImmutableMap.of("type", "float")) + .put("status", ImmutableMap.of("type", "keyword")) + .put("@timestamp", ImmutableMap.of("type", "date")) + .build(); + + Map mappingSource = ImmutableMap.of("properties", properties); + when(mappingMetadata.getSourceAsMap()).thenReturn(mappingSource); + when(getMappingsResponse.getMappings()).thenReturn(ImmutableMap.of("test-index", mappingMetadata)); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getMappingsResponse); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + } + + private SearchHit[] createSampleHits(Map... sources) { + SearchHit[] hits = new SearchHit[sources.length]; + for (int i = 0; i < sources.length; i++) { + hits[i] = new SearchHit(i); + hits[i].sourceRef(null); + hits[i].score(1.0f); + // Use reflection to set source + try { + java.lang.reflect.Field sourceField = SearchHit.class.getDeclaredField("source"); + sourceField.setAccessible(true); + sourceField.set(hits[i], sources[i]); + } catch (Exception e) { + // Fallback: create hit with source + } + } + return hits; + } + + private void mockSearchResponse(SearchHit[] hits) { + SearchHits searchHits = new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + when(searchResponse.getHits()).thenReturn(searchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + } + + // ========== Percentile Calculation Tests ========== + + @Test + @SneakyThrows + public void testCalculatePercentile() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentile", List.class, int.class); + calculatePercentileMethod.setAccessible(true); + + // Test with sorted values + List values = List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0); + + double p25 = (double) calculatePercentileMethod.invoke(tool, values, 25); + double p50 = (double) calculatePercentileMethod.invoke(tool, values, 50); + double p75 = (double) calculatePercentileMethod.invoke(tool, values, 75); + double p90 = (double) calculatePercentileMethod.invoke(tool, values, 90); + + assertEquals(3.25, p25, 0.01); + assertEquals(5.5, p50, 0.01); + assertEquals(7.75, p75, 0.01); + assertEquals(9.1, p90, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileWithSingleValue() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentile", List.class, int.class); + calculatePercentileMethod.setAccessible(true); + + List values = List.of(5.0); + + double p25 = (double) calculatePercentileMethod.invoke(tool, values, 25); + double p50 = (double) calculatePercentileMethod.invoke(tool, values, 50); + double p75 = (double) calculatePercentileMethod.invoke(tool, values, 75); + double p90 = (double) calculatePercentileMethod.invoke(tool, values, 90); + + assertEquals(5.0, p25, 0.01); + assertEquals(5.0, p50, 0.01); + assertEquals(5.0, p75, 0.01); + assertEquals(5.0, p90, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileWithEmptyList() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentile", List.class, int.class); + calculatePercentileMethod.setAccessible(true); + + List values = List.of(); + + double p50 = (double) calculatePercentileMethod.invoke(tool, values, 50); + + assertEquals(0.0, p50, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentiles() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentilesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentiles", List.class); + calculatePercentilesMethod.setAccessible(true); + + List values = List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0); + + Object result = calculatePercentilesMethod.invoke(tool, values); + + assertNotNull(result); + java.lang.reflect.Method p50Method = result.getClass().getDeclaredMethod("p50"); + java.lang.reflect.Method p90Method = result.getClass().getDeclaredMethod("p90"); + + double p50 = (double) p50Method.invoke(result); + double p90 = (double) p90Method.invoke(result); + + assertEquals(5.5, p50, 0.01); + assertEquals(9.1, p90, 0.01); + } + + // ========== Numeric Value Extraction Tests ========== + + @Test + @SneakyThrows + public void testExtractNumericValues() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method extractNumericValuesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("extractNumericValues", List.class, String.class); + extractNumericValuesMethod.setAccessible(true); + + List> data = List + .of( + Map.of("level", 1, "status", "error"), + Map.of("level", 2, "status", "warning"), + Map.of("level", 3, "status", "info"), + Map.of("level", "4", "status", "debug"), // String number + Map.of("status", "error") // Missing level field + ); + + @SuppressWarnings("unchecked") + List values = (List) extractNumericValuesMethod.invoke(tool, data, "level"); + + assertNotNull(values); + assertEquals(4, values.size()); + assertTrue(values.contains(1.0)); + assertTrue(values.contains(2.0)); + assertTrue(values.contains(3.0)); + assertTrue(values.contains(4.0)); + } + + @Test + @SneakyThrows + public void testExtractNumericValuesWithNonNumericData() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method extractNumericValuesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("extractNumericValues", List.class, String.class); + extractNumericValuesMethod.setAccessible(true); + + List> data = List.of(Map.of("status", "error"), Map.of("status", "warning"), Map.of("status", "info")); + + @SuppressWarnings("unchecked") + List values = (List) extractNumericValuesMethod.invoke(tool, data, "status"); + + assertNotNull(values); + assertTrue(values.isEmpty()); + } + + @Test + @SneakyThrows + public void testExtractNumericValuesWithNestedField() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method extractNumericValuesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("extractNumericValues", List.class, String.class); + extractNumericValuesMethod.setAccessible(true); + + List> data = List + .of( + Map.of("metrics", Map.of("response_time", 100)), + Map.of("metrics", Map.of("response_time", 200)), + Map.of("metrics", Map.of("response_time", 300)) + ); + + @SuppressWarnings("unchecked") + List values = (List) extractNumericValuesMethod.invoke(tool, data, "metrics.response_time"); + + assertNotNull(values); + assertEquals(3, values.size()); + assertTrue(values.contains(100.0)); + assertTrue(values.contains(200.0)); + assertTrue(values.contains(300.0)); + } + + // ========== Log Ratio Tests ========== + + @Test + @SneakyThrows + public void testSafeLogRatio() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method safeLogRatioMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("safeLogRatio", double.class, double.class); + safeLogRatioMethod.setAccessible(true); + + // Test 2x increase: |log(2/1)| = log(2) ≈ 0.693 + double ratio1 = (double) safeLogRatioMethod.invoke(tool, 2.0, 1.0); + assertEquals(Math.log(2.0), ratio1, 0.01); + + // Test 1.8x increase: |log(180/100)| = log(1.8) ≈ 0.588 + double ratio2 = (double) safeLogRatioMethod.invoke(tool, 180.0, 100.0); + assertEquals(Math.log(1.8), ratio2, 0.01); + + // Test decrease: |log(5/10)| = |log(0.5)| = log(2) ≈ 0.693 + double ratio3 = (double) safeLogRatioMethod.invoke(tool, 5.0, 10.0); + assertEquals(Math.log(2.0), ratio3, 0.01); + + // Test zero baseline: should return cap (10.0) + double ratio4 = (double) safeLogRatioMethod.invoke(tool, 10.0, 0.0); + assertEquals(10.0, ratio4, 0.01); + + // Test both near zero: should return 0.0 + double ratio5 = (double) safeLogRatioMethod.invoke(tool, 0.0, 0.0); + assertEquals(0.0, ratio5, 0.01); + + // Test no change: |log(1)| = 0 + double ratio6 = (double) safeLogRatioMethod.invoke(tool, 100.0, 100.0); + assertEquals(0.0, ratio6, 0.01); + } + + // ========== Variance Calculation Tests ========== + + @Test + @SneakyThrows + public void testCalculatePercentileVarianceSkipsZeroBaseline() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + // Both baselines zero → score 0 + Object sel1 = constructor.newInstance(10.0, 5.0); + Object base1 = constructor.newInstance(0.0, 0.0); + assertEquals(0.0, (double) calculatePercentileVarianceMethod.invoke(tool, sel1, base1), 0.01); + + // Only P50 baseline zero → score based on P90 only + Object sel2 = constructor.newInstance(10.0, 20.0); + Object base2 = constructor.newInstance(0.0, 10.0); + double expected = Math.abs(Math.log(20.0 / 10.0)); // log(2) ≈ 0.693 + assertEquals(expected, (double) calculatePercentileVarianceMethod.invoke(tool, sel2, base2), 0.01); + + // Only P90 baseline zero → score based on P50 only + Object sel3 = constructor.newInstance(20.0, 5.0); + Object base3 = constructor.newInstance(10.0, 0.0); + double expected3 = Math.abs(Math.log(20.0 / 10.0)); // log(2) ≈ 0.693 + assertEquals(expected3, (double) calculatePercentileVarianceMethod.invoke(tool, sel3, base3), 0.01); + } + + // ========== Variance Calculation Tests ========== + + @Test + @SneakyThrows + public void testCalculatePercentileVariance() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + // Create PercentileStats using reflection (p50, p90) + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + // selection p50=20, p90=40; baseline p50=10, p90=20 → both are 2x + Object selectionStats = constructor.newInstance(20.0, 40.0); + Object baselineStats = constructor.newInstance(10.0, 20.0); + + double variance = (double) calculatePercentileVarianceMethod.invoke(tool, selectionStats, baselineStats); + + // score = 0.5 * log(2) + 0.5 * log(2) = log(2) ≈ 0.693 + assertEquals(Math.log(2.0), variance, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileVarianceWithNoChange() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + Object selectionStats = constructor.newInstance(20.0, 40.0); + Object baselineStats = constructor.newInstance(20.0, 40.0); + + double variance = (double) calculatePercentileVarianceMethod.invoke(tool, selectionStats, baselineStats); + + assertEquals(0.0, variance, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileVarianceRelativeChange() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + // Test that 1->2 (2x) ranks higher than 100->180 (1.8x) + Object smallChangeStats = constructor.newInstance(2.0, 2.0); + Object smallBaselineStats = constructor.newInstance(1.0, 1.0); + + Object largeChangeStats = constructor.newInstance(180.0, 180.0); + Object largeBaselineStats = constructor.newInstance(100.0, 100.0); + + double smallVariance = (double) calculatePercentileVarianceMethod.invoke(tool, smallChangeStats, smallBaselineStats); + double largeVariance = (double) calculatePercentileVarianceMethod.invoke(tool, largeChangeStats, largeBaselineStats); + + // 2x change: 0.5 * log(2) + 0.5 * log(2) = log(2) ≈ 0.693 + // 1.8x change: 0.5 * log(1.8) + 0.5 * log(1.8) = log(1.8) ≈ 0.588 + assertEquals(Math.log(2.0), smallVariance, 0.01); + assertEquals(Math.log(1.8), largeVariance, 0.01); + assertTrue("2x change should rank higher than 1.8x change", smallVariance > largeVariance); + } + + // ========== Percentile Analysis Tests ========== + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysis() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + // Use non-monotonic (gauge-like) data to avoid counter detection + List> selectionData = List + .of( + Map.of("response_time", 400, "cpu_usage", 80), + Map.of("response_time", 100, "cpu_usage", 50), + Map.of("response_time", 300, "cpu_usage", 60), + Map.of("response_time", 200, "cpu_usage", 70) + ); + + List> baselineData = List + .of( + Map.of("response_time", 150, "cpu_usage", 65), + Map.of("response_time", 50, "cpu_usage", 45), + Map.of("response_time", 200, "cpu_usage", 75), + Map.of("response_time", 100, "cpu_usage", 55) + ); + + java.util.Set numberFields = java.util.Set.of("response_time", "cpu_usage"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + assertEquals(2, analyses.size()); + + // Verify first analysis has highest variance + java.lang.reflect.Method varianceMethod = analyses.get(0).getClass().getDeclaredMethod("variance"); + double firstVariance = (double) varianceMethod.invoke(analyses.get(0)); + double secondVariance = (double) varianceMethod.invoke(analyses.get(1)); + + assertTrue(firstVariance >= secondVariance); + } + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysisWithEmptyData() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + List> selectionData = List.of(); + List> baselineData = List.of(); + java.util.Set numberFields = java.util.Set.of("response_time"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + assertTrue(analyses.isEmpty()); + } + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysisWithMissingFields() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + List> selectionData = List.of(Map.of("response_time", 100), Map.of("response_time", 200)); + + List> baselineData = List.of(Map.of("cpu_usage", 50), Map.of("cpu_usage", 60)); + + java.util.Set numberFields = java.util.Set.of("response_time", "cpu_usage"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + // Should skip fields that don't have data in both datasets + assertTrue(analyses.isEmpty()); + } + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysisRanking() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + // Create non-monotonic data where response_time has high change and cpu_usage has low change + List> selectionData = List + .of( + Map.of("response_time", 3000, "cpu_usage", 52), + Map.of("response_time", 1000, "cpu_usage", 54), + Map.of("response_time", 4000, "cpu_usage", 51), + Map.of("response_time", 2000, "cpu_usage", 53) + ); + + List> baselineData = List + .of( + Map.of("response_time", 300, "cpu_usage", 52), + Map.of("response_time", 100, "cpu_usage", 50), + Map.of("response_time", 400, "cpu_usage", 53), + Map.of("response_time", 200, "cpu_usage", 51) + ); + + java.util.Set numberFields = java.util.Set.of("response_time", "cpu_usage"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + assertEquals(2, analyses.size()); + + // First field should be response_time (higher variance) + java.lang.reflect.Method fieldMethod = analyses.get(0).getClass().getDeclaredMethod("field"); + String firstField = (String) fieldMethod.invoke(analyses.get(0)); + assertEquals("response_time", firstField); + + // Second field should be cpu_usage (lower variance) + String secondField = (String) fieldMethod.invoke(analyses.get(1)); + assertEquals("cpu_usage", secondField); + } + + @Test + @SneakyThrows + public void testFormatResultsLimitsToTopTen() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method formatResultsMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("formatResults", List.class, int.class); + formatResultsMethod.setAccessible(true); + + // Create 15 fields with different variance scores + List analyses = new ArrayList<>(); + Class fieldAnalysisClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$FieldPercentileAnalysis"); + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + + java.lang.reflect.Constructor statsConstructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + statsConstructor.setAccessible(true); + + java.lang.reflect.Constructor analysisConstructor = fieldAnalysisClass + .getDeclaredConstructor(String.class, double.class, percentileStatsClass, percentileStatsClass); + analysisConstructor.setAccessible(true); + + Object stats = statsConstructor.newInstance(20.0, 40.0); + + // Create 15 fields with descending variance scores + for (int i = 0; i < 15; i++) { + double variance = 15.0 - i; // 15.0, 14.0, 13.0, ..., 1.0 + Object analysis = analysisConstructor.newInstance("field_" + i, variance, stats, stats); + analyses.add(analysis); + } + + // Test with topN = 10 + @SuppressWarnings("unchecked") + List results10 = (List) formatResultsMethod.invoke(tool, analyses, 10); + assertNotNull(results10); + assertEquals("Should return only top 10 results", 10, results10.size()); + + // Test with topN = 5 (default) + @SuppressWarnings("unchecked") + List results5 = (List) formatResultsMethod.invoke(tool, analyses, 5); + assertNotNull(results5); + assertEquals("Should return only top 5 results", 5, results5.size()); + + // Test with topN = 3 + @SuppressWarnings("unchecked") + List results3 = (List) formatResultsMethod.invoke(tool, analyses, 3); + assertNotNull(results3); + assertEquals("Should return only top 3 results", 3, results3.size()); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java index 66e1c7d8..9b5a48e2 100644 --- a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java @@ -95,6 +95,8 @@ public void setup() { null, null, null, + null, + null, null ); } diff --git a/src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java new file mode 100644 index 00000000..fb8dba22 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java @@ -0,0 +1,687 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.transport.client.Client; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; + +public class SearchAroundDocumentToolTests { + + private Client client; + private SearchAroundDocumentTool tool; + private static final Gson GSON = new Gson(); + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + client = mock(Client.class); + SearchAroundDocumentTool.Factory.getInstance().init(client, NamedXContentRegistry.EMPTY); + tool = SearchAroundDocumentTool.Factory.getInstance().create(Collections.emptyMap()); + } + + private SearchHit createMockHit(String id, String index, Map source, Object[] sortValues) { + SearchHit hit = mock(SearchHit.class); + when(hit.getId()).thenReturn(id); + when(hit.getIndex()).thenReturn(index); + when(hit.getScore()).thenReturn(1.0f); + when(hit.getSourceAsMap()).thenReturn(source); + when(hit.getSortValues()).thenReturn(sortValues); + return hit; + } + + private SearchResponse createMockSearchResponse(SearchHit[] hits) { + SearchResponse response = mock(SearchResponse.class); + SearchHits searchHits = mock(SearchHits.class); + when(searchHits.getHits()).thenReturn(hits); + when(response.getHits()).thenReturn(searchHits); + return response; + } + + private void mockThreeSearchCalls(SearchResponse targetResponse, SearchResponse beforeResponse, SearchResponse afterResponse) { + AtomicInteger callCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + int call = callCount.getAndIncrement(); + switch (call) { + case 0: + listener.onResponse(targetResponse); + break; + case 1: + listener.onResponse(beforeResponse); + break; + case 2: + listener.onResponse(afterResponse); + break; + default: + listener.onFailure(new RuntimeException("Unexpected search call")); + } + return null; + }).when(client).search(any(SearchRequest.class), any()); + } + + // ========== Validate Tests ========== + + @Test + public void testValidateWithNullParameters() { + assertFalse(tool.validate(null)); + } + + @Test + public void testValidateWithEmptyParameters() { + assertFalse(tool.validate(Collections.emptyMap())); + } + + @Test + public void testValidateWithJsonInput() { + Map params = new HashMap<>(); + params.put("input", "{\"index\":\"test\",\"doc_id\":\"1\",\"timestamp_field\":\"@timestamp\",\"count\":5}"); + assertTrue(tool.validate(params)); + } + + @Test + public void testValidateWithDirectParameters() { + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "5"); + assertTrue(tool.validate(params)); + } + + @Test + public void testValidateWithMissingParameters() { + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc1"); + assertFalse(tool.validate(params)); + } + + @Test + public void testValidateWithEmptyValues() { + Map params = new HashMap<>(); + params.put("index", ""); + params.put("doc_id", "doc1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "5"); + assertFalse(tool.validate(params)); + } + + // ========== Factory Tests ========== + + @Test + public void testFactoryCreate() { + SearchAroundDocumentTool.Factory factory = SearchAroundDocumentTool.Factory.getInstance(); + SearchAroundDocumentTool createdTool = factory.create(Collections.emptyMap()); + assertNotNull(createdTool); + assertEquals(SearchAroundDocumentTool.TYPE, createdTool.getType()); + } + + @Test + public void testFactoryGetInstance() { + SearchAroundDocumentTool.Factory factory1 = SearchAroundDocumentTool.Factory.getInstance(); + SearchAroundDocumentTool.Factory factory2 = SearchAroundDocumentTool.Factory.getInstance(); + assertTrue(factory1 == factory2); + } + + @Test + public void testFactoryDefaults() { + SearchAroundDocumentTool.Factory factory = SearchAroundDocumentTool.Factory.getInstance(); + assertEquals(SearchAroundDocumentTool.TYPE, factory.getDefaultType()); + assertNotNull(factory.getDefaultDescription()); + assertNull(factory.getDefaultVersion()); + assertNotNull(factory.getDefaultAttributes()); + } + + // ========== Tool Metadata Tests ========== + + @Test + public void testGetType() { + assertEquals("SearchAroundDocumentTool", tool.getType()); + } + + @Test + public void testGetVersion() { + assertNull(tool.getVersion()); + } + + // ========== Run - Success Tests ========== + + @Test + public void testRunWithDirectParameters() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit( + "target-doc", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:05", "message", "target"), + sortValues + ); + + SearchHit beforeHit1 = createMockHit( + "before-1", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:03", "message", "before1"), + new Object[] { 998L, 3L } + ); + SearchHit beforeHit2 = createMockHit( + "before-2", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:04", "message", "before2"), + new Object[] { 999L, 4L } + ); + + SearchHit afterHit1 = createMockHit( + "after-1", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:06", "message", "after1"), + new Object[] { 1001L, 6L } + ); + SearchHit afterHit2 = createMockHit( + "after-2", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:07", "message", "after2"), + new Object[] { 1002L, 7L } + ); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + // Before: returned in DESC order (before2, before1) - tool reverses to chronological + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] { beforeHit2, beforeHit1 }); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] { afterHit1, afterHit2 }); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "target-doc"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + assertNotNull(result); + + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(5, docs.size()); + + // Verify chronological order: before1, before2, target, after1, after2 + assertEquals("before-1", docs.get(0).get("_id")); + assertEquals("before-2", docs.get(1).get("_id")); + assertEquals("target-doc", docs.get(2).get("_id")); + assertEquals("after-1", docs.get(3).get("_id")); + assertEquals("after-2", docs.get(4).get("_id")); + } + + @Test + public void testRunWithJsonInput() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit( + "doc-1", + "my-index", + Map.of("@timestamp", "2024-01-01T00:00:05", "message", "target"), + sortValues + ); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("input", "{\"index\":\"my-index\",\"doc_id\":\"doc-1\",\"timestamp_field\":\"@timestamp\",\"count\":2}"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + assertNotNull(result); + + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + assertEquals("doc-1", docs.get(0).get("_id")); + } + + @Test + public void testRunWithNoBeforeOrAfterDocuments() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "only doc"), sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "5"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + assertEquals("doc-1", docs.get(0).get("_id")); + } + + @Test + public void testRunWithCountAsDouble() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2.0"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + assertNotNull(result); + } + + @Test + public void testRunResponseContainsSortValues() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "1"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + + Map doc = docs.get(0); + assertEquals("doc-1", doc.get("_id")); + assertEquals("test-index", doc.get("_index")); + assertNotNull(doc.get("_source")); + assertNotNull(doc.get("sort")); + } + + // ========== Run - Error Tests ========== + + @Test + public void testRunWithDocumentNotFound() throws Exception { + SearchResponse emptyResponse = createMockSearchResponse(new SearchHit[] {}); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(emptyResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "nonexistent"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("Document not found")); + } + } + + @Test + public void testRunWithMissingSortValues() throws Exception { + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), new Object[] {}); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(targetResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("sort values")); + } + } + + @Test + public void testRunWithNullSortValues() throws Exception { + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), null); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(targetResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("sort values")); + } + } + + @Test + public void testRunWithOnlySingleSortValue() throws Exception { + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), new Object[] { 1000L }); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(targetResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("sort values")); + } + } + + @Test + public void testRunWithMissingRequiredParameters() throws Exception { + Map params = new HashMap<>(); + params.put("index", "test-index"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("requires")); + } + } + + @Test + public void testRunWithTargetSearchFailure() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Target search failed")); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("Target search failed", e.getCause().getMessage()); + } + } + + @Test + public void testRunWithBeforeSearchFailure() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + AtomicInteger callCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + int call = callCount.getAndIncrement(); + if (call == 0) { + listener.onResponse(targetResponse); + } else { + listener.onFailure(new RuntimeException("Before search failed")); + } + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("Before search failed", e.getCause().getMessage()); + } + } + + @Test + public void testRunWithAfterSearchFailure() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + + AtomicInteger callCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + int call = callCount.getAndIncrement(); + if (call == 0) { + listener.onResponse(targetResponse); + } else if (call == 1) { + listener.onResponse(beforeResponse); + } else { + listener.onFailure(new RuntimeException("After search failed")); + } + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("After search failed", e.getCause().getMessage()); + } + } + + // ========== Run - Ordering Tests ========== + + @Test + public void testRunBeforeDocsAreReversedToChronological() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("target", "idx", Map.of("ts", 1000), sortValues); + + // Before search returns in DESC order: doc3 (newest before), doc2, doc1 (oldest before) + SearchHit beforeHit3 = createMockHit("before-3", "idx", Map.of("ts", 999), new Object[] { 999L, 4L }); + SearchHit beforeHit2 = createMockHit("before-2", "idx", Map.of("ts", 998), new Object[] { 998L, 3L }); + SearchHit beforeHit1 = createMockHit("before-1", "idx", Map.of("ts", 997), new Object[] { 997L, 2L }); + + SearchHit afterHit1 = createMockHit("after-1", "idx", Map.of("ts", 1001), new Object[] { 1001L, 6L }); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] { beforeHit3, beforeHit2, beforeHit1 }); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] { afterHit1 }); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "idx"); + params.put("doc_id", "target"); + params.put("timestamp_field", "ts"); + params.put("count", "3"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(5, docs.size()); + + // Before docs should be reversed to chronological (oldest first) + assertEquals("before-1", docs.get(0).get("_id")); + assertEquals("before-2", docs.get(1).get("_id")); + assertEquals("before-3", docs.get(2).get("_id")); + assertEquals("target", docs.get(3).get("_id")); + assertEquals("after-1", docs.get(4).get("_id")); + } + + @Test + public void testRunResponseDocStructure() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + Map source = Map.of("@timestamp", "2024-01-01", "level", "INFO", "message", "test log"); + + SearchHit targetHit = createMockHit("doc-1", "logs-index", source, sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "logs-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "1"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + + Map doc = docs.get(0); + assertEquals("doc-1", doc.get("_id")); + assertEquals("logs-index", doc.get("_index")); + assertNotNull(doc.get("_score")); + assertNotNull(doc.get("sort")); + + @SuppressWarnings("unchecked") + Map returnedSource = (Map) doc.get("_source"); + assertEquals("2024-01-01", returnedSource.get("@timestamp")); + assertEquals("INFO", returnedSource.get("level")); + assertEquals("test log", returnedSource.get("message")); + } +} diff --git a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java index 76d5c72e..66cedc2f 100644 --- a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java +++ b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java @@ -65,6 +65,7 @@ public void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100); updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); updateClusterSettings("plugins.ml_commons.agent_framework_enabled", true); + updateClusterSettings("plugins.ml_commons.connector.private_ip_enabled", true); } @SneakyThrows diff --git a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java index 79154570..8c2f197a 100644 --- a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java +++ b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java @@ -15,6 +15,10 @@ import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; +import org.opensearch.ml.common.utils.StringUtils; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; import lombok.SneakyThrows; @@ -132,7 +136,7 @@ public void testDataDistributionToolSingleAnalysis() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -166,13 +170,13 @@ public void testDataDistributionToolWithFilter() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows public void testDataDistributionToolMissingRequiredParameters() { Exception exception = assertThrows(Exception.class, () -> executeAgent(agentId, "{\"parameters\": {\"index\": \"test_index\"}}")); - MatcherAssert.assertThat(exception.getMessage(), containsString("Unable to parse time string")); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid time format")); } @SneakyThrows @@ -202,7 +206,7 @@ public void testDataDistributionToolPPLSingleAnalysis() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.5},{\"value\":\"2.0\",\"selectionPercentage\":0.25},{\"value\":\"4.0\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -238,7 +242,7 @@ public void testDataDistributionToolPPLWithCustomQuery() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67},{\"value\":\"4.0\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -255,7 +259,7 @@ public void testDataDistributionToolWithDSLQueryType() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -272,7 +276,7 @@ public void testDataDistributionToolWithMultipleFilters() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -289,7 +293,7 @@ public void testDataDistributionToolWithCustomSize() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -306,7 +310,7 @@ public void testDataDistributionToolWithCustomTimeField() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -323,7 +327,7 @@ public void testDataDistributionToolWithRangeFilter() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -340,7 +344,7 @@ public void testDataDistributionToolWithMatchFilter() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -357,7 +361,7 @@ public void testDataDistributionToolWithRawDSLQuery() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -374,7 +378,7 @@ public void testDataDistributionToolWithExistsFilter() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); } @SneakyThrows @@ -425,7 +429,7 @@ public void testDataDistributionToolInvalidTimeFormat() { ) ) ); - MatcherAssert.assertThat(exception.getMessage(), containsString("Unable to parse time string")); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid time format")); } @SneakyThrows @@ -443,6 +447,30 @@ public void testDataDistributionToolPPLWithComplexQuery() { String expectedResult = "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67},{\"value\":\"4.0\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; - assertEquals(expectedResult, result); + assertResults(expectedResult, result); + } + + private void assertResults(String expectedResult, String result) { + try { + JsonNode resultJson = StringUtils.MAPPER.readTree(result); + JsonNode expectedJson = StringUtils.MAPPER.readTree(expectedResult); + JsonNode expectedAnalysis = expectedJson.get("singleAnalysis"); + JsonNode resultAnalysis = resultJson.get("singleAnalysis"); + for (int i = 0; i < expectedAnalysis.size(); i++) { + assertEquals(expectedAnalysis.get(i).get("field").asText(), resultAnalysis.get(i).get("field").asText()); + assertEquals(expectedAnalysis.get(i).get("divergence").asText(), resultAnalysis.get(i).get("divergence").asText()); + JsonNode expectedTopChanges = expectedAnalysis.get(i).get("topChanges"); + JsonNode resultTopChanges = resultAnalysis.get(i).get("topChanges"); + for (int j = 0; j < expectedTopChanges.size(); j++) { + assertEquals( + expectedTopChanges.get(j).get("selectionPercentage").asText(), + resultTopChanges.get(j).get("selectionPercentage").asText() + ); + assertEquals(expectedTopChanges.get(j).get("value").asText(), resultTopChanges.get(j).get("value").asText()); + } + } + } catch (JsonProcessingException e) { + fail("Failed to process jsons"); + } } } diff --git a/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java b/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java index 4ad50662..9ea60485 100644 --- a/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java +++ b/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java @@ -59,6 +59,9 @@ private void prepareLogIndex() { + " },\n" + " \"traceId\": {\n" + " \"type\": \"keyword\"\n" + + " },\n" + + " \"serviceName\": {\n" + + " \"type\": \"keyword\"\n" + " }\n" + " }\n" + " }\n" @@ -69,52 +72,52 @@ private void prepareLogIndex() { addDocToIndex( TEST_LOG_INDEX_NAME, "base1", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 09:30:00", "System startup completed", "trace-base-001") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 09:30:00", "System startup completed", "trace-base-001", "auth-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "base2", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 09:45:00", "Database connection established", "trace-base-002") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 09:45:00", "Database connection established", "trace-base-002", "db-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "base3", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 09:50:00", "User session initialized", "trace-base-003") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 09:50:00", "User session initialized", "trace-base-003", "auth-service") ); // Add test log data with error keywords for logInsight addDocToIndex( TEST_LOG_INDEX_NAME, "1", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:00:00", "User login successful", "trace-001") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:00:00", "User login successful", "trace-001", "auth-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "2", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:01:00", "Database connection established", "trace-001") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:01:00", "Database connection established", "trace-001", "db-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "3", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:02:00", "Error connection timeout failed", "trace-002") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:02:00", "Error connection timeout failed", "trace-002", "db-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "4", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:03:00", "User logout completed", "trace-001") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:03:00", "User logout completed", "trace-001", "auth-service") ); addDocToIndex( TEST_LOG_INDEX_NAME, "5", - List.of("@timestamp", "message", "traceId"), - List.of("2025-01-01 10:04:00", "Exception in authentication service", "trace-003") + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:04:00", "Exception in authentication service", "trace-003", "auth-service") ); } @@ -206,6 +209,51 @@ public void testLogPatternAnalysisToolInvalidTimeFormat() { MatcherAssert.assertThat(exception.getMessage(), containsString("not a valid term")); } + @SneakyThrows + public void testLogPatternAnalysisToolLogInsightWithFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\", \"filter\": \"serviceName='db-service'\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("logInsights")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithBaseTimeRangeAndFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\", \"filter\": \"serviceName='auth-service'\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("patternMapDifference")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithTraceFieldAndFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"traceFieldName\": \"traceId\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\", \"filter\": \"serviceName='auth-service'\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("BASE") || result.contains("EXCEPTIONAL")); + } + @SneakyThrows public void testLogPatternAnalysisToolEmptyTimeRange() { Exception exception = assertThrows( diff --git a/src/test/java/org/opensearch/integTest/PPLToolIT.java b/src/test/java/org/opensearch/integTest/PPLToolIT.java index 3d6120ee..9ca930aa 100644 --- a/src/test/java/org/opensearch/integTest/PPLToolIT.java +++ b/src/test/java/org/opensearch/integTest/PPLToolIT.java @@ -54,9 +54,9 @@ public void testPPLTool() { String agentId = registerAgent(); String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee\"}}"); assertEquals( - "{\"ppl\":\"source\\u003demployee | where age \\u003e 56 | stats COUNT() as cnt\"," + "{\"ppl\":\"source=employee | where age > 56 | stats COUNT() as cnt\"," + "\"executionResult\":\"{\\n \\\"schema\\\": [\\n {\\n \\\"name\\\": \\\"cnt\\\",\\n " - + "\\\"type\\\": \\\"int\\\"\\n }\\n ],\\n \\\"datarows\\\": [\\n [\\n 0\\n ]\\n ],\\n " + + "\\\"type\\\": \\\"bigint\\\"\\n }\\n ],\\n \\\"datarows\\\": [\\n [\\n 0\\n ]\\n ],\\n " + "\\\"total\\\": 1,\\n \\\"size\\\": 1\\n}\"}", result ); @@ -114,7 +114,7 @@ public void testPPLTool_withNonExistingIndex_thenThrowException() { exception.getMessage(), allOf( containsString( - "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + "Return this final answer to human directly and do not use other tools: 'Please provide the existing index name(s)'. Please try to directly send this message to human to ask for index name" ) ) ); diff --git a/src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java b/src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java new file mode 100644 index 00000000..50f45b29 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java @@ -0,0 +1,280 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; + +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; + +import lombok.SneakyThrows; + +public class SearchAroundDocumentToolIT extends BaseAgentToolsIT { + + private static final String TEST_INDEX_NAME = "test_search_around_document_index"; + private static final String REGISTER_AGENT_RESOURCE = + "org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json"; + + private String registerAgentRequestBody; + private String agentId; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(REGISTER_AGENT_RESOURCE).toURI())); + prepareDataIndex(); + agentId = createAgent(registerAgentRequestBody); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + @SneakyThrows + private void prepareDataIndex() { + createIndexWithConfiguration(TEST_INDEX_NAME, """ + { + "mappings": { + "properties": { + "@timestamp": { + "type": "date", + "format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis" + }, + "message": { + "type": "text" + }, + "level": { + "type": "keyword" + } + } + } + }"""); + + // Index 7 documents with known timestamps and IDs + addDocToIndex( + TEST_INDEX_NAME, + "doc1", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:00:00", "First log entry", "INFO") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc2", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:10:00", "Second log entry", "INFO") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc3", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:20:00", "Third log entry", "WARN") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc4", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:30:00", "Fourth log entry - target", "ERROR") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc5", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:40:00", "Fifth log entry", "WARN") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc6", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:50:00", "Sixth log entry", "INFO") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc7", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 10:00:00", "Seventh log entry", "ERROR") + ); + } + + @SneakyThrows + public void testSearchAroundDocument_basicSearch() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc4\", \"timestamp_field\": \"@timestamp\", \"count\": \"2\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // Should have 5 documents: 2 before + target + 2 after + assertEquals(5, docs.size()); + + // Verify chronological order: doc2, doc3, doc4 (target), doc5, doc6 + assertEquals("doc2", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc3", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc4", docs.get(2).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(3).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc6", docs.get(4).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_countExceedsAvailable() { + // doc1 is the first document, requesting 5 before but only 0 exist before it + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc1\", \"timestamp_field\": \"@timestamp\", \"count\": \"5\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // Should have target + up to 5 after (doc2-doc7 = 6 after, but count=5) + // No before docs since doc1 is the earliest + assertEquals(6, docs.size()); + + // First should be the target (doc1), followed by 5 after docs + assertEquals("doc1", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc2", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc3", docs.get(2).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc4", docs.get(3).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(4).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc6", docs.get(5).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_lastDocument() { + // doc7 is the last document, requesting 3 before + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc7\", \"timestamp_field\": \"@timestamp\", \"count\": \"3\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // Should have 3 before + target, no after docs + assertEquals(4, docs.size()); + + assertEquals("doc4", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc6", docs.get(2).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc7", docs.get(3).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_countOfOne() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc4\", \"timestamp_field\": \"@timestamp\", \"count\": \"1\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // 1 before + target + 1 after = 3 + assertEquals(3, docs.size()); + + assertEquals("doc3", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc4", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(2).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_jsonInput() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"input\": \"{\\\"index\\\": \\\"%s\\\", \\\"doc_id\\\": \\\"doc4\\\", \\\"timestamp_field\\\": \\\"@timestamp\\\", \\\"count\\\": 2}\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + assertEquals(5, docs.size()); + assertEquals("doc4", docs.get(2).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_responseContainsSource() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc4\", \"timestamp_field\": \"@timestamp\", \"count\": \"1\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + // Verify the target document has _source with correct fields + JsonElement targetDoc = docs.get(1); + assertTrue(targetDoc.getAsJsonObject().has("_source")); + assertTrue(targetDoc.getAsJsonObject().has("_id")); + assertTrue(targetDoc.getAsJsonObject().has("_index")); + + String source = targetDoc.getAsJsonObject().get("_source").toString(); + assertTrue(source.contains("Fourth log entry - target")); + assertTrue(source.contains("ERROR")); + } + + @SneakyThrows + public void testSearchAroundDocument_nonExistentDoc() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"nonexistent\", \"timestamp_field\": \"@timestamp\", \"count\": \"2\"}}", + TEST_INDEX_NAME + ) + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("Document not found")); + } + + @SneakyThrows + public void testSearchAroundDocument_missingRequiredParameters() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent(agentId, String.format(Locale.ROOT, "{\"parameters\": {\"index\": \"%s\"}}", TEST_INDEX_NAME)) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("requires")); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json new file mode 100644 index 00000000..c97af635 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json @@ -0,0 +1,9 @@ +{ + "name": "Test_Search_Around_Document_Agent", + "type": "flow", + "tools": [ + { + "type": "SearchAroundDocumentTool" + } + ] +} \ No newline at end of file