diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c098e120..c3edc539 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: needs: Get-CI-Image-Tag strategy: matrix: - java: [21, 23] + java: [21, 25] name: Build and Test skills plugin on Linux runs-on: ubuntu-latest container: @@ -52,7 +52,7 @@ jobs: build-MacOS: strategy: matrix: - java: [21, 23] + java: [21, 25] name: Build and Test skills Plugin on MacOS needs: Get-CI-Image-Tag @@ -77,7 +77,7 @@ jobs: build-windows: strategy: matrix: - java: [21, 23] + java: [21, 25] name: Build and Test skills plugin on Windows needs: Get-CI-Image-Tag runs-on: windows-latest diff --git a/.github/workflows/delete_backport_branch.yml b/.github/workflows/delete_backport_branch.yml index f24f022b..be2dffd5 100644 --- a/.github/workflows/delete_backport_branch.yml +++ b/.github/workflows/delete_backport_branch.yml @@ -7,9 +7,16 @@ on: jobs: delete-branch: runs-on: ubuntu-latest - if: startsWith(github.event.pull_request.head.ref,'backport/') + permissions: + contents: write + if: startsWith(github.event.pull_request.head.ref,'backport/') || startsWith(github.event.pull_request.head.ref,'release-chores/') steps: - - name: Delete merged branch - uses: SvanBoxel/delete-merged-branch@main - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Delete merged branch + uses: actions/github-script@v7 + with: + script: | + github.rest.git.deleteRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: `heads/${context.payload.pull_request.head.ref}`, + }) \ No newline at end of file diff --git a/.github/workflows/maven-publish.yml b/.github/workflows/maven-publish.yml index d07e5b31..8b73d17d 100644 --- a/.github/workflows/maven-publish.yml +++ b/.github/workflows/maven-publish.yml @@ -24,14 +24,23 @@ jobs: distribution: temurin # Temurin is a distribution of adoptium java-version: 21 - uses: actions/checkout@v3 - - uses: aws-actions/configure-aws-credentials@v4 + + - name: Load secret + uses: 1password/load-secrets-action@v2 + with: + # Export loaded secrets as environment variables + export-env: true + env: + OP_SERVICE_ACCOUNT_TOKEN: ${{ secrets.OP_SERVICE_ACCOUNT_TOKEN }} + MAVEN_SNAPSHOTS_S3_REPO: op://opensearch-infra-secrets/maven-snapshots-s3/repo + MAVEN_SNAPSHOTS_S3_ROLE: op://opensearch-infra-secrets/maven-snapshots-s3/role + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v5 with: - role-to-assume: ${{ secrets.PUBLISH_SNAPSHOTS_ROLE }} + role-to-assume: ${{ env.MAVEN_SNAPSHOTS_S3_ROLE }} aws-region: us-east-1 + - name: publish snapshots to maven run: | - export SONATYPE_USERNAME=$(aws secretsmanager get-secret-value --secret-id maven-snapshots-username --query SecretString --output text) - export SONATYPE_PASSWORD=$(aws secretsmanager get-secret-value --secret-id maven-snapshots-password --query SecretString --output text) - echo "::add-mask::$SONATYPE_USERNAME" - echo "::add-mask::$SONATYPE_PASSWORD" ./gradlew publishPluginZipPublicationToSnapshotsRepository diff --git a/.github/workflows/test_security.yml b/.github/workflows/test_security.yml index 8821566d..63c99b29 100644 --- a/.github/workflows/test_security.yml +++ b/.github/workflows/test_security.yml @@ -16,7 +16,7 @@ jobs: integ-test-with-security-linux: strategy: matrix: - java: [21, 23] + java: [21, 25] env: ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true name: Run Security Integration Tests on Linux diff --git a/.gitignore b/.gitignore index 722e14d1..21b99ea6 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ out/ .settings .vscode bin/ +.factorypath diff --git a/build.gradle b/build.gradle index 09169432..be0d853d 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ import com.github.jengelman.gradle.plugins.shadow.ShadowPlugin buildscript { ext { opensearch_group = "org.opensearch" - opensearch_version = System.getProperty("opensearch.version", "3.1.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.6.1-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") isSnapshot = "true" == System.getProperty("build.snapshot", "true") version_tokens = opensearch_version.tokenize('-') @@ -31,7 +31,7 @@ buildscript { repositories { mavenLocal() - maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } + maven { url "https://ci.opensearch.org/ci/dbc/snapshots/maven/" } maven { url "https://plugins.gradle.org/m2/" } mavenCentral() } @@ -49,14 +49,14 @@ plugins { } lombok { - version = "1.18.34" + version = "1.18.38" } repositories { mavenLocal() mavenCentral() maven { url "https://plugins.gradle.org/m2/" } - maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } + maven { url "https://ci.opensearch.org/ci/dbc/snapshots/maven/" } } allprojects { @@ -136,26 +136,36 @@ dependencies { compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.23.1" compileOnly group: 'org.json', name: 'json', version: '20240303' compileOnly("com.google.guava:guava:33.2.1-jre") - compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.16.0' + compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: "${versions.commonslang}" compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.12.0' - compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") + compileOnly group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' + compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson_annotations}") compileOnly("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") compileOnly(group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: "${versions.httpcore5}") compileOnly(group: 'org.apache.httpcomponents.client5', name: 'httpclient5', version: "${versions.httpclient5}") compileOnly ('com.jayway.jsonpath:json-path:2.9.0') { exclude group: 'net.minidev', module: 'json-smart' } + compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}") { + exclude(group: 'org.reactivestreams', module: 'reactive-streams') + exclude(group: 'org.slf4j', module: 'slf4j-api') + } + compileOnly(group: 'software.amazon.awssdk', name: 'http-client-spi', version: "${versions.aws}") + compileOnly(group: 'software.amazon.awssdk', name: 'utils', version: "${versions.aws}") + compileOnly(group: 'software.amazon.awssdk', name: 'sdk-core', version: "${versions.aws}") - spark 'org.apache.spark:spark-sql-api_2.13:3.5.4' - spark 'org.apache.spark:spark-core_2.13:3.5.4' - spark group: 'org.apache.spark', name: 'spark-common-utils_2.13', version: '3.5.4' + spark 'org.apache.spark:spark-sql-api_2.13:3.5.8' + spark ('org.apache.spark:spark-core_2.13:3.5.8') { + exclude group: 'org.eclipse.jetty', module: 'jetty-server' + } + spark group: 'org.apache.spark', name: 'spark-common-utils_2.13', version: '3.5.8' implementation 'org.scala-lang:scala-library:2.13.9' implementation group: 'org.antlr', name: 'antlr4-runtime', version: '4.9.3' implementation("org.json4s:json4s-ast_2.13:3.7.0-M11") implementation("org.json4s:json4s-core_2.13:3.7.0-M11") implementation("org.json4s:json4s-jackson_2.13:3.7.0-M11") - implementation 'com.fasterxml.jackson.module:jackson-module-scala_3:2.18.2' + implementation "com.fasterxml.jackson.module:jackson-module-scala_3:${versions.jackson}" implementation group: 'org.scala-lang', name: 'scala3-library_3', version: '3.7.0-RC1-bin-20250119-bd699fc-NIGHTLY' implementation("com.thoughtworks.paranamer:paranamer:2.8") implementation("org.jsoup:jsoup:1.19.1") @@ -165,7 +175,7 @@ dependencies { compileOnly group: 'org.opensearch', name:'opensearch-ml-spi', version: "${opensearch_build}" compileOnly fileTree(dir: jsJarDirectory, include: ["opensearch-job-scheduler-${opensearch_build}.jar"]) implementation fileTree(dir: adJarDirectory, include: ["opensearch-anomaly-detection-${opensearch_build}.jar"]) - implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-thin-${opensearch_build}.jar", "ppl-${opensearch_build}.jar", "protocol-${opensearch_build}.jar"]) + implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-thin-${opensearch_build}.jar", "ppl-${opensearch_build}.jar", "protocol-${opensearch_build}.jar", "core-${opensearch_build}.jar"]) implementation fileTree(dir: sparkDir, include: ["spark*.jar"]) compileOnly "org.opensearch:common-utils:${opensearch_build}" compileOnly "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}" @@ -189,13 +199,15 @@ dependencies { testImplementation group: 'org.json', name: 'json', version: '20240303' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.14.2' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0' - testImplementation("net.bytebuddy:byte-buddy:1.17.5") + testImplementation("net.bytebuddy:byte-buddy:1.17.7") testImplementation("net.bytebuddy:byte-buddy-agent:1.17.5") testImplementation 'org.junit.jupiter:junit-jupiter-api:5.11.2' testImplementation 'org.mockito:mockito-junit-jupiter:5.14.2' testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0" testImplementation "com.cronutils:cron-utils:9.2.1" testImplementation 'com.jayway.jsonpath:json-path:2.9.0' + testImplementation('net.minidev:json-smart:2.5.2') + testImplementation 'net.minidev:asm:1.0.2' testImplementation "commons-validator:commons-validator:1.8.0" testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.11.2' } @@ -205,9 +217,9 @@ task addSparkJar(type: Copy) { into sparkDir doLast { - def jarA = file("$sparkDir/spark-sql-api_2.13-3.5.4.jar") - def jarB = file("$sparkDir/spark-core_2.13-3.5.4.jar") - def jarC = file("$sparkDir/spark-common-utils_2.13-3.5.4.jar") + def jarA = file("$sparkDir/spark-sql-api_2.13-3.5.8.jar") + def jarB = file("$sparkDir/spark-core_2.13-3.5.8.jar") + def jarC = file("$sparkDir/spark-common-utils_2.13-3.5.8.jar") // 3a. Extract jar A to manipulate it def jarAContents = file("$buildDir/tmp/JarAContents") @@ -234,7 +246,9 @@ task addSparkJar(type: Copy) { // Remove the unwanted directory from jar B delete file("${jarBContents}/org/apache/spark/unused") delete file("${jarBContents}/org/sparkproject/jetty/http") + delete file("${jarBContents}/org/sparkproject/jetty/server") delete file("${jarBContents}/META-INF/maven/org.eclipse.jetty/jetty-http") + delete file("${jarBContents}/META-INF/maven/org.eclipse.jetty/jetty-server") // Re-compress jar B ant.zip(destfile: jarB, baseDir: jarBContents) @@ -432,10 +446,11 @@ publishing { repositories { maven { name = "Snapshots" - url = "https://aws.oss.sonatype.org/content/repositories/snapshots" - credentials { - username "$System.env.SONATYPE_USERNAME" - password "$System.env.SONATYPE_PASSWORD" + url = System.getenv("MAVEN_SNAPSHOTS_S3_REPO") + credentials(AwsCredentials) { + accessKey = System.getenv("AWS_ACCESS_KEY_ID") + secretKey = System.getenv("AWS_SECRET_ACCESS_KEY") + sessionToken = System.getenv("AWS_SESSION_TOKEN") } } } diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 247cf2a9..b11741a1 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=61ad310d3c7d3e5da131b76bbf22b5a4c0786e9d892dae8c1658d4b484de3caa -distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-bin.zip +distributionSha256Sum=16f2b95838c1ddcf7242b1c39e7bbbb43c842f1f1a1a0dc4959b6d4d68abcac3 +distributionUrl=https\://services.gradle.org/distributions/gradle-9.2.0-all.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/release-notes/opensearch-skills.relase-notes-3.1.0.0.md b/release-notes/opensearch-skills.release-notes-3.1.0.0.md similarity index 100% rename from release-notes/opensearch-skills.relase-notes-3.1.0.0.md rename to release-notes/opensearch-skills.release-notes-3.1.0.0.md diff --git a/release-notes/opensearch-skills.release-notes-3.2.0.0.md b/release-notes/opensearch-skills.release-notes-3.2.0.0.md new file mode 100644 index 00000000..16f7e946 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.2.0.0.md @@ -0,0 +1,15 @@ +## Version 3.2.0.0 Release Notes + +Compatible with OpenSearch and OpenSearch Dashboards version 3.2.0.0 + +### Enhancements +* Merge index schema meta ([#596](https://github.com/opensearch-project/skills/pull/596)) +* Mask error message in PPLTool ([#609](https://github.com/opensearch-project/skills/pull/609)) + +### Bug Fixes +* Update parameter handling of tools ([#618](https://github.com/opensearch-project/skills/pull/618)) + +### Maintenance +* Update the maven snapshot publish endpoint and credential ([#601](https://github.com/opensearch-project/skills/pull/601)) +* Bump gradle, java, lombok and fix ad configrequest change ([#615](https://github.com/opensearch-project/skills/pull/615)) +* Bump version to 3.2.0.0 ([#605](https://github.com/opensearch-project/skills/pull/605)) diff --git a/release-notes/opensearch-skills.release-notes-3.3.0.0.md b/release-notes/opensearch-skills.release-notes-3.3.0.0.md new file mode 100644 index 00000000..8a549a08 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.3.0.0.md @@ -0,0 +1,20 @@ +## Version 3.3.0 Release Notes + +Compatible with OpenSearch and OpenSearch Dashboards version 3.3.0 + +### Features +* Log patterns analysis tool ([#625](https://github.com/opensearch-project/skills/pull/625)) +* Data Distribution Tool ([#634](https://github.com/opensearch-project/skills/pull/634)) + +### Enhancements +* Add more information in ppl tool when passing to sagemaker ([#636](https://github.com/opensearch-project/skills/pull/636)) + +### Bug Fixes +* Delete-single-baseline ([#641](https://github.com/opensearch-project/skills/pull/641)) +* Fix WebSearchTool issue ([#639](https://github.com/opensearch-project/skills/pull/639)) + +### Infrastructure +* Update System.env syntax for Gradle 9 compatibility ([#630](https://github.com/opensearch-project/skills/pull/630)) + +### Maintenance +* Increment version to 3.3.0-SNAPSHOT ([#626](https://github.com/opensearch-project/skills/pull/626)) \ No newline at end of file diff --git a/release-notes/opensearch-skills.release-notes-3.3.2.0.md b/release-notes/opensearch-skills.release-notes-3.3.2.0.md new file mode 100644 index 00000000..dfddd694 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.3.2.0.md @@ -0,0 +1,6 @@ +## Version 3.3.2 Release Notes + +Compatible with OpenSearch 3.3.2 and OpenSearch Dashboards 3.3.0 + +### Bug Fixes +* Fix regex bypass issue ([#656](https://github.com/opensearch-project/skills/pull/656)) diff --git a/release-notes/opensearch-skills.release-notes-3.6.0.0.md b/release-notes/opensearch-skills.release-notes-3.6.0.0.md new file mode 100644 index 00000000..6982daac --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-3.6.0.0.md @@ -0,0 +1,17 @@ +## Version 3.6.0 Release Notes + +Compatible with OpenSearch and OpenSearch Dashboards version 3.6.0 + +### Features + +* Add SearchAroundTool to search N documents around a given document ([#702](https://github.com/opensearch-project/skills/pull/702)) +* Add MetricChangeAnalysisTool for detecting and analyzing metric changes via percentile comparison between baseline and selection periods ([#698](https://github.com/opensearch-project/skills/pull/698)) + +### Enhancements + +* Add filter support for LogPatternAnalysisTool to enable log pattern analysis for specific services ([#707](https://github.com/opensearch-project/skills/pull/707)) +* Update default tool descriptions for LogPatternAnalysisTool and DataDistributionTool to improve clarity for LLM usage ([#703](https://github.com/opensearch-project/skills/pull/703)) + +### Maintenance + +* Update Apache Spark dependencies (spark-common-utils_2.13) from 3.5.4 to 3.5.8 ([#713](https://github.com/opensearch-project/skills/pull/713)) diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index bdf82853..a8faae88 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -8,23 +8,32 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import org.opensearch.agent.tools.CreateAlertTool; import org.opensearch.agent.tools.CreateAnomalyDetectorTool; +import org.opensearch.agent.tools.DataDistributionTool; +import org.opensearch.agent.tools.LogPatternAnalysisTool; import org.opensearch.agent.tools.LogPatternTool; +import org.opensearch.agent.tools.MetricChangeAnalysisTool; import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; import org.opensearch.agent.tools.RAGTool; import org.opensearch.agent.tools.SearchAlertsTool; import org.opensearch.agent.tools.SearchAnomalyDetectorsTool; import org.opensearch.agent.tools.SearchAnomalyResultsTool; +import org.opensearch.agent.tools.SearchAroundDocumentTool; import org.opensearch.agent.tools.SearchMonitorsTool; import org.opensearch.agent.tools.VectorDBTool; import org.opensearch.agent.tools.WebSearchTool; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -32,8 +41,11 @@ import org.opensearch.env.NodeEnvironment; import org.opensearch.ml.common.spi.MLCommonsExtension; import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.repositories.RepositoriesService; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; @@ -41,19 +53,28 @@ import org.opensearch.transport.client.Client; import org.opensearch.watcher.ResourceWatcherService; -import com.google.common.collect.ImmutableList; - import lombok.SneakyThrows; -public class ToolPlugin extends Plugin implements MLCommonsExtension { - - private Client client; - private ClusterService clusterService; - private NamedXContentRegistry xContentRegistry; +public class ToolPlugin extends Plugin implements MLCommonsExtension, ActionPlugin { + private final AtomicReference restControllerRef = new AtomicReference<>(); public static final String SKILLS_THREAD_POOL_PREFIX = "thread_pool.skills"; public static final String WEBSEARCH_CRAWLER_THREADPOOL = "websearch-crawler-threadpool"; + @Override + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { + restControllerRef.set(restController); + return Collections.emptyList(); + } + @SneakyThrows @Override public Collection createComponents( @@ -69,9 +90,6 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - this.client = client; - this.clusterService = clusterService; - this.xContentRegistry = xContentRegistry; PPLTool.Factory.getInstance().init(client); NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry); VectorDBTool.Factory.getInstance().init(client, xContentRegistry); @@ -84,6 +102,10 @@ public Collection createComponents( CreateAnomalyDetectorTool.Factory.getInstance().init(client); LogPatternTool.Factory.getInstance().init(client, xContentRegistry); WebSearchTool.Factory.getInstance().init(threadPool); + LogPatternAnalysisTool.Factory.getInstance().init(client); + DataDistributionTool.Factory.getInstance().init(client); + SearchAroundDocumentTool.Factory.getInstance().init(client, xContentRegistry); + MetricChangeAnalysisTool.Factory.getInstance().init(client); return Collections.emptyList(); } @@ -102,7 +124,11 @@ public List> getToolFactories() { CreateAlertTool.Factory.getInstance(), CreateAnomalyDetectorTool.Factory.getInstance(), LogPatternTool.Factory.getInstance(), - WebSearchTool.Factory.getInstance() + WebSearchTool.Factory.getInstance(), + LogPatternAnalysisTool.Factory.getInstance(), + DataDistributionTool.Factory.getInstance(), + SearchAroundDocumentTool.Factory.getInstance(), + MetricChangeAnalysisTool.Factory.getInstance() ); } @@ -117,7 +143,7 @@ public List> getExecutorBuilders(Settings settings) { false ); - return ImmutableList.of(websearchCrawlThread); + return List.of(websearchCrawlThread); } } diff --git a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java index dd713ae6..4865e70a 100644 --- a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java +++ b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java @@ -22,6 +22,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.transport.client.Client; @@ -94,7 +95,8 @@ protected SearchRequest buildSearchRequest(Map parameters) t } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); SearchRequest searchRequest; try { searchRequest = buildSearchRequest(parameters); diff --git a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java index 368807c0..3e7096a5 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAlertTool.java @@ -39,6 +39,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.transport.client.Client; import com.google.gson.reflect.TypeToken; @@ -133,7 +134,8 @@ public boolean validate(Map parameters) { } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); Map tmpParams = new HashMap<>(parameters); if (!tmpParams.containsKey("indices") || Strings.isEmpty(tmpParams.get("indices"))) { throw new IllegalArgumentException( diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index 2beb4cf4..2c4a2273 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -44,6 +44,7 @@ import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.transport.client.Client; import com.google.common.collect.ImmutableMap; @@ -169,6 +170,7 @@ public CreateAnomalyDetectorTool(Client client, String modelId, String modelType */ @Override public void run(Map parameters, ActionListener listener) { + parameters = ToolUtils.extractInputParameters(parameters, attributes); final String tenantId = parameters.get(TENANT_ID_FIELD); Map enrichedParameters = enrichParameters(parameters); String indexName = enrichedParameters.get("index"); diff --git a/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java new file mode 100644 index 00000000..b133f999 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/DataDistributionTool.java @@ -0,0 +1,985 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.agent.tools.DataFetchingHelper.DATE_FORMAT_PATTERN; +import static org.opensearch.agent.tools.DataFetchingHelper.NUMBER_FIELD_TYPES; +import static org.opensearch.agent.tools.DataFetchingHelper.QUERY_TYPE_PPL; +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.time.LocalDateTime; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.math.NumberUtils; +import org.opensearch.agent.tools.utils.PPLExecuteHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.transport.client.Client; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Usage: + * 1. Register agent: + * POST /_plugins/_ml/agents/_register + * { + * "name": "DataDistribution", + * "type": "flow", + * "tools": [ + * { + * "name": "data_distribution_tool", + * "type": "DataDistributionTool", + * "parameters": { + * } + * } + * ] + * } + * 2. Execute agent: + * POST /_plugins/_ml/agents/{agent_id}/_execute + * { + * "parameters": { + * "index": "logs-2025.01.15", + * "timeField": "@timestamp", + * "selectionTimeRangeStart": "2025-01-15 10:00:00", + * "selectionTimeRangeEnd": "2025-01-15 11:00:00", + * "baselineTimeRangeStart": "2025-01-15 08:00:00", + * "baselineTimeRangeEnd": "2025-01-15 09:00:00", + * "size": 1000, + * "queryType": "dsl", + * "filter": ["{'term': {'status': 'error'}}", "{'range': {'response_time': {'gte': 100}}}"], + * "dsl": "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}]}}", + * "ppl": "source index where a=0" + * } + * } + * 3. Result: analysis of data distribution patterns + * { + * "comparisonAnalysis": [ + * { + * "field": "status", + * "divergence": 0.2, + * "topChanges": [ + * { + * "value": "error", + * "selectionPercentage": 0.3, + * "baselinePercentage": 0.1 + * }, + * { + * "value": "success", + * "selectionPercentage": 0.7, + * "baselinePercentage": 0.9 + * } + * ] + * } + * ] + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(DataDistributionTool.TYPE) +public class DataDistributionTool implements Tool { + public static final String TYPE = "DataDistributionTool"; + public static final String STRICT_FIELD = "strict"; + + private static final String DEFAULT_DESCRIPTION = + "Analyzes field value distributions in a target time range, optionally compared to a baseline. " + + "Use to identify which fields changed most when investigating anomalies. " + + "Two modes: (1) Comparison (baseline provided): ranks fields by divergence. " + + "(2) Single (no baseline): summarizes field distributions."; + + private static final Set USEFUL_FIELD_TYPES = Set + .of("keyword", "boolean", "text", "byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); + + private static final int DEFAULT_COMPARISON_RESULT_LIMIT = 10; + private static final int DEFAULT_SINGLE_ANALYSIS_RESULT_LIMIT = 30; + private static final int MIN_CARDINALITY_DIVISOR = 4; + private static final int MIN_CARDINALITY_BASE = 5; + private static final int ID_FIELD_MAX_CARDINALITY = 30; + private static final int DATA_FIELD_MAX_CARDINALITY = 10; + private static final int DATA_FIELD_CARDINALITY_DIVISOR = 2; + private static final int NUMERIC_GROUPING_THRESHOLD = 10; + private static final double PERCENTAGE_MULTIPLIER = 100.0; + private static final int TOP_CHANGES_LIMIT = 10; + private static final int MAX_SIZE_LIMIT = 10000; + + public static final String DEFAULT_INPUT_SCHEMA = """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "Target OpenSearch index name" + }, + "timeField": { + "type": "string", + "description": "Date/time field for filtering" + }, + "selectionTimeRangeStart": { + "type": "string", + "description": "Start of target period (format: yyyy-MM-dd HH:mm:ss)" + }, + "selectionTimeRangeEnd": { + "type": "string", + "description": "End of target period (format: yyyy-MM-dd HH:mm:ss)" + }, + "baselineTimeRangeStart": { + "type": "string", + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baselineTimeRangeEnd" + }, + "baselineTimeRangeEnd": { + "type": "string", + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baselineTimeRangeStart" + }, + "size": { + "type": "integer", + "description": "Max documents to sample (default: 1000, max: 10000)" + }, + "queryType": { + "type": "string", + "description": "Query type: 'dsl' (default) or 'ppl'" + }, + "filter": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Additional DSL filter clauses as JSON strings" + }, + "dsl": { + "type": "string", + "description": "Complete DSL query as JSON string" + }, + "ppl": { + "type": "string", + "description": "PPL query without time filtering (added automatically)" + } + }, + "required": ["index", "timeField", "selectionTimeRangeStart", "selectionTimeRangeEnd"], + "additionalProperties": false + } + """; + + public static final Map DEFAULT_ATTRIBUTES = Map + .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class)), STRICT_FIELD, false); + + /** + * Result class for data distribution analysis + */ + private record SummaryDataItem(String field, double divergence, List topChanges) { + } + + /** + * Individual change item for field values + */ + private record ChangeItem(String value, double selectionPercentage, Double baselinePercentage) { + } + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String version; + private Client client; + private DataFetchingHelper dataFetchingHelper; + + /** + * Constructs a DataDistributionTool with the given OpenSearch client + * + * @param client The OpenSearch client for executing queries + */ + public DataDistributionTool(Client client) { + this.client = client; + this.dataFetchingHelper = new DataFetchingHelper(client); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public Map getAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public void setAttributes(Map map) {} + + @Override + public boolean validate(Map map) { + try { + new DataFetchingHelper.AnalysisParameters(map).validate(); + } catch (Exception e) { + log.error("Failed to validate the data distribution analysis parameter: {}", e.getMessage()); + return false; + } + return true; + } + + /** + * Executes data distribution analysis based on provided parameters. + * Supports both single dataset analysis and comparative analysis between time periods. + * + * @param The response type + * @param originalParameters Input parameters for analysis + * @param listener Action listener for handling results or failures + */ + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, DEFAULT_ATTRIBUTES); + log.debug("Starting data distribution analysis with parameters: {}", parameters.keySet()); + DataFetchingHelper.AnalysisParameters params = new DataFetchingHelper.AnalysisParameters(parameters); + + if (QUERY_TYPE_PPL.equals(params.queryType)) { + executePPLAnalysis(params, listener); + } else { + executeDSLAnalysis(params, listener); + } + } catch (IllegalArgumentException e) { + log.error("Invalid parameters for DataDistributionTool: {}", e.getMessage()); + listener.onFailure(e); + } catch (Exception e) { + log.error("Unexpected error in DataDistributionTool", e); + listener.onFailure(e); + } + } + + /** + * Executes analysis using PPL (Piped Processing Language) queries + * + * @param The response type + * @param params Analysis parameters containing query details + * @param listener Action listener for handling results + */ + private void executePPLAnalysis(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { + if (params.hasBaselineTimeRange()) { + fetchPPLComparisonData(params, listener); + } else { + String pplQuery = buildPPLQuery( + params.index, + params.timeField, + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd, + params.size, + params.ppl + ); + + Function, List>> pplResultParser = this::parsePPLResult; + + PPLExecuteHelper.executePPLAndParseResult(client, pplQuery, pplResultParser, ActionListener.wrap(data -> { + try { + analyzeSingleDataset(data, params.index, ActionListener.wrap(result -> { + listener.onResponse((T) gson.toJson(Map.of("singleAnalysis", result))); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + } + } + + /** + * Executes analysis using DSL (Domain Specific Language) queries + * + * @param The response type + * @param params Analysis parameters containing query details + * @param listener Action listener for handling results + */ + private void executeDSLAnalysis(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { + if (params.hasBaselineTimeRange()) { + fetchComparisonData(params, listener); + } else { + getSingleDataDistribution(params, listener); + } + } + + /** + * Fetches data for both selection and baseline time ranges for comparison analysis + * + * @param The response type + * @param params Analysis parameters containing time ranges + * @param listener Action listener for handling comparison results + */ + private void fetchComparisonData(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { + fetchIndexData(params.selectionTimeRangeStart, params.selectionTimeRangeEnd, params, ActionListener.wrap(selectionData -> { + fetchIndexData(params.baselineTimeRangeStart, params.baselineTimeRangeEnd, params, ActionListener.wrap(baselineData -> { + try { + if (selectionData.isEmpty()) { + throw new IllegalStateException("No data found for selection time range"); + } + if (baselineData.isEmpty()) { + throw new IllegalStateException("No data found for baseline time range"); + } + getComparisonDataDistribution(selectionData, baselineData, params.index, ActionListener.wrap(result -> { + listener.onResponse((T) gson.toJson(Map.of("comparisonAnalysis", result))); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + }, listener::onFailure)); + } + + /** + * Performs single dataset distribution analysis for the selection time range + * + * @param The response type + * @param params Analysis parameters containing selection time range + * @param listener Action listener for handling single analysis results + */ + private void getSingleDataDistribution(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { + fetchIndexData(params.selectionTimeRangeStart, params.selectionTimeRangeEnd, params, ActionListener.wrap(data -> { + try { + if (data.isEmpty()) { + throw new IllegalStateException("No data found for selection time range"); + } + analyzeSingleDataset(data, params.index, ActionListener.wrap(result -> { + listener.onResponse((T) gson.toJson(Map.of("singleAnalysis", result))); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + } + + /** + * Fetches data from the specified index within the given time range + * + * @param startTime Start time for data retrieval + * @param endTime End time for data retrieval + * @param params Analysis parameters containing index and field information + * @param listener Action listener for handling retrieved data + */ + private void fetchIndexData( + String startTime, + String endTime, + DataFetchingHelper.AnalysisParameters params, + ActionListener>> listener + ) { + // Convert AnalysisParameters to helper parameters + Map helperParams = new HashMap<>(); + helperParams.put("index", params.index); + helperParams.put("timeField", params.timeField); + helperParams.put("size", String.valueOf(params.size)); + helperParams.put("queryType", params.queryType); + if (!Strings.isEmpty(params.dsl)) { + helperParams.put("dsl", params.dsl); + } + if (!params.filter.isEmpty()) { + helperParams.put("filter", gson.toJson(params.filter)); + } + if (!Strings.isEmpty(params.ppl)) { + helperParams.put("ppl", params.ppl); + } + + DataFetchingHelper.AnalysisParameters helperAnalysisParams = new DataFetchingHelper.AnalysisParameters(helperParams); + + dataFetchingHelper.fetchIndexData(startTime, endTime, helperAnalysisParams, listener); + } + + /** + * Fetches data for both selection and baseline time ranges using PPL for comparison analysis + * + * @param The response type + * @param params Analysis parameters containing time ranges + * @param listener Action listener for handling comparison results + */ + private void fetchPPLComparisonData(DataFetchingHelper.AnalysisParameters params, ActionListener listener) { + String selectionQuery = buildPPLQuery( + params.index, + params.timeField, + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd, + params.size, + params.ppl + ); + String baselineQuery = buildPPLQuery( + params.index, + params.timeField, + params.baselineTimeRangeStart, + params.baselineTimeRangeEnd, + params.size, + params.ppl + ); + + Function, List>> pplResultParser = this::parsePPLResult; + + PPLExecuteHelper.executePPLAndParseResult(client, selectionQuery, pplResultParser, ActionListener.wrap(selectionData -> { + PPLExecuteHelper.executePPLAndParseResult(client, baselineQuery, pplResultParser, ActionListener.wrap(baselineData -> { + try { + if (selectionData.isEmpty()) { + throw new IllegalStateException("No data found for selection time range"); + } + if (baselineData.isEmpty()) { + throw new IllegalStateException("No data found for baseline time range"); + } + getComparisonDataDistribution(selectionData, baselineData, params.index, ActionListener.wrap(result -> { + listener.onResponse((T) gson.toJson(Map.of("comparisonAnalysis", result))); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + }, listener::onFailure)); + } + + /** + * Converts time string to PPL format (yyyy-MM-dd HH:mm:ss) + * + * @param timeString Input time string + * @return Formatted time string for PPL + */ + private String formatTimeForPPL(String timeString) { + try { + // Parse ISO format and convert to PPL format + ZonedDateTime dateTime = ZonedDateTime.parse(timeString); + return dateTime.format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS", Locale.ROOT)); + } catch (DateTimeParseException e) { + // Try parsing as local time without zone + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DATE_FORMAT_PATTERN, Locale.ROOT); + LocalDateTime localDateTime = LocalDateTime.parse(timeString, formatter); + return localDateTime.format(formatter); + } catch (DateTimeParseException e2) { + // Return original if parsing fails + return timeString; + } + } + } + + /** + * Adds time range filter to PPL query + * + * @param query PPL query string (can be empty) + * @param startTime Start time for filtering + * @param endTime End time for filtering + * @param timeField Time field name + * @return PPL query with time range filter added + */ + private String getPPLQueryWithTimeRange(String query, String startTime, String endTime, String timeField) { + if (Strings.isEmpty(query)) { + throw new IllegalArgumentException("PPL query cannot be empty"); + } + if (Strings.isEmpty(timeField)) { + return query; + } + + String formattedStartTime = formatTimeForPPL(startTime); + String formattedEndTime = formatTimeForPPL(endTime); + String timePredicate = String + .format(Locale.ROOT, "`%s` >= '%s' AND `%s` <= '%s'", timeField, formattedStartTime, timeField, formattedEndTime); + + String[] commands = query.split("\\|"); + List commandList = new ArrayList<>(); + + // Always insert time filter right after first command (safest approach) + commandList.add(commands[0].trim()); + commandList.add("WHERE " + timePredicate); + + // Add remaining commands + for (int i = 1; i < commands.length; i++) { + String cmd = commands[i].trim(); + if (!cmd.isEmpty()) { + commandList.add(cmd); + } + } + + return String.join(" | ", commandList); + } + + /** + * Builds PPL query string for data retrieval within specified time range + * + * @param index Index name + * @param timeField Time field name + * @param startTime Start time for query + * @param endTime End time for query + * @param size Maximum number of documents + * @param customPpl Custom PPL statement (optional) + * @return Formatted PPL query string + */ + private String buildPPLQuery(String index, String timeField, String startTime, String endTime, int size, String customPpl) { + String baseQuery; + + if (!Strings.isEmpty(customPpl)) { + baseQuery = getPPLQueryWithTimeRange(customPpl, startTime, endTime, timeField); + } else { + baseQuery = getPPLQueryWithTimeRange(String.format(Locale.ROOT, "source=%s", index), startTime, endTime, timeField); + } + + return baseQuery + String.format(Locale.ROOT, " | head %d", size); + } + + /** + * Analyzes and compares data distributions between selection and baseline datasets + * + * @param selectionData Data from the selection time period + * @param baselineData Data from the baseline time period + * @param index Index name for field mapping retrieval + * @param listener Action listener for handling comparison results + */ + private void getComparisonDataDistribution( + List> selectionData, + List> baselineData, + String index, + ActionListener> listener + ) { + getFieldTypes(index, ActionListener.wrap(fieldTypes -> { + try { + List usefulFields = getUsefulFields(selectionData, fieldTypes); + Set numberFields = getNumberFields(fieldTypes); + List analyses = new ArrayList<>(); + + for (String field : usefulFields) { + Map selectionDist = calculateFieldDistribution(selectionData, field); + Map baselineDist = calculateFieldDistribution(baselineData, field); + + if (numberFields.contains(field)) { + GroupedDistributions grouped = groupNumericKeys(selectionDist, baselineDist); + selectionDist = grouped.groupedSelectionDist(); + baselineDist = grouped.groupedBaselineDist(); + } + + double divergence = calculateMaxDifference(selectionDist, baselineDist); + analyses.add(new FieldAnalysis(field, divergence, selectionDist, baselineDist)); + } + + analyses.sort(Comparator.comparingDouble((FieldAnalysis a) -> a.divergence).reversed()); + listener.onResponse(formatComparisonSummary(analyses, DEFAULT_COMPARISON_RESULT_LIMIT)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + } + + /** + * Analyzes distribution patterns within a single dataset + * + * @param data Dataset to analyze + * @param index Index name for field mapping retrieval + * @param listener Action listener for handling single analysis results + */ + private void analyzeSingleDataset(List> data, String index, ActionListener> listener) { + getFieldTypes(index, ActionListener.wrap(fieldTypes -> { + try { + List usefulFields = getUsefulFields(data, fieldTypes); + Set numberFields = getNumberFields(fieldTypes); + List analyses = new ArrayList<>(); + + for (String field : usefulFields) { + Map selectionDist = calculateFieldDistribution(data, field); + Map baselineDist = new HashMap<>(); + + if (numberFields.contains(field)) { + GroupedDistributions grouped = groupNumericKeys(selectionDist, baselineDist); + selectionDist = grouped.groupedSelectionDist(); + } + + double divergence = calculateMaxDifference(selectionDist, baselineDist); + analyses.add(new FieldAnalysis(field, divergence, selectionDist, baselineDist)); + } + + analyses.sort(Comparator.comparingDouble((FieldAnalysis a) -> a.divergence).reversed()); + listener.onResponse(formatComparisonSummary(analyses, DEFAULT_SINGLE_ANALYSIS_RESULT_LIMIT)); + } catch (Exception e) { + listener.onFailure(e); + } + }, listener::onFailure)); + } + + /** + * Internal record for field analysis results + */ + private record FieldAnalysis(String field, double divergence, Map selectionDist, Map baselineDist) { + } + + /** + * Record for grouped numeric distributions + */ + private record GroupedDistributions(Map groupedSelectionDist, Map groupedBaselineDist) { + } + + /** + * Gets field type mappings from index + * + * @param index Index name for mapping retrieval + * @param listener Action listener for handling field types result + */ + private void getFieldTypes(String index, ActionListener> listener) { + dataFetchingHelper.getFieldTypes(index, listener); + } + + /** + * Identifies useful fields for analysis based on index mapping and data characteristics + * + * @param data Sample data for cardinality analysis + * @param fieldTypes Map of field names to their types + * @return List of field names suitable for distribution analysis + */ + private List getUsefulFields(List> data, Map fieldTypes) { + if (fieldTypes.isEmpty()) { + log.warn("No field types available, using data-based field detection"); + return getFieldsFromData(data); + } + + Set keywordFields = new HashSet<>(); + Set numberFields = new HashSet<>(); + + for (Map.Entry entry : fieldTypes.entrySet()) { + String fieldType = entry.getValue(); + String fieldName = entry.getKey(); + + if (USEFUL_FIELD_TYPES.contains(fieldType)) { + keywordFields.add(fieldName); + } + if (NUMBER_FIELD_TYPES.contains(fieldType)) { + numberFields.add(fieldName); + } + } + + Set normalizedFields = keywordFields + .stream() + .map(field -> field.endsWith(".keyword") ? field.replace(".keyword", "") : field) + .collect(Collectors.toSet()); + + Map> fieldValueSets = new HashMap<>(); + normalizedFields.forEach(field -> fieldValueSets.put(field, new HashSet<>())); + + int maxCardinality = Math.max(MIN_CARDINALITY_BASE, data.size() / MIN_CARDINALITY_DIVISOR); + + data.forEach(doc -> { + normalizedFields.forEach(field -> { + Object value = getFlattenedValue(doc, field); + if (value != null) { + fieldValueSets.get(field).add(gson.toJson(value)); + } + }); + }); + + return normalizedFields.stream().filter(field -> { + int cardinality = fieldValueSets.get(field).size(); + if (field.toLowerCase(Locale.ROOT).endsWith("id")) { + return cardinality <= ID_FIELD_MAX_CARDINALITY && cardinality > 0; + } else if (numberFields.contains(field)) { + return true; + } + return cardinality <= maxCardinality && cardinality > 0; + }).collect(Collectors.toList()); + } + + /** + * Extracts nested field values from document using dot notation + * + * @param doc Document map to extract value from + * @param field Field path using dot notation (e.g., "user.name") + * @return Field value or null if not found + */ + private Object getFlattenedValue(Map doc, String field) { + return dataFetchingHelper.getFlattenedValue(doc, field); + } + + /** + * Calculates distribution of values for a specific field across the dataset + * + * @param data Dataset to analyze + * @param field Field name to calculate distribution for + * @return Map of field values to their relative frequencies + */ + private Map calculateFieldDistribution(List> data, String field) { + if (data == null || data.isEmpty()) { + return new HashMap<>(); + } + + Map counts = new HashMap<>(); + + for (Map doc : data) { + Object value = getFlattenedValue(doc, field); + if (value != null) { + String strValue = String.valueOf(value); + counts.merge(strValue, 1, Integer::sum); + } + } + return counts.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> (double) entry.getValue() / data.size())); + } + + /** + * Calculates maximum difference between selection and baseline distributions + * + * @param selectionDist Selection period distribution + * @param baselineDist Baseline period distribution + * @return Maximum difference value across all field values + */ + private double calculateMaxDifference(Map selectionDist, Map baselineDist) { + Set allKeys = new HashSet<>(selectionDist.keySet()); + allKeys.addAll(baselineDist.keySet()); + + if (allKeys.isEmpty()) { + return Double.NEGATIVE_INFINITY; + } + return allKeys.stream().mapToDouble(key -> { + double selectionVal = selectionDist.getOrDefault(key, 0.0); + double baselineVal = baselineDist.getOrDefault(key, 0.0); + return Math.abs(selectionVal - baselineVal); + }).max().orElse(Double.NEGATIVE_INFINITY); + } + + /** + * Extracts field names from sample data when mapping is not available + * + * @param data Sample data to analyze + * @return List of field names suitable for analysis + */ + private List getFieldsFromData(List> data) { + if (data.isEmpty()) { + return List.of(); + } + + Set allFields = new HashSet<>(); + for (Map doc : data) { + allFields.addAll(doc.keySet()); + } + + // Filter out timestamp and other non-useful fields + return allFields + .stream() + .filter(field -> !field.equals("@timestamp") && !field.equals("_id") && !field.equals("_index")) + .filter(field -> { + // Check cardinality - exclude high cardinality fields + Set values = new HashSet<>(); + for (Map doc : data) { + Object value = doc.get(field); + if (value != null) { + values.add(String.valueOf(value)); + } + } + int cardinality = values.size(); + return cardinality > 0 && cardinality <= Math.max(DATA_FIELD_MAX_CARDINALITY, data.size() / DATA_FIELD_CARDINALITY_DIVISOR); + }) + .collect(Collectors.toList()); + } + + /** + * Gets number fields from field type mappings + * + * @param fieldTypes Map of field names to their types + * @return Set of number field names + */ + private Set getNumberFields(Map fieldTypes) { + return dataFetchingHelper.getNumberFields(fieldTypes); + } + + /** + * Groups numeric keys and merges counts + * + * @param selectionDist Selection distribution + * @param baselineDist Baseline distribution + * @return Grouped distributions + */ + private GroupedDistributions groupNumericKeys(Map selectionDist, Map baselineDist) { + Set allKeys = new HashSet<>(selectionDist.keySet()); + allKeys.addAll(baselineDist.keySet()); + + if (allKeys.size() <= NUMERIC_GROUPING_THRESHOLD || allKeys.stream().anyMatch(key -> !NumberUtils.isCreatable(key))) { + return new GroupedDistributions(selectionDist, baselineDist); + } + + List numericKeys = allKeys.stream().map(Double::parseDouble).sorted().collect(Collectors.toList()); + Function getGroupLabel = getDoubleStringFunction(numericKeys); + // Group the keys and aggregate the values + Map groupedSelectionDist = numericKeys + .stream() + .collect( + Collectors + .groupingBy(getGroupLabel, Collectors.summingDouble(numKey -> selectionDist.getOrDefault(String.valueOf(numKey), 0.0))) + ); + Map groupedBaselineDist = numericKeys + .stream() + .collect( + Collectors + .groupingBy(getGroupLabel, Collectors.summingDouble(numKey -> baselineDist.getOrDefault(String.valueOf(numKey), 0.0))) + ); + // Ensure all groups are present in both maps (in case some have zero values) + Set allGroups = new HashSet<>(); + allGroups.addAll(groupedSelectionDist.keySet()); + allGroups.addAll(groupedBaselineDist.keySet()); + allGroups.forEach(group -> { + groupedSelectionDist.putIfAbsent(group, 0.0); + groupedBaselineDist.putIfAbsent(group, 0.0); + }); + + return new GroupedDistributions(groupedSelectionDist, groupedBaselineDist); + } + + private static Function getDoubleStringFunction(List numericKeys) { + double min = numericKeys.get(0); + double max = numericKeys.get(numericKeys.size() - 1); + double range = max - min; + int numGroups = 5; + double groupSize = range / numGroups; + // Create a function to determine which group a key belongs to + Function getGroupLabel = numKey -> { + int groupIndex = numKey == max ? numGroups - 1 : (int) ((numKey - min) / groupSize); + double lowerBound = min + groupIndex * groupSize; + double upperBound = groupIndex == numGroups - 1 ? max : min + (groupIndex + 1) * groupSize; + return String.format(Locale.ROOT, "%.1f-%.1f", lowerBound, upperBound); + }; + return getGroupLabel; + } + + /** + * Formats field analysis results into summary data items + * + * @param differences List of field analysis results + * @param maxResults Maximum number of results to return + * @return Formatted list of summary data items + */ + private List formatComparisonSummary(List differences, int maxResults) { + return differences.stream().filter(diff -> diff.divergence > 0).limit(maxResults).map(diff -> { + Set allKeys = new HashSet<>(diff.selectionDist.keySet()); + allKeys.addAll(diff.baselineDist.keySet()); + + boolean hasBaseline = !diff.baselineDist.isEmpty(); + + List changes = allKeys.stream().map(value -> { + double selectionPercentage = Math.round(diff.selectionDist.getOrDefault(value, 0.0) * PERCENTAGE_MULTIPLIER) + / PERCENTAGE_MULTIPLIER; + Double baselinePercentage = hasBaseline + ? Math.round(diff.baselineDist.getOrDefault(value, 0.0) * PERCENTAGE_MULTIPLIER) / PERCENTAGE_MULTIPLIER + : null; + return new ChangeItem(value, selectionPercentage, baselinePercentage); + }).collect(Collectors.toList()); + + List topChanges = changes + .stream() + .sorted( + (a, b) -> hasBaseline + ? Double + .compare( + Math.max(b.baselinePercentage != null ? b.baselinePercentage : 0.0, b.selectionPercentage), + Math.max(a.baselinePercentage != null ? a.baselinePercentage : 0.0, a.selectionPercentage) + ) + : Double.compare(b.selectionPercentage, a.selectionPercentage) + ) + .limit(TOP_CHANGES_LIMIT) + .collect(Collectors.toList()); + + return new SummaryDataItem(diff.field, diff.divergence, topChanges); + }).collect(Collectors.toList()); + } + + /** + * Parses PPL query result into list of documents + * + * @param pplResult PPL query result containing datarows and schema + * @return List of documents as maps + */ + private List> parsePPLResult(Map pplResult) { + Object datarowsObj = pplResult.get("datarows"); + Object schemaObj = pplResult.get("schema"); + + if (!(datarowsObj instanceof List) || !(schemaObj instanceof List)) { + return List.of(); + } + + @SuppressWarnings("unchecked") + List> dataRows = (List>) datarowsObj; + @SuppressWarnings("unchecked") + List> schema = (List>) schemaObj; + + List> result = new ArrayList<>(); + for (List row : dataRows) { + Map doc = new HashMap<>(); + for (int i = 0; i < Math.min(row.size(), schema.size()); i++) { + String columnName = (String) schema.get(i).get("name"); + if (columnName != null) { + doc.put(columnName, row.get(i)); + } + } + result.add(doc); + } + return result; + } + + /** + * Factory class for creating DataDistributionTool instances + */ + public static class Factory implements Tool.Factory { + private Client client; + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (DataDistributionTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public DataDistributionTool create(Map map) { + return new DataDistributionTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java b/src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java new file mode 100644 index 00000000..39b8e45d --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/DataFetchingHelper.java @@ -0,0 +1,482 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.agent.tools.utils.PPLExecuteHelper; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.transport.client.Client; + +import com.google.gson.reflect.TypeToken; + +import lombok.extern.log4j.Log4j2; + +/** + * Helper class for fetching and processing data from OpenSearch indices. + * Provides common functionality for data analysis tools including: + * - Field type detection and mapping + * - Time-based data fetching with DSL/PPL query support + * - Query building and parameter validation + * - Nested field value extraction + */ +@Log4j2 +public class DataFetchingHelper { + + private static final String DEFAULT_TIME_FIELD = "@timestamp"; + public static final String DATE_FORMAT_PATTERN = "yyyy-MM-dd HH:mm:ss"; + public static final String QUERY_TYPE_PPL = "ppl"; + public static final String QUERY_TYPE_DSL = "dsl"; + private static final String DEFAULT_SIZE = "1000"; + private static final int MAX_SIZE_LIMIT = 10000; + + public static final Set NUMBER_FIELD_TYPES = Set + .of("byte", "short", "integer", "long", "float", "double", "half_float", "scaled_float"); + + private final Client client; + + /** + * Constructs a DataFetchingHelper with the given OpenSearch client + * + * @param client The OpenSearch client for executing queries + */ + public DataFetchingHelper(Client client) { + this.client = client; + } + + /** + * Parameters for data analysis operations + */ + public static class AnalysisParameters { + public final String index; + public final String timeField; + public final String selectionTimeRangeStart; + public final String selectionTimeRangeEnd; + public final String baselineTimeRangeStart; + public final String baselineTimeRangeEnd; + public final int size; + public final String queryType; + public final List filter; + public final String dsl; + public final String ppl; + + public AnalysisParameters(Map parameters) { + this.index = parameters.getOrDefault("index", ""); + this.timeField = parameters.getOrDefault("timeField", DEFAULT_TIME_FIELD); + this.selectionTimeRangeStart = parameters.getOrDefault("selectionTimeRangeStart", ""); + this.selectionTimeRangeEnd = parameters.getOrDefault("selectionTimeRangeEnd", ""); + this.baselineTimeRangeStart = parameters.getOrDefault("baselineTimeRangeStart", ""); + this.baselineTimeRangeEnd = parameters.getOrDefault("baselineTimeRangeEnd", ""); + + String sizeStr = parameters.getOrDefault("size", DEFAULT_SIZE); + int parsedSize; + try { + parsedSize = Double.valueOf(sizeStr).intValue(); + if (parsedSize <= 0 || parsedSize > MAX_SIZE_LIMIT) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid 'size' parameter: must be between 1 and %d, got '%s'", MAX_SIZE_LIMIT, sizeStr) + ); + } + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid 'size' parameter: '%s', must be a valid integer", sizeStr) + ); + } + this.size = parsedSize; + + this.queryType = parameters.getOrDefault("queryType", QUERY_TYPE_DSL); + + // Parse filter from JSON string to List + String filterParam = parameters.getOrDefault("filter", ""); + if (Strings.isNullOrEmpty(filterParam)) { + this.filter = List.of(); + } else { + try { + this.filter = List.of(gson.fromJson(filterParam, String[].class)); + } catch (Exception e) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Invalid 'filter' parameter: must be a valid JSON array of strings, got '%s'. " + + "Example: [\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]", + filterParam + ) + ); + } + } + + this.dsl = parameters.getOrDefault("dsl", ""); + this.ppl = parameters.getOrDefault("ppl", ""); + } + + /** + * Validates required parameters are present + * + * @throws IllegalArgumentException if required parameters are missing + */ + public void validate() { + if (Strings.isNullOrEmpty(index)) { + throw new IllegalArgumentException("Missing required parameter: 'index'"); + } + if (Strings.isNullOrEmpty(selectionTimeRangeStart) || Strings.isNullOrEmpty(selectionTimeRangeEnd)) { + throw new IllegalArgumentException("Missing required parameters: 'selectionTimeRangeStart' and 'selectionTimeRangeEnd'"); + } + if (!QUERY_TYPE_DSL.equals(queryType) && !QUERY_TYPE_PPL.equals(queryType)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid 'queryType': must be 'dsl' or 'ppl', got '%s'", queryType) + ); + } + } + + public boolean hasBaselineTimeRange() { + return !Strings.isNullOrEmpty(baselineTimeRangeStart) && !Strings.isNullOrEmpty(baselineTimeRangeEnd); + } + } + + /** + * Retrieves field type mappings from the specified index + * + * @param index The index name + * @param listener Action listener for handling the field type map or failures + */ + public void getFieldTypes(String index, ActionListener> listener) { + GetMappingsRequest request = new GetMappingsRequest().indices(index); + + client.admin().indices().getMappings(request, ActionListener.wrap(response -> { + Map fieldTypes = new HashMap<>(); + + for (Map.Entry entry : response.getMappings().entrySet()) { + MappingMetadata metadata = entry.getValue(); + Map mappingSource = metadata.getSourceAsMap(); + + if (mappingSource.containsKey("properties")) { + @SuppressWarnings("unchecked") + Map properties = (Map) mappingSource.get("properties"); + extractFieldTypes(properties, "", fieldTypes); + } + } + + listener.onResponse(fieldTypes); + }, listener::onFailure)); + } + + /** + * Recursively extracts field types from mapping properties + */ + private void extractFieldTypes(Map properties, String prefix, Map fieldTypes) { + for (Map.Entry entry : properties.entrySet()) { + String fieldName = prefix.isEmpty() ? entry.getKey() : prefix + "." + entry.getKey(); + + @SuppressWarnings("unchecked") + Map fieldProps = (Map) entry.getValue(); + + if (fieldProps.containsKey("type")) { + fieldTypes.put(fieldName, (String) fieldProps.get("type")); + } + + if (fieldProps.containsKey("properties")) { + @SuppressWarnings("unchecked") + Map nestedProps = (Map) fieldProps.get("properties"); + extractFieldTypes(nestedProps, fieldName, fieldTypes); + } + } + } + + /** + * Fetches data from index within the specified time range + * + * @param timeRangeStart Start time string + * @param timeRangeEnd End time string + * @param params Analysis parameters + * @param listener Action listener for handling the fetched data or failures + */ + public void fetchIndexData( + String timeRangeStart, + String timeRangeEnd, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + if (QUERY_TYPE_PPL.equals(params.queryType)) { + fetchDataWithPPL(timeRangeStart, timeRangeEnd, params, listener); + } else { + fetchDataWithDSL(timeRangeStart, timeRangeEnd, params, listener); + } + } catch (Exception e) { + log.error("Failed to fetch index data", e); + listener.onFailure(e); + } + } + + /** + * Builds a BoolQueryBuilder with time range filter and optional custom filters. + * Can be used with any SearchSourceBuilder (for documents or aggregations). + * + * @param timeRangeStart Start time string + * @param timeRangeEnd End time string + * @param params Analysis parameters containing timeField, dsl, and filter settings + * @return BoolQueryBuilder with time range and custom filters applied + */ + public BoolQueryBuilder buildFilterQuery(String timeRangeStart, String timeRangeEnd, AnalysisParameters params) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + + // Add time range filter + RangeQueryBuilder timeRangeQuery = QueryBuilders + .rangeQuery(params.timeField) + .gte(formatTimeString(timeRangeStart)) + .lte(formatTimeString(timeRangeEnd)) + .format("strict_date_optional_time||epoch_millis"); + boolQuery.must(timeRangeQuery); + + // Add custom query if provided + if (!Strings.isNullOrEmpty(params.dsl)) { + Map dslMap = buildQueryFromMap(params.dsl); + boolQuery.must(QueryBuilders.wrapperQuery(gson.toJson(dslMap))); + } else if (!params.filter.isEmpty()) { + for (String filterStr : params.filter) { + Map filterMap = buildQueryFromMap(filterStr); + boolQuery.must(QueryBuilders.wrapperQuery(gson.toJson(filterMap))); + } + } + + return boolQuery; + } + + /** + * Fetches data using DSL query + */ + private void fetchDataWithDSL( + String timeRangeStart, + String timeRangeEnd, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + BoolQueryBuilder boolQuery = buildFilterQuery(timeRangeStart, timeRangeEnd, params); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQuery).size(params.size).fetchSource(true); + + SearchRequest searchRequest = new SearchRequest(params.index).source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(response -> { + List> data = new ArrayList<>(); + for (SearchHit hit : response.getHits().getHits()) { + data.add(hit.getSourceAsMap()); + } + listener.onResponse(data); + }, listener::onFailure)); + + } catch (Exception e) { + log.error("Failed to fetch data with DSL", e); + listener.onFailure(e); + } + } + + /** + * Fetches data using PPL query + */ + private void fetchDataWithPPL( + String timeRangeStart, + String timeRangeEnd, + AnalysisParameters params, + ActionListener>> listener + ) { + try { + String pplQuery = buildPPLQuery(timeRangeStart, timeRangeEnd, params); + + // Use PPLExecuteHelper with parser function + Function, List>> pplResultParser = this::parsePPLResult; + + PPLExecuteHelper.executePPLAndParseResult(client, pplQuery, pplResultParser, listener); + + } catch (Exception e) { + log.error("Failed to fetch data with PPL", e); + listener.onFailure(e); + } + } + + /** + * Parses PPL query result into list of documents + */ + private List> parsePPLResult(Map pplResult) { + Object datarowsObj = pplResult.get("datarows"); + Object schemaObj = pplResult.get("schema"); + + if (!(datarowsObj instanceof List) || !(schemaObj instanceof List)) { + log.warn("Invalid PPL result format"); + return new ArrayList<>(); + } + + @SuppressWarnings("unchecked") + List> datarows = (List>) datarowsObj; + @SuppressWarnings("unchecked") + List> schema = (List>) schemaObj; + + List fieldNames = new ArrayList<>(); + for (Map field : schema) { + fieldNames.add((String) field.get("name")); + } + + List> documents = new ArrayList<>(); + for (List row : datarows) { + Map doc = new HashMap<>(); + for (int i = 0; i < Math.min(row.size(), fieldNames.size()); i++) { + doc.put(fieldNames.get(i), row.get(i)); + } + documents.add(doc); + } + + return documents; + } + + /** + * Builds PPL query with time range filter + */ + private String buildPPLQuery(String timeRangeStart, String timeRangeEnd, AnalysisParameters params) { + String pplBase = !Strings.isNullOrEmpty(params.ppl) ? params.ppl : "source=" + params.index; + + String timeFilter = String + .format( + Locale.ROOT, + "%s >= '%s' AND %s <= '%s'", + params.timeField, + formatTimeString(timeRangeStart), + params.timeField, + formatTimeString(timeRangeEnd) + ); + + return String.format(Locale.ROOT, "%s | where %s | head %d", pplBase, timeFilter, params.size); + } + + /** + * Formats time string to ISO 8601 format for OpenSearch compatibility + * + * @param timeString Input time string + * @return Formatted time string in ISO 8601 format + * @throws DateTimeParseException if time string cannot be parsed + */ + private String formatTimeString(String timeString) throws DateTimeParseException { + log.debug("Attempting to parse time string: {}", timeString); + + // Try parsing with zone first + try { + if (timeString.endsWith("Z")) { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss'Z'", Locale.ROOT); + ZonedDateTime dateTime = ZonedDateTime.parse(timeString, formatter.withZone(ZoneOffset.UTC)); + return dateTime.format(DateTimeFormatter.ISO_INSTANT); + } + } catch (DateTimeParseException e) { + log.debug("Failed to parse as UTC time: {}", e.getMessage()); + } + + // Try parsing as local time without zone + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DATE_FORMAT_PATTERN, Locale.ROOT); + LocalDateTime localDateTime = LocalDateTime.parse(timeString, formatter); + ZonedDateTime zonedDateTime = localDateTime.atOffset(ZoneOffset.UTC).toZonedDateTime(); + return zonedDateTime.format(DateTimeFormatter.ISO_INSTANT); + } catch (DateTimeParseException e) { + log.debug("Failed to parse as local time: {}", e.getMessage()); + } + + // Try ISO format + try { + ZonedDateTime dateTime = ZonedDateTime.parse(timeString); + return dateTime.format(DateTimeFormatter.ISO_INSTANT); + } catch (DateTimeParseException e) { + log.debug("Failed to parse as ISO format: {}", e.getMessage()); + } + + throw new RuntimeException("Invalid time format: " + timeString); + } + + /** + * Builds query map from JSON string + */ + public Map buildQueryFromMap(String queryStr) { + if (Strings.isNullOrEmpty(queryStr)) { + return new HashMap<>(); + } + + try { + String normalizedQuery = queryStr.trim().replace("'", "\""); + return gson.fromJson(normalizedQuery, new TypeToken>() { + }.getType()); + } catch (Exception e) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid query format: %s. Error: %s", queryStr, e.getMessage())); + } + } + + /** + * Filters numeric fields from field type mappings + * + * @param fieldTypes Map of field names to their types + * @return Set of numeric field names + */ + public Set getNumberFields(Map fieldTypes) { + Set numberFields = new HashSet<>(); + for (Map.Entry entry : fieldTypes.entrySet()) { + if (NUMBER_FIELD_TYPES.contains(entry.getValue())) { + numberFields.add(entry.getKey()); + } + } + return numberFields; + } + + /** + * Extracts nested field value from document using dot notation + * + * @param doc The document map + * @param field The field path (e.g., "metrics.response_time") + * @return The field value or null if not found + */ + public Object getFlattenedValue(Map doc, String field) { + if (doc == null || field == null) { + return null; + } + + String[] parts = field.split("\\."); + Object current = doc; + + for (String part : parts) { + if (!(current instanceof Map)) { + return null; + } + @SuppressWarnings("unchecked") + Map currentMap = (Map) current; + current = currentMap.get(part); + if (current == null) { + return null; + } + } + + return current; + } +} diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java new file mode 100644 index 00000000..741cc94c --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/LogPatternAnalysisTool.java @@ -0,0 +1,1014 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.agent.tools.utils.clustering.HierarchicalAgglomerativeClustering.calculateCosineSimilarity; +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.opensearch.agent.tools.utils.PPLExecuteHelper; +import org.opensearch.agent.tools.utils.clustering.ClusteringHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.transport.client.Client; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Usage: + * 1. Register agent: + * POST /_plugins/_ml/agents/_register + * { + * "name": "LogPatternAnalysis", + * "type": "flow", + * "tools": [ + * { + * "name": "log_pattern_analysis_tool", + * "type": "LogPatternAnalysisTool", + * "parameters": { + * } + * } + * ] + * } + * 2. Execute agent: + * POST /_plugins/_ml/agents/{agent_id}/_execute + * { + * "parameters": { + * "index": "ss4o_logs-otel-2025.06.24", + * "logFieldName": "body", + * "traceFieldName": "traceId", + * "baseTimeRangeStart": "2025-06-24 07:33:05", + * "baseTimeRangeEnd": "2025-06-24 07:51:27", + * "selectionTimeRangeStart": "2025-06-24 07:50:26", + * "selectionTimeRangeEnd": "2025-06-24 07:55:56" + * } + * } + * 3. Result: a list of selection traceId + * { + * "inference_results": [ + * { + * "output": [ + * { + * "name": "response", + * "result": """{"EXCEPTIONAL": {"traceId": "sequence"}}""" + * } + * ] + * } + * ] + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(LogPatternAnalysisTool.TYPE) +public class LogPatternAnalysisTool implements Tool { + public static final String TYPE = "LogPatternAnalysisTool"; + public static final String STRICT_FIELD = "strict"; + + // Constants + private static final String DEFAULT_DESCRIPTION = "Analyzes log patterns in a target time range, optionally compared to a baseline. " + + "Use when investigating incidents to find new, spiking, or anomalous log messages. " + + "Three modes: " + + "(1) Sequence analysis (traceFieldName + baseline): groups logs by trace ID, returns exceptional trace sequences. " + + "(2) Pattern diff (baseline, no traceFieldName): compares pattern frequencies, returns highest-lift patterns. " + + "(3) Log insight (no baseline): finds top error/warning patterns with sample logs."; + private static final double LOG_VECTORS_CLUSTERING_THRESHOLD = 0.5; + private static final double LOG_PATTERN_THRESHOLD = 0.75; + private static final double LOG_PATTERN_LIFT = 3; + private static final String DEFAULT_TIME_FIELD = "@timestamp"; + private static final int MAX_LOG_SAMPLE_SIZE = 10000; + + public static final String DEFAULT_INPUT_SCHEMA = + """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "Target OpenSearch index name" + }, + "timeField": { + "type": "string", + "description": "Date/time field for filtering" + }, + "logFieldName": { + "type": "string", + "description": "Field containing log message text" + }, + "traceFieldName": { + "type": "string", + "description": "Trace/correlation ID field. Enables sequence analysis mode when provided with baseline time range" + }, + "baseTimeRangeStart": { + "type": "string", + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeEnd" + }, + "baseTimeRangeEnd": { + "type": "string", + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Must pair with baseTimeRangeStart" + }, + "selectionTimeRangeStart": { + "type": "string", + "description": "Start of target/incident period (format: yyyy-MM-dd HH:mm:ss)" + }, + "selectionTimeRangeEnd": { + "type": "string", + "description": "End of target/incident period (format: yyyy-MM-dd HH:mm:ss)" + }, + "filter": { + "type": "string", + "description": "PPL boolean expression to filter logs (e.g. serviceName='ts-auth-service' or severity='ERROR'). Applied as additional where clause. Not applicable for sequence analysis mode (when traceFieldName is provided with baseline), as sequence analysis requires all logs within a trace" + } + }, + "required": [ + "index", + "timeField", + "logFieldName", + "selectionTimeRangeStart", + "selectionTimeRangeEnd" + ], + "additionalProperties": false + } + """; + + public static final Map DEFAULT_ATTRIBUTES = Map + .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class)), STRICT_FIELD, false); + + // Compiled regex patterns for better performance + private static final Pattern REPEATED_WILDCARDS_PATTERN = Pattern.compile("(<\\*>)(\\s+<\\*>)+"); + + /** + * Parameter class to hold analysis parameters with validation + */ + private static class AnalysisParameters { + final String index; + final String timeField; + final String logFieldName; + final String traceFieldName; + final String baseTimeRangeStart; + final String baseTimeRangeEnd; + final String selectionTimeRangeStart; + final String selectionTimeRangeEnd; + final String filter; + + AnalysisParameters(Map parameters) { + this.index = parameters.getOrDefault("index", ""); + this.timeField = parameters.getOrDefault("timeField", DEFAULT_TIME_FIELD); + this.logFieldName = parameters.getOrDefault("logFieldName", "message"); + this.traceFieldName = parameters.getOrDefault("traceFieldName", ""); + this.baseTimeRangeStart = parameters.getOrDefault("baseTimeRangeStart", ""); + this.baseTimeRangeEnd = parameters.getOrDefault("baseTimeRangeEnd", ""); + this.selectionTimeRangeStart = parameters.getOrDefault("selectionTimeRangeStart", ""); + this.selectionTimeRangeEnd = parameters.getOrDefault("selectionTimeRangeEnd", ""); + this.filter = parameters.getOrDefault("filter", ""); + } + + private void validate() { + List missingParams = new ArrayList<>(); + if (Strings.isEmpty(index)) + missingParams.add("index"); + if (Strings.isEmpty(timeField)) + missingParams.add("timeField"); + if (Strings.isEmpty(logFieldName)) + missingParams.add("logFieldName"); + if (Strings.isEmpty(selectionTimeRangeStart)) + missingParams.add("selectionTimeRangeStart"); + if (Strings.isEmpty(selectionTimeRangeEnd)) + missingParams.add("selectionTimeRangeEnd"); + if (!missingParams.isEmpty()) { + throw new IllegalArgumentException("Missing required parameters: " + String.join(", ", missingParams)); + } + } + + boolean hasBaseTime() { + return !Strings.isEmpty(baseTimeRangeStart) && !Strings.isEmpty(baseTimeRangeEnd); + } + + boolean hasTraceField() { + return !Strings.isEmpty(traceFieldName); + } + } + + /** + * Result class for pattern analysis + */ + private record PatternAnalysisResult(Map> tracePatternMap, Map> patternCountMap, + Map patternWeightsMap) { + } + + private record PatternDiffResult(String pattern, Double base, Double selection, Double lift) { + } + + Comparator comparator = (d1, d2) -> { + Double lift1 = Optional.ofNullable(d1.lift).orElse(Double.MIN_VALUE); + Double lift2 = Optional.ofNullable(d2.lift).orElse(Double.MIN_VALUE); + + if (lift1.compareTo(lift2) == 0) { + return Optional + .ofNullable(d2.selection) + .orElse(Double.MIN_VALUE) + .compareTo(Optional.ofNullable(d1.selection).orElse(Double.MIN_VALUE)); + } else { + return lift2.compareTo(lift1); + } + }; + + private record PatternWithSamples(String pattern, double count, List sampleLogs) { + } + + // Instance fields + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String version; + private Client client; + private ClusteringHelper clusteringHelper; + + public LogPatternAnalysisTool(Client client) { + this.client = client; + this.clusteringHelper = new ClusteringHelper(LOG_VECTORS_CLUSTERING_THRESHOLD); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public Map getAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public void setAttributes(Map map) { + + } + + @Override + public boolean validate(Map map) { + try { + new AnalysisParameters(map).validate(); + } catch (Exception e) { + return false; + } + return true; + } + + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, DEFAULT_ATTRIBUTES); + log.debug("Starting log pattern analysis with parameters: {}", parameters.keySet()); + AnalysisParameters params = new AnalysisParameters(parameters); + params.validate(); + + if (params.hasTraceField() && params.hasBaseTime()) { + log.debug("Performing log sequence analysis for index: {}", params.index); + logSequenceAnalysis(params, listener); + } else if (params.hasBaseTime()) { + log.debug("Performing log pattern analysis for index: {}", params.index); + logPatternDiffAnalysis(params, listener); + } else { + logInsight(params, listener); + } + } catch (IllegalArgumentException e) { + log.error("Invalid parameters for LogPatternAnalysisTool: {}", e.getMessage()); + listener.onFailure(new IllegalArgumentException("Invalid parameters: " + e.getMessage(), e)); + } catch (Exception e) { + log.error("Unexpected error in LogPatternAnalysisTool", e); + listener.onFailure(new RuntimeException("Failed to execute log pattern analysis", e)); + } + } + + private void logSequenceAnalysis(AnalysisParameters params, ActionListener listener) { + if (!Strings.isEmpty(params.filter)) { + log.warn("Filter parameter is ignored for sequence analysis mode as it requires all logs within a trace"); + } + // Step 1: Analyze selection time range + analyzeSelectionTimeRange(params, ActionListener.wrap(selectionResult -> { + log.debug("Selection time range analysis completed, found {} traces", selectionResult.tracePatternMap.size()); + + if (selectionResult.tracePatternMap.isEmpty()) { + Map> emptyResult = buildFinalResult( + List.of(), + List.of(), + Collections.emptyMap(), + Collections.emptyMap() + ); + listener.onResponse((T) gson.toJson(emptyResult)); + return; + } + + // Step 2: Analyze base time range + analyzeBaseTimeRange(params, ActionListener.wrap(baseResult -> { + log.debug("Base time range analysis completed, found {} traces", baseResult.tracePatternMap.size()); + + // Step 3: Generate comparison result + generateSequenceComparisonResult(baseResult, selectionResult, listener); + }, listener::onFailure)); + }, error -> { + log.error("Failed to execute analysis", error); + listener.onFailure(new RuntimeException("Analysis failed: " + error.getMessage(), error)); + })); + } + + private void analyzeBaseTimeRange(AnalysisParameters params, ActionListener listener) { + String baseTimeRangeLogPatternPPL = buildLogPatternPPL( + params.index, + params.timeField, + params.logFieldName, + params.traceFieldName, + params.baseTimeRangeStart, + params.baseTimeRangeEnd, + "" + ); + + executePPL(baseTimeRangeLogPatternPPL, listener); + } + + private void analyzeSelectionTimeRange(AnalysisParameters params, ActionListener listener) { + String selectionTimeRangeLogPatternPPL = buildLogPatternPPL( + params.index, + params.timeField, + params.logFieldName, + params.traceFieldName, + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd, + "" + ); + + executePPL(selectionTimeRangeLogPatternPPL, listener); + } + + private void executePPL(String ppl, ActionListener listener) { + Function>, PatternAnalysisResult> rowParser = dataRows -> { + Map> tracePatternMap = new HashMap<>(); + Map> patternCountMap = new HashMap<>(); + Map rawPatternCache = new HashMap<>(); + + for (List row : dataRows) { + if (row.size() < 2) { + continue; + } + + String traceId = (String) row.get(0); + String rawPattern = (String) row.get(1); + + String simplifiedPattern = rawPatternCache.computeIfAbsent(rawPattern, this::postProcessPattern); + + tracePatternMap.computeIfAbsent(traceId, k -> new LinkedHashSet<>()).add(simplifiedPattern); + patternCountMap.computeIfAbsent(simplifiedPattern, k -> new HashSet<>()).add(traceId); + } + + // Calculate pattern values using IDF and sigmoid + Map patternVectors = vectorizePattern(patternCountMap, tracePatternMap.size()); + + return new PatternAnalysisResult(tracePatternMap, patternCountMap, patternVectors); + }; + + PPLExecuteHelper.executePPLAndParseResult(client, ppl, PPLExecuteHelper.dataRowsParser(rowParser), listener); + } + + private String buildLogPatternPPL( + String index, + String timeField, + String logFieldName, + String traceFieldName, + String startTime, + String endTime, + String filter + ) { + String filterClause = Strings.isEmpty(filter) ? "" : String.format(Locale.ROOT, " | where %s", filter); + + String pplTemplate = + "source={INDEX} | where {TRACE_FIELD}!='' | where {TIME_FIELD}>'{START_TIME}' and {TIME_FIELD}<'{END_TIME}'{FILTER} " + + "| fields {TRACE_FIELD}, {LOG_FIELD}, {TIME_FIELD} | patterns {LOG_FIELD} method=brain variable_count_threshold=3 " + + "| fields {TRACE_FIELD}, patterns_field, {TIME_FIELD} | sort {TIME_FIELD}"; + + return pplTemplate + .replace("{INDEX}", index) + .replace("{TRACE_FIELD}", traceFieldName) + .replace("{TIME_FIELD}", timeField) + .replace("{START_TIME}", startTime) + .replace("{END_TIME}", endTime) + .replace("{FILTER}", filterClause) + .replace("{LOG_FIELD}", logFieldName); + } + + private Map vectorizePattern(Map> patternCountMap, int totalTraceCount) { + Map patternValues = new HashMap<>(); + + for (Map.Entry> entry : patternCountMap.entrySet()) { + String pattern = entry.getKey(); + Set traceIds = entry.getValue(); + + if (traceIds != null && !traceIds.isEmpty()) { + // IDF calculation + double idf = Math.log((double) totalTraceCount / traceIds.size()); + // Apply sigmoid function + double value = 1.0 / (1.0 + Math.exp(-idf)); + patternValues.put(pattern, value); + } else { + patternValues.put(pattern, 0.0); + } + } + + return patternValues; + } + + private void generateSequenceComparisonResult( + PatternAnalysisResult baseResult, + PatternAnalysisResult selectionResult, + ActionListener listener + ) { + try { + // Step 3: Build pattern index for vector construction + Map patternIndexMap = buildPatternIndex(baseResult, selectionResult); + + // Step 4: Build vectors for base time range + Map baseVectorMap = buildVectorMap( + baseResult.tracePatternMap, + baseResult.patternWeightsMap, + patternIndexMap, + false + ); + + // Step 5: Cluster base vectors and find centroids + List baseRepresentative = this.clusteringHelper.clusterLogVectorsAndGetRepresentative(baseVectorMap); + + // Step 6: Build vectors for traceNeedToExamine time range + Map selectionVectorMap = buildVectorMap( + selectionResult.tracePatternMap, + selectionResult.patternWeightsMap, + patternIndexMap, + true, + baseResult.patternCountMap, + selectionResult.patternCountMap + ); + + // Step 7: Find traceNeedToExamine centroids + List selectionRepresentative = this.clusteringHelper.clusterLogVectorsAndGetRepresentative(selectionVectorMap); + + List traceNeedToExamine = filterSelectionCentroids( + baseRepresentative, + selectionRepresentative, + baseVectorMap, + selectionVectorMap + ); + + log + .info( + "Identified {} traceNeedToExamine centroids from {} candidates", + traceNeedToExamine.size(), + selectionRepresentative.size() + ); + + // Generate final result + Map> result = buildFinalResult( + baseRepresentative, + traceNeedToExamine, + baseResult.tracePatternMap, + selectionResult.tracePatternMap + ); + listener.onResponse((T) gson.toJson(result)); + + } catch (Exception e) { + log.error("Failed to generate sequence comparison result", e); + listener.onFailure(new RuntimeException("Failed to generate comparison result: " + e.getMessage(), e)); + } + } + + private Map buildPatternIndex(PatternAnalysisResult baseResult, PatternAnalysisResult selectionResult) { + Set allPatterns = new HashSet<>(baseResult.patternCountMap.keySet()); + allPatterns.addAll(selectionResult.patternCountMap.keySet()); + + List sortedPatterns = new ArrayList<>(allPatterns); + Collections.sort(sortedPatterns); + log.debug("vector dimension is {}", sortedPatterns.size()); + + // pattern and its index in a vector + Map patternIndexMap = new HashMap<>(); + for (int i = 0; i < sortedPatterns.size(); i++) { + patternIndexMap.put(sortedPatterns.get(i), i); + } + + return patternIndexMap; + } + + @SafeVarargs + private Map buildVectorMap( + Map> tracePatternMap, + Map patternWeightsMap, + Map patternIndexMap, + boolean isSelection, + Map>... additionalPatternMaps + ) { + Map vectorMap = new HashMap<>(); + int dimension = patternIndexMap.size(); + + for (Map.Entry> entry : tracePatternMap.entrySet()) { + String traceId = entry.getKey(); + Set patterns = entry.getValue(); + double[] vector = new double[dimension]; + + for (String pattern : patterns) { + Integer index = patternIndexMap.get(pattern); + if (index != null) { + double baseValue = 0.5 * patternWeightsMap.getOrDefault(pattern, 0.0); + + if (isSelection && additionalPatternMaps.length >= 2) { + // Add existence weight for selection patterns + Map> basePatterns = additionalPatternMaps[0]; + + int existenceWeight = basePatterns.containsKey(pattern) ? 0 : 1; + vector[index] = baseValue + 0.5 * existenceWeight; + } else { + vector[index] = baseValue; + } + } + } + + vectorMap.put(traceId, vector); + } + + return vectorMap; + } + + private List filterSelectionCentroids( + List baseCentroids, + List selectionCandidates, + Map baseVectorMap, + Map selectionVectorMap + ) { + List selectionCentroids = new ArrayList<>(); + + for (String candidate : selectionCandidates) { + boolean isSelection = true; + double[] candidateVector = selectionVectorMap.get(candidate); + + if (candidateVector == null) { + log.warn("No vector found for selection candidate: {}", candidate); + continue; + } + + for (String baseCentroid : baseCentroids) { + double[] baseVector = baseVectorMap.get(baseCentroid); + if (baseVector != null && calculateCosineSimilarity(baseVector, candidateVector) > LOG_VECTORS_CLUSTERING_THRESHOLD) { + isSelection = false; + break; + } + } + + if (isSelection) { + selectionCentroids.add(candidate); + } + } + + return selectionCentroids; + } + + private Map> buildFinalResult( + List baseCentroids, + List selectionCentroids, + Map> baseTracePatternMap, + Map> selectionTracePatternMap + ) { + Map baseSequences = new HashMap<>(); + for (String centroid : baseCentroids) { + Set patterns = baseTracePatternMap.get(centroid); + if (patterns != null) { + baseSequences.put(centroid, String.join(" -> ", patterns)); + } + } + + Map selectionSequences = new HashMap<>(); + for (String centroid : selectionCentroids) { + Set patterns = selectionTracePatternMap.get(centroid); + if (patterns != null) { + selectionSequences.put(centroid, String.join(" -> ", patterns)); + } + } + + Map> result = new HashMap<>(); + result.put("BASE", baseSequences); + result.put("EXCEPTIONAL", selectionSequences); + + return result; + } + + private void logPatternDiffAnalysis(AnalysisParameters params, ActionListener listener) { + // Step 1: Generate log patterns for baseline time range + String baseTimeRangeLogPatternPPL = buildLogPatternPPL( + params.index, + params.timeField, + params.logFieldName, + params.baseTimeRangeStart, + params.baseTimeRangeEnd, + params.filter + ); + Function>, Map> dataRowsParser = dataRows -> { + Map patternMap = new HashMap<>(); + for (List row : dataRows) { + if (row.size() == 2) { + String pattern = (String) row.get(1); + double count = ((Number) row.get(0)).doubleValue(); + patternMap.put(pattern, count); + } + } + return patternMap; + }; + + log.debug("Executing base time range pattern PPL: {}", baseTimeRangeLogPatternPPL); + PPLExecuteHelper + .executePPLAndParseResult( + client, + baseTimeRangeLogPatternPPL, + PPLExecuteHelper.dataRowsParser(dataRowsParser), + ActionListener.wrap(basePatterns -> { + try { + mergeSimilarPatterns(basePatterns); + + log.debug("Base patterns processed: {} patterns", basePatterns.size()); + + // Step 2: Generate log patterns for selection time range + String selectionTimeRangeLogPatternPPL = buildLogPatternPPL( + params.index, + params.timeField, + params.logFieldName, + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd, + params.filter + ); + + log.debug("Executing selection time range pattern PPL: {}", selectionTimeRangeLogPatternPPL); + PPLExecuteHelper + .executePPLAndParseResult( + client, + selectionTimeRangeLogPatternPPL, + PPLExecuteHelper.dataRowsParser(dataRowsParser), + ActionListener.wrap(selectionPatterns -> { + mergeSimilarPatterns(selectionPatterns); + + log.debug("Selection patterns processed: {} patterns", selectionPatterns.size()); + + // Step 3: Calculate pattern differences + List patternDifferences = calculatePatternDifferences( + basePatterns, + selectionPatterns + ); + + // Step 4: Sort the difference and get top 10 + List topDiffs = Stream + .concat( + patternDifferences + .stream() + .filter(diff -> !Objects.isNull(diff.lift)) + .sorted(comparator) + .limit(10), + patternDifferences + .stream() + .filter(diff -> Objects.isNull(diff.lift)) + .sorted(comparator) + .limit(10) + ) + .collect(Collectors.toList()); + + Map finalResult = new HashMap<>(); + finalResult.put("patternMapDifference", topDiffs); + + log.debug("Pattern analysis completed: {} differences found", patternDifferences.size()); + listener.onResponse((T) gson.toJson(finalResult)); + }, listener::onFailure) + ); + + } catch (Exception e) { + log.error("Failed to process base pattern response", e); + listener.onFailure(new RuntimeException("Failed to process base patterns: " + e.getMessage(), e)); + } + }, error -> { + log.error("Failed to execute pattern analysis", error); + listener.onFailure(new RuntimeException("Analysis failed: " + error.getMessage(), error)); + }) + ); + } + + private void logInsight(AnalysisParameters params, ActionListener listener) { + Set errorKeywords = Set + .of( + "error", + "err", + "exception", + "failed", + "failure", + "timeout", + "panic", + "fatal", + "critical", + "severe", + "abort", + "aborted", + "aborting", + "crash", + "crashed", + "broken", + "corrupt", + "corrupted", + "invalid", + "malformed", + "unprocessable", + "denied", + "forbidden", + "unauthorized", + "conflict", + "deadlock", + "overflow", + "underflow", + "throttled", + "disk_full", + "insufficient", + "retrying", + "backpressure", + "degraded", + "unexpected", + "unusual", + "missing", + "stale", + "expired", + "mismatch", + "violation" + ); + + String filterClause = Strings.isEmpty(params.filter) ? "" : String.format(Locale.ROOT, " | where %s", params.filter); + + String pplTemplate = "source={INDEX} | where {TIME_FIELD}>'{START_TIME}' and {TIME_FIELD}<'{END_TIME}'{FILTER} " + + "| where match({LOG_FIELD}, '{ERROR_KEYWORDS}') | head " + + MAX_LOG_SAMPLE_SIZE + + " | fields {LOG_FIELD} | patterns {LOG_FIELD} method=brain " + + "mode=aggregation max_sample_count=5 variable_count_threshold=3 " + + "| fields patterns_field, pattern_count, sample_logs | sort -pattern_count | head 5"; + + String selectionTimeRangeLogPatternPPL = pplTemplate + .replace("{INDEX}", params.index) + .replace("{TIME_FIELD}", params.timeField) + .replace("{START_TIME}", params.selectionTimeRangeStart) + .replace("{END_TIME}", params.selectionTimeRangeEnd) + .replace("{FILTER}", filterClause) + .replace("{LOG_FIELD}", params.logFieldName) + .replace("{ERROR_KEYWORDS}", String.join(" ", errorKeywords)); + + Function>, List> dataRowsParser = dataRows -> { + List patternWithSamplesList = new ArrayList<>(); + for (List row : dataRows) { + if (row.size() == 3) { + String pattern = (String) row.get(0); + double count = ((Number) row.get(1)).doubleValue(); + List samples = (List) row.get(2); + patternWithSamplesList.add(new PatternWithSamples(pattern, count, samples)); + } + } + return patternWithSamplesList; + }; + + PPLExecuteHelper + .executePPLAndParseResult( + client, + selectionTimeRangeLogPatternPPL, + PPLExecuteHelper.dataRowsParser(dataRowsParser), + ActionListener.wrap(logInsights -> { + try { + Map finalResult = new HashMap<>(); + finalResult.put("logInsights", logInsights); + listener.onResponse((T) gson.toJson(finalResult)); + } catch (Exception e) { + log.error("Failed to process base pattern response", e); + listener.onFailure(new RuntimeException("Failed to process base patterns: " + e.getMessage(), e)); + } + }, error -> { + log.error("Failed to execute log insights analysis", error); + listener.onFailure(new RuntimeException("Log insights analysis failed: " + error.getMessage(), error)); + }) + ); + } + + private String buildLogPatternPPL( + String index, + String timeField, + String logFieldName, + String startTime, + String endTime, + String filter + ) { + String filterClause = Strings.isEmpty(filter) ? "" : String.format(Locale.ROOT, " | where %s", filter); + + String pplTemplate = "source={INDEX} | where {TIME_FIELD}>'{START_TIME}' and {TIME_FIELD}<'{END_TIME}'{FILTER} " + + "| fields {LOG_FIELD} | patterns {LOG_FIELD} method=brain mode=aggregation variable_count_threshold=3 " + + "| fields pattern_count, patterns_field"; + + return pplTemplate + .replace("{INDEX}", index) + .replace("{TIME_FIELD}", timeField) + .replace("{START_TIME}", startTime) + .replace("{END_TIME}", endTime) + .replace("{FILTER}", filterClause) + .replace("{LOG_FIELD}", logFieldName); + } + + private List calculatePatternDifferences(Map basePatterns, Map selectionPatterns) { + List differences = new ArrayList<>(); + + double selectionTotal = selectionPatterns.values().stream().mapToDouble(Double::doubleValue).sum(); + double baseTotal = basePatterns.values().stream().mapToDouble(Double::doubleValue).sum(); + + for (Map.Entry entry : selectionPatterns.entrySet()) { + String pattern = entry.getKey(); + double selectionCount = entry.getValue(); + + if (basePatterns.containsKey(pattern)) { + double baseCount = basePatterns.get(pattern); + double lift = (selectionCount / selectionTotal) / (baseCount / baseTotal); + + if (lift < 1) { + lift = 1.0 / lift; + } + + if (lift > LOG_PATTERN_LIFT) { + differences.add(new PatternDiffResult(pattern, baseCount / baseTotal, selectionCount / selectionTotal, lift)); + } + } else { + // Pattern only exists in selection time range + differences.add(new PatternDiffResult(pattern, 0.0, selectionCount / selectionTotal, null)); + log.debug("New selection pattern detected: {} (count: {})", pattern, selectionCount); + } + } + + return differences; + } + + private double jaccardSimilarity(String pattern1, String pattern2) { + if (Strings.isEmpty(pattern1) && Strings.isEmpty(pattern2)) { + return 1.0; + } + if (Strings.isEmpty(pattern1) || Strings.isEmpty(pattern2)) { + return 0.0; + } + + Set set1 = new HashSet<>(Arrays.asList(pattern1.split("\\s+"))); + Set set2 = new HashSet<>(Arrays.asList(pattern2.split("\\s+"))); + + // Calculate union + Set union = new HashSet<>(set1); + union.addAll(set2); + + int intersectionSize = set1.size() + set2.size() - union.size(); + return (double) intersectionSize / union.size(); + } + + private void mergeSimilarPatterns(Map patternMap) { + if (patternMap.isEmpty()) { + return; + } + + List patterns = new ArrayList<>(patternMap.keySet()); + patterns.sort(String::compareTo); + Set removed = new HashSet<>(); + + for (int i = 0; i < patterns.size(); i++) { + String pattern1 = patterns.get(i); + if (removed.contains(pattern1)) { + continue; + } + + for (int j = i + 1; j < patterns.size(); j++) { + String pattern2 = patterns.get(j); + if (removed.contains(pattern2)) { + continue; + } + + if (jaccardSimilarity(pattern1, pattern2) > LOG_PATTERN_THRESHOLD) { + // Merge pattern2 into pattern1 + double count1 = patternMap.getOrDefault(pattern1, 0.0); + double count2 = patternMap.getOrDefault(pattern2, 0.0); + patternMap.put(pattern1, count1 + count2); + patternMap.remove(pattern2); + removed.add(pattern2); + log.debug("Merged similar patterns: '{}' + '{}' -> '{}'", pattern1, pattern2, pattern1); + } + } + } + + // Post-process patterns and merge those with similar processed forms + Map toReplace = new HashMap<>(); + for (String pattern : patternMap.keySet()) { + String processedPattern = postProcessPattern(pattern); + if (!processedPattern.equals(pattern)) { + toReplace.put(pattern, processedPattern); + } + } + + for (Map.Entry entry : toReplace.entrySet()) { + String originalPattern = entry.getKey(); + String processedPattern = entry.getValue(); + double count = patternMap.remove(originalPattern); + patternMap.merge(processedPattern, count, Double::sum); + } + + log.debug("Pattern merging completed: {} patterns remaining", patternMap.size()); + } + + private String postProcessPattern(String pattern) { + if (Strings.isEmpty(pattern)) { + return pattern; + } + + // Replace repeated <*> with single <*> using compiled pattern + pattern = REPEATED_WILDCARDS_PATTERN.matcher(pattern).replaceAll("<*>"); + return pattern; + } + + public static class Factory implements Tool.Factory { + private Client client; + + private static LogPatternAnalysisTool.Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static LogPatternAnalysisTool.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (LogPatternAnalysisTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new LogPatternAnalysisTool.Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public LogPatternAnalysisTool create(Map map) { + + return new LogPatternAnalysisTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/LogPatternTool.java b/src/main/java/org/opensearch/agent/tools/LogPatternTool.java index 4359b2dc..70464156 100644 --- a/src/main/java/org/opensearch/agent/tools/LogPatternTool.java +++ b/src/main/java/org/opensearch/agent/tools/LogPatternTool.java @@ -31,6 +31,7 @@ import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.sql.plugin.transport.PPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; @@ -107,7 +108,8 @@ protected String getQueryBody(String queryText) { } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); String dsl = parameters.get(INPUT_FIELD); String ppl = parameters.get(PPL_FIELD); if (!StringUtils.isBlank(dsl)) { diff --git a/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java b/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java new file mode 100644 index 00000000..d2af20d1 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/MetricChangeAnalysisTool.java @@ -0,0 +1,557 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.transport.client.Client; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Tool for analyzing metric changes by comparing percentile distributions between time periods. + * Uses relative change (percentage change) to rank fields by significance. + * + * Usage: + * POST /_plugins/_ml/agents/{agent_id}/_execute + * { + * "parameters": { + * "index": "logs-2025.01.15", + * "timeField": "@timestamp", + * "selectionTimeRangeStart": "2025-01-15 10:00:00", + * "selectionTimeRangeEnd": "2025-01-15 11:00:00", + * "baselineTimeRangeStart": "2025-01-15 08:00:00", + * "baselineTimeRangeEnd": "2025-01-15 09:00:00", + * "size": 1000 + * } + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(MetricChangeAnalysisTool.TYPE) +public class MetricChangeAnalysisTool implements Tool { + public static final String TYPE = "MetricChangeAnalysisTool"; + + private static final String DEFAULT_DESCRIPTION = + "Compares percentile distributions (P50, P90) of numeric fields between two time ranges. " + + "Returns top fields ranked by change score. " + + "Use for root cause analysis when investigating metric anomalies. " + + "Keep both time ranges short (e.g. 15-30 minutes) and similar in duration for accurate comparison."; + + private static final int DEFAULT_TOP_N = 10; + + public static final String DEFAULT_INPUT_SCHEMA = + """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "Target OpenSearch index name" + }, + "timeField": { + "type": "string", + "description": "Date/time field for filtering" + }, + "selectionTimeRangeStart": { + "type": "string", + "description": "Start of target period (format: yyyy-MM-dd HH:mm:ss)" + }, + "selectionTimeRangeEnd": { + "type": "string", + "description": "End of target period (format: yyyy-MM-dd HH:mm:ss)" + }, + "baselineTimeRangeStart": { + "type": "string", + "description": "Start of baseline period (format: yyyy-MM-dd HH:mm:ss)" + }, + "baselineTimeRangeEnd": { + "type": "string", + "description": "End of baseline period (format: yyyy-MM-dd HH:mm:ss). Should be at or before selectionTimeRangeStart" + }, + "size": { + "type": "integer", + "description": "Maximum number of documents to analyze (default: 1000, max: 10000)" + }, + "topN": { + "type": "integer", + "description": "Number of top fields to return, ranked by change score (default: 10)" + }, + "queryType": { + "type": "string", + "description": "Query type: 'ppl' or 'dsl' (default: 'dsl')" + }, + "filter": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Additional DSL query conditions (optional)" + }, + "dsl": { + "type": "string", + "description": "Complete raw DSL query as JSON string (optional)" + }, + "ppl": { + "type": "string", + "description": "Complete PPL statement without time information (optional)" + } + }, + "required": ["index", "timeField", "selectionTimeRangeStart", "selectionTimeRangeEnd", "baselineTimeRangeStart", "baselineTimeRangeEnd"] + } + """; + + public static final Map DEFAULT_ATTRIBUTES = Map + .of(TOOL_INPUT_SCHEMA_FIELD, gson.toJson(gson.fromJson(DEFAULT_INPUT_SCHEMA, Map.class))); + + /** + * Record for percentile statistics + */ + private record PercentileStats(double p50, double p90) { + } + + /** + * Record for field percentile analysis with variance + */ + private record FieldPercentileAnalysis(String field, double variance, PercentileStats selectionStats, PercentileStats baselineStats) { + } + + /** + * Result item for JSON output + */ + private record PercentileAnalysisResult(String field, Double changeScore, Map selectionPercentiles, + Map baselinePercentiles, Map logRatios) { + } + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String version; + private Client client; + private DataFetchingHelper dataFetchingHelper; + + /** + * Constructs a MetricChangeAnalysisTool with the given OpenSearch client + * + * @param client The OpenSearch client for executing queries + */ + public MetricChangeAnalysisTool(Client client) { + this.client = client; + this.dataFetchingHelper = new DataFetchingHelper(client); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public Map getAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public void setAttributes(Map map) {} + + @Override + public boolean validate(Map parameters) { + try { + // Use helper's validation logic + DataFetchingHelper.AnalysisParameters params = new DataFetchingHelper.AnalysisParameters(parameters); + params.validate(); + return true; + } catch (IllegalArgumentException e) { + log.error("Invalid parameters: {}", e.getMessage()); + return false; + } + } + + /** + * Executes percentile analysis on numeric fields between selection and baseline periods + * + * @param The response type + * @param originalParameters Input parameters for analysis + * @param listener Action listener for handling results or failures + */ + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, DEFAULT_ATTRIBUTES); + log.debug("Starting metric change analysis with parameters: {}", parameters.keySet()); + + // Extract topN parameter (use Double.valueOf to handle "30.0" from Gson serialization) + int topN = DEFAULT_TOP_N; + String topNStr = parameters.get("topN"); + if (topNStr != null && !topNStr.isEmpty()) { + try { + topN = Double.valueOf(topNStr).intValue(); + if (topN <= 0) { + topN = DEFAULT_TOP_N; + } + } catch (NumberFormatException e) { + log.warn("Invalid topN parameter '{}', using default: {}", topNStr, DEFAULT_TOP_N); + } + } + + // Use DataDistributionTool's data fetching mechanism + fetchDataAndAnalyze(parameters, topN, listener); + + } catch (IllegalArgumentException e) { + log.error("Invalid parameters for MetricChangeAnalysisTool: {}", e.getMessage()); + listener.onFailure(e); + } catch (Exception e) { + log.error("Unexpected error in MetricChangeAnalysisTool", e); + listener.onFailure(e); + } + } + + /** + * Fetches data using DataDistributionTool's mechanism and performs percentile analysis + */ + private void fetchDataAndAnalyze(Map parameters, int topN, ActionListener listener) { + try { + // Create analysis parameters + DataFetchingHelper.AnalysisParameters params = new DataFetchingHelper.AnalysisParameters(parameters); + + // Get field types first + String index = parameters.get("index"); + dataFetchingHelper.getFieldTypes(index, ActionListener.wrap((Map fieldTypes) -> { + // Get number fields + Set numberFields = dataFetchingHelper.getNumberFields(fieldTypes); + + if (numberFields.isEmpty()) { + listener + .onFailure( + new IllegalStateException("No numeric fields found in index. Percentile analysis requires numeric fields.") + ); + return; + } + + // Fetch selection and baseline data + fetchBothDatasets(params, numberFields, topN, listener); + }, listener::onFailure)); + + } catch (Exception e) { + log.error("Failed to fetch data for percentile analysis", e); + listener.onFailure(e); + } + } + + /** + * Fetches both selection and baseline datasets + */ + private void fetchBothDatasets( + DataFetchingHelper.AnalysisParameters params, + Set numberFields, + int topN, + ActionListener listener + ) { + try { + // Fetch selection data + dataFetchingHelper + .fetchIndexData( + params.selectionTimeRangeStart, + params.selectionTimeRangeEnd, + params, + ActionListener.wrap((List> selectionData) -> { + // Fetch baseline data + dataFetchingHelper + .fetchIndexData( + params.baselineTimeRangeStart, + params.baselineTimeRangeEnd, + params, + ActionListener.wrap((List> baselineData) -> { + // Perform metric change analysis + performAnalysis(selectionData, baselineData, numberFields, topN, listener); + }, listener::onFailure) + ); + }, listener::onFailure) + ); + + } catch (Exception e) { + log.error("Failed to fetch datasets", e); + listener.onFailure(e); + } + } + + /** + * Performs metric change analysis on the fetched data + */ + private void performAnalysis( + List> selectionData, + List> baselineData, + Set numberFields, + int topN, + ActionListener listener + ) { + try { + if (selectionData.isEmpty()) { + listener.onFailure(new IllegalStateException("No data found for selection time range")); + return; + } + + if (baselineData.isEmpty()) { + listener.onFailure(new IllegalStateException("No data found for baseline time range")); + return; + } + + // Calculate percentiles and relative changes + List analyses = calculateMetricChangeAnalysis(selectionData, baselineData, numberFields); + List results = formatResults(analyses, topN); + listener.onResponse((T) gson.toJson(Map.of("percentileAnalysis", results))); + + } catch (Exception e) { + log.error("Failed to perform metric change analysis", e); + listener.onFailure(e); + } + } + + /** + * Formats analysis results for JSON output, limiting to top N results + */ + private List formatResults(List analyses, int topN) { + return analyses.stream().limit(topN).map(analysis -> { + Map selectionStats = Map.of("p50", analysis.selectionStats.p50, "p90", analysis.selectionStats.p90); + + Map baselineStats = Map.of("p50", analysis.baselineStats.p50, "p90", analysis.baselineStats.p90); + + Map logRatios = Map + .of( + "p50", + safeLogRatio(analysis.selectionStats.p50, analysis.baselineStats.p50), + "p90", + safeLogRatio(analysis.selectionStats.p90, analysis.baselineStats.p90) + ); + + return new PercentileAnalysisResult(analysis.field, analysis.variance, selectionStats, baselineStats, logRatios); + }).toList(); + } + + // ========== Metric Change Analysis Functions ========== + + /** + * Calculates metric change analysis for all numeric fields + */ + private List calculateMetricChangeAnalysis( + List> selectionData, + List> baselineData, + Set numberFields + ) { + List analyses = new ArrayList<>(); + + for (String field : numberFields) { + List selectionValues = extractNumericValues(selectionData, field); + List baselineValues = extractNumericValues(baselineData, field); + + if (selectionValues.isEmpty() || baselineValues.isEmpty()) { + continue; + } + + PercentileStats selectionStats = calculatePercentiles(selectionValues); + PercentileStats baselineStats = calculatePercentiles(baselineValues); + double variance = calculatePercentileVariance(selectionStats, baselineStats); + + analyses.add(new FieldPercentileAnalysis(field, variance, selectionStats, baselineStats)); + } + + analyses.sort(Comparator.comparingDouble((FieldPercentileAnalysis a) -> a.variance).reversed()); + return analyses; + } + + /** + * Extracts numeric values from dataset for a specific field + */ + private List extractNumericValues(List> data, String field) { + List values = new ArrayList<>(); + + for (Map doc : data) { + Object value = getFlattenedValue(doc, field); + if (value != null) { + try { + if (value instanceof Number) { + values.add(((Number) value).doubleValue()); + } else { + values.add(Double.parseDouble(value.toString())); + } + } catch (NumberFormatException e) { + // Skip non-numeric values + } + } + } + + return values; + } + + /** + * Extracts nested field values from document using dot notation + */ + private Object getFlattenedValue(Map doc, String field) { + return dataFetchingHelper.getFlattenedValue(doc, field); + } + + /** + * Calculates statistics (avg, P50, P90) for a list of values + */ + private PercentileStats calculatePercentiles(List values) { + if (values.isEmpty()) { + return new PercentileStats(0.0, 0.0); + } + + List sorted = new ArrayList<>(values); + sorted.sort(Double::compareTo); + + double p50 = calculatePercentile(sorted, 50); + double p90 = calculatePercentile(sorted, 90); + + return new PercentileStats(p50, p90); + } + + /** + * Calculates a specific percentile from sorted values + */ + private double calculatePercentile(List sortedValues, int percentile) { + if (sortedValues.isEmpty()) { + return 0.0; + } + + if (sortedValues.size() == 1) { + return sortedValues.get(0); + } + + double index = (percentile / 100.0) * (sortedValues.size() - 1); + int lowerIndex = (int) Math.floor(index); + int upperIndex = (int) Math.ceil(index); + + if (lowerIndex == upperIndex) { + return sortedValues.get(lowerIndex); + } + + double lowerValue = sortedValues.get(lowerIndex); + double upperValue = sortedValues.get(upperIndex); + double fraction = index - lowerIndex; + + return lowerValue + (upperValue - lowerValue) * fraction; + } + + private static final double LOG_RATIO_CAP = 10.0; + private static final double EPSILON = 1e-10; + + /** + * Calculates change score between selection and baseline statistics. + * Uses weighted log-ratio scoring on avg, P50, and P90. + * Skips a metric if its baseline is near zero to avoid inflated scores. + * Redistributes weight equally among valid metrics. + */ + private double calculatePercentileVariance(PercentileStats selectionStats, PercentileStats baselineStats) { + boolean p50Valid = Math.abs(baselineStats.p50) >= EPSILON; + boolean p90Valid = Math.abs(baselineStats.p90) >= EPSILON; + + if (!p50Valid && !p90Valid) { + return 0.0; + } + if (p50Valid && p90Valid) { + return 0.5 * safeLogRatio(selectionStats.p50, baselineStats.p50) + 0.5 * safeLogRatio(selectionStats.p90, baselineStats.p90); + } + if (p50Valid) { + return safeLogRatio(selectionStats.p50, baselineStats.p50); + } + return safeLogRatio(selectionStats.p90, baselineStats.p90); + } + + /** + * Computes |log(selection / baseline)| with safe handling of near-zero values. + * Returns 0 when both values are near zero, caps at LOG_RATIO_CAP when only baseline is near zero. + */ + private double safeLogRatio(double selection, double baseline) { + if (Math.abs(baseline) < EPSILON && Math.abs(selection) < EPSILON) { + return 0.0; + } + if (Math.abs(baseline) < EPSILON) { + return LOG_RATIO_CAP; + } + double ratio = selection / baseline; + if (ratio <= 0) { + return 0.0; + } + return Math.abs(Math.log(ratio)); + } + + /** + * Factory class for creating MetricChangeAnalysisTool instances + */ + public static class Factory implements Tool.Factory { + private Client client; + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (MetricChangeAnalysisTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public MetricChangeAnalysisTool create(Map map) { + return new MetricChangeAnalysisTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index 182d396d..15bdc7c8 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -7,6 +7,7 @@ import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.utils.ToolUtils.NO_ESCAPE_PARAMS; import java.io.IOException; import java.io.InputStream; @@ -14,6 +15,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -22,20 +24,27 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.StringJoiner; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.math.NumberUtils; import org.apache.commons.text.StringSubstitutor; import org.apache.spark.sql.types.DataType; import org.json.JSONObject; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.agent.tools.utils.ToolHelper; +import org.opensearch.agent.tools.utils.mergeMetaData.MergeRuleHelper; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; @@ -50,6 +59,7 @@ import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.plugin.transport.PPLQueryAction; @@ -79,6 +89,9 @@ public class PPLTool implements WithModelTool { private static final String DEFAULT_DESCRIPTION = "\"Use this tool when user ask question based on the data in the cluster or parse user statement about which index to use in a conversion.\nAlso use this tool when question only contains index information.\n1. If uesr question contain both question and index name, the input parameters are {'question': UserQuestion, 'index': IndexName}.\n2. If user question contain only question, the input parameter is {'question': UserQuestion}.\n3. If uesr question contain only index name, find the original human input from the conversation histroy and formulate parameter as {'question': UserQuestion, 'index': IndexName}\nThe index name should be exactly as stated in user's input."; + private static final String TABLE_INFO_KEY = "table_info"; + private static final String MAPPING_KEY = "mappings"; + @Setter private String name = TYPE; @Getter @@ -194,30 +207,54 @@ public PPLTool( @SuppressWarnings("unchecked") @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String tenantId = parameters.get(TENANT_ID_FIELD); extractFromChatParameters(parameters); - String indexName = getIndexNameFromParameters(parameters); - if (StringUtils.isBlank(indexName)) { + List indices = Optional + .ofNullable(getIndexNameFromParameters(parameters, "index")) + .filter(list -> !list.isEmpty()) + .orElseGet(() -> getIndexNameFromParameters(parameters, this.previousToolKey + ".output")); + if (indices.isEmpty()) { throw new IllegalArgumentException( "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" ); } String question = parameters.get("question"); - if (StringUtils.isBlank(indexName) || StringUtils.isBlank(question)) { + if (StringUtils.isBlank(question)) { throw new IllegalArgumentException("Parameter index and question can not be null or empty."); } - if (indexName.startsWith(".")) { - throw new IllegalArgumentException( - "PPLTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " - + indexName - ); + for (String index : indices) { + if (index.startsWith(".")) { + throw new IllegalArgumentException( + "PPLTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + + index + ); + } } - ActionListener actionsAfterTableinfo = ActionListener.wrap(tableInfo -> { - String prompt = constructPrompt(tableInfo, question.strip(), indexName); + ActionListener> actionsAfterTableinfo = ActionListener.wrap(indexInfo -> { + if (Objects.isNull(indexInfo.get(TABLE_INFO_KEY)) || Objects.isNull(indexInfo.get(MAPPING_KEY))) { + log.error("The table info and mappings are missing in: {}", indexInfo); + listener.onFailure(new RuntimeException("The table info and mappings are missing in: " + indexInfo)); + } + String tableInfo = indexInfo.get(TABLE_INFO_KEY).toString(); + String prompt = constructPrompt(tableInfo, question.strip(), indices); + Map reformattedInput = Map + .of( + "prompt", + prompt, + "mappings", + indexInfo.get(MAPPING_KEY), + "os_version", + Version.CURRENT.toString(), + "current_time", + Instant.now().toString(), + "datasourceType", + parameters.getOrDefault("type", "Opensearch") + ); RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet .builder() - .parameters(Map.of("prompt", prompt, "datasourceType", parameters.getOrDefault("type", "Opensearch"))) + .parameters(Map.of("prompt", formatString(reformattedInput), NO_ESCAPE_PARAMS, "prompt")) .build(); ActionRequest request = new MLPredictionTaskRequest( modelId, @@ -234,7 +271,7 @@ public void run(Map parameters, ActionListener listener) listener.onFailure(new IllegalStateException("Remote endpoint fails to inference.")); return; } - String ppl = parseOutput(dataAsMap.get("response"), indexName); + String ppl = parseOutput(dataAsMap.get("response")); if (!this.execute) { Map ret = ImmutableMap.of("ppl", ppl); listener.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(ret))); @@ -262,8 +299,14 @@ public void run(Map parameters, ActionListener listener) ); // Execute output here }, e -> { + log.error(String.format(Locale.ROOT, "fail to predict model: %s with error: %s", modelId, e.getMessage()), e); - listener.onFailure(e); + if (e instanceof OpenSearchStatusException) { + String errorMessage = redactSagemakerArns(redactCloudwatchUrl(e.getMessage())); + listener.onFailure(new OpenSearchStatusException(errorMessage, ((OpenSearchStatusException) e).status())); + } else { + listener.onFailure(e); + } })); }, e -> { log.info("fail to get index schema"); @@ -271,6 +314,7 @@ public void run(Map parameters, ActionListener listener) } ); + // Logic for schema/samples as input if (parameters.containsKey("schema") && parameters.containsKey("samples") && Objects.equals(parameters.getOrDefault("type", ""), "s3")) { @@ -281,44 +325,61 @@ public void run(Map parameters, ActionListener listener) transferS3SchemaFormat(schema), (Map) samples.get(0) ); - actionsAfterTableinfo.onResponse(tableInfo); + actionsAfterTableinfo.onResponse(Map.of(TABLE_INFO_KEY, tableInfo, MAPPING_KEY, gson.toJson(schema))); } catch (Exception e) { log.info("fail to get table info for s3"); actionsAfterTableinfo.onFailure(e); } - return; } - GetMappingsRequest getMappingsRequest = buildGetMappingRequest(indexName); - client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> { - Map mappings = getMappingsResponse.getMappings(); - if (mappings.isEmpty()) { - throw new IllegalArgumentException("No matching mapping with index name: " + indexName); - } - String firstIndexName = (String) mappings.keySet().toArray()[0]; - SearchRequest searchRequest = buildSearchRequest(firstIndexName); - client.search(searchRequest, ActionListener.wrap(searchResponse -> { - SearchHit[] searchHits = searchResponse.getHits().getHits(); - String tableInfo = constructTableInfo(searchHits, mappings); - actionsAfterTableinfo.onResponse(tableInfo); + + CountDownLatch latch = new CountDownLatch(indices.size()); + ConcurrentHashMap tableInfos = new ConcurrentHashMap<>(); + ConcurrentHashMap mappingInfos = new ConcurrentHashMap<>(); + for (String index : indices) { + GetMappingsRequest getMappingsRequest = buildGetMappingRequest(index); + client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> { + Map mappings = getMappingsResponse.getMappings(); + if (mappings.isEmpty()) { + throw new IllegalArgumentException("No matching mapping with index name: " + index); + } + String firstIndexName = (String) mappings.keySet().toArray()[0]; + SearchRequest searchRequest = buildSearchRequest(firstIndexName); + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + SearchHit[] searchHits = searchResponse.getHits().getHits(); + Map finalMappings = new HashMap<>(); + for (MappingMetadata mappingMetadata : mappings.values()) { + Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); + MergeRuleHelper.merge(mappingSource, finalMappings); + } + String tableInfo = constructTableInfo(searchHits, finalMappings); + tableInfos.put(index, tableInfo); + mappingInfos.put(index, finalMappings); + latch.countDown(); + if (latch.getCount() == 0) { + String mergedTableInfo = mergeTableInfo(tableInfos); + actionsAfterTableinfo.onResponse(Map.of(TABLE_INFO_KEY, mergedTableInfo, MAPPING_KEY, mappingInfos)); + } + }, e -> { + log.error(String.format(Locale.ROOT, "fail to search index: %s with error: %s", firstIndexName, e.getMessage()), e); + listener.onFailure(e); + })); }, e -> { - log.error(String.format(Locale.ROOT, "fail to search model: %s with error: %s", modelId, e.getMessage()), e); - listener.onFailure(e); + log.error(String.format(Locale.ROOT, "fail to get mapping of index: %s with error: %s", indices, e.getMessage()), e); + String errorMessage = e.getMessage(); + if (errorMessage.contains("no such index")) { + listener + .onFailure( + new IllegalArgumentException( + "Return this final answer to human directly and do not use other tools: 'Please provide the existing index name(s)'. Please try to directly send this message to human to ask for index name" + ) + ); + } else { + listener.onFailure(e); + } })); - }, e -> { - log.error(String.format(Locale.ROOT, "fail to get mapping of index: %s with error: %s", indexName, e.getMessage()), e); - String errorMessage = e.getMessage(); - if (errorMessage.contains("no such index")) { - listener - .onFailure( - new IllegalArgumentException( - "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" - ) - ); - } else { - listener.onFailure(e); - } - })); + } + } @Override @@ -504,17 +565,10 @@ private String constructTableInfoByPPLResultForSpark(Map schema, } - private String constructTableInfo(SearchHit[] searchHits, Map mappings) throws PrivilegedActionException { - String firstIndexName = (String) mappings.keySet().toArray()[0]; - MappingMetadata mappingMetadata = mappings.get(firstIndexName); - Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); - if (Objects.isNull(mappingSource)) { - throw new IllegalArgumentException( - "The querying index doesn't have mapping metadata, please add data to it or using another index." - ); - } + private String constructTableInfo(SearchHit[] searchHits, Map allFields) throws PrivilegedActionException { Map fieldsToType = new HashMap<>(); - ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", true); + ToolHelper.extractFieldNamesTypes(allFields, fieldsToType, "", false); + StringJoiner tableInfoJoiner = new StringJoiner("\n"); List sortedKeys = new ArrayList<>(fieldsToType.keySet()); Collections.sort(sortedKeys); @@ -546,8 +600,8 @@ private String constructTableInfo(SearchHit[] searchHits, Map indexInfo = ImmutableMap.of("mappingInfo", tableInfo, "question", question, "indexName", indexName); + private String constructPrompt(String tableInfo, String question, List indices) { + Map indexInfo = ImmutableMap.of("mappingInfo", tableInfo, "question", question, "indexName", indices.toString()); StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}"); return substitutor.replace(contextPrompt); } @@ -602,7 +656,7 @@ private void extractFromChatParameters(Map parameters) { } } - private String parseOutput(String llmOutput, String indexName) { + private String parseOutput(String llmOutput) { String ppl; Pattern pattern = Pattern.compile("((.|[\\r\\n])+?)"); // For ppl like source=a \n | fields b Matcher matcher = pattern.matcher(llmOutput); @@ -612,32 +666,10 @@ private String parseOutput(String llmOutput, String indexName) { } else { // logic for only ppl returned int sourceIndex = llmOutput.indexOf("source="); int describeIndex = llmOutput.indexOf("describe "); - if (sourceIndex != -1) { - llmOutput = llmOutput.substring(sourceIndex); - - // Splitting the string at "|" - String[] lists = llmOutput.split("\\|"); - - // Modifying the first element - if (lists.length > 0) { - lists[0] = "source=" + indexName; - } - - // Joining the string back together - ppl = String.join("|", lists); - } else if (describeIndex != -1) { - llmOutput = llmOutput.substring(describeIndex); - String[] lists = llmOutput.split("\\|"); - - // Modifying the first element - if (lists.length > 0) { - lists[0] = "describe " + indexName; - } - - // Joining the string back together - ppl = String.join("|", lists); - } else { + if (sourceIndex == -1 && describeIndex == -1) { throw new IllegalArgumentException("The returned PPL: " + llmOutput + " has wrong format"); + } else { + ppl = llmOutput; } } if (this.pplModelType != PPLModelType.FINETUNE) { @@ -656,12 +688,26 @@ private String parseOutput(String llmOutput, String indexName) { return ppl; } - private String getIndexNameFromParameters(Map parameters) { - String indexName = parameters.getOrDefault("index", ""); - if (!StringUtils.isBlank(this.previousToolKey) && StringUtils.isBlank(indexName)) { - indexName = parameters.getOrDefault(this.previousToolKey + ".output", ""); // read index name from previous key + private List getIndexNameFromParameters(Map parameters, String key) { + if (!parameters.containsKey(key)) { + return List.of(); + } + String indexName = parameters.get(key); + try { + List list = gson.fromJson(indexName, List.class); + return list.stream().map(Object::toString).map(String::trim).collect(Collectors.toList()); + } catch (Exception e) { + return List.of(indexName.trim()); + } + } + + private String mergeTableInfo(ConcurrentHashMap tableInfos) { + StringBuilder mergedTableInfo = new StringBuilder(); + for (Map.Entry entry : tableInfos.entrySet()) { + mergedTableInfo.append(entry.getKey()).append("\n"); + mergedTableInfo.append(entry.getValue()).append("\n"); } - return indexName.trim(); + return mergedTableInfo.toString(); } private Map transferS3SchemaFormat(Map originalSchema) { @@ -686,4 +732,25 @@ private static Map loadDefaultPromptDict() { } return new HashMap<>(); } + + private static String redactSagemakerArns(String input) { + String regex = "arn:aws:logs:[^:]+:\\d+:log-group:/aws/sagemaker/Endpoints/[^ \\t\\n\\r\\f\\v,\"']+"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(input); + + return matcher.replaceAll(""); + } + + public static String redactCloudwatchUrl(String input) { + String regex = "See\\s+.+?\\s+in\\s+account\\s+.+?\\s+for\\s+more\\s+information"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(input); + + return matcher.replaceAll(""); + } + + public String formatString(Map targetMap) { + String mapString = gson.toJson(gson.toJson(targetMap)); + return mapString.substring(1, mapString.length() - 1); + } } diff --git a/src/main/java/org/opensearch/agent/tools/RAGTool.java b/src/main/java/org/opensearch/agent/tools/RAGTool.java index 7771ca66..c1a32667 100644 --- a/src/main/java/org/opensearch/agent/tools/RAGTool.java +++ b/src/main/java/org/opensearch/agent/tools/RAGTool.java @@ -28,6 +28,7 @@ import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.transport.client.Client; import com.google.gson.Gson; @@ -95,7 +96,9 @@ public Object parse(Object o) { }; } - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); + final String tenantId = parameters.get(TENANT_ID_FIELD); String input = null; diff --git a/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java index cab2bc7c..e144dd83 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java @@ -21,6 +21,7 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.transport.client.Client; import org.opensearch.transport.client.node.NodeClient; @@ -70,7 +71,8 @@ public Object parse(Object o) { } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String tableSortOrder = parameters.getOrDefault("sortOrder", "asc"); final String tableSortString = parameters.getOrDefault("sortString", "monitor_name.keyword"); final int tableSize = parameters.containsKey("size") && StringUtils.isNumeric(parameters.get("size")) diff --git a/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java index efe8dfd7..9830f5d6 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java @@ -36,6 +36,7 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.SortOrder; @@ -52,7 +53,7 @@ public class SearchAnomalyDetectorsTool implements Tool { public static final String TYPE = "SearchAnomalyDetectorsTool"; private static final String DEFAULT_DESCRIPTION = "This is a tool that searches anomaly detectors. It takes 12 optional arguments named detectorName which is the explicit name of the detector (default is null), and detectorNamePattern which is a wildcard query to match detector name (default is null), and indices which defines the index or index pattern the detector is detecting over (default is null), and highCardinality which defines whether the anomaly detector is high cardinality (synonymous with multi-entity) of non-high-cardinality (synonymous with single-entity) (default is null, indicating both), and lastUpdateTime which defines the latest update time of the anomaly detector in epoch milliseconds (default is null), and sortOrder which defines the order of the results (options are asc or desc, and default is asc), and sortString which defines how to sort the results (default is name.keyword), and size which defines the size of the request to be returned (default is 20), and startIndex which defines the paginated index to start from (default is 0), and running which defines whether the anomaly detector is running (default is null, indicating both), and failed which defines whether the anomaly detector has failed (default is null, indicating both). The tool returns 2 values: a list of anomaly detectors (each containing the detector id, detector name, detector type indicating multi-entity or single-entity (where multi-entity also means high-cardinality), detector description, name of the configured index, last update time in epoch milliseconds), and the total number of anomaly detectors."; - + public static final String CONFIG_INDEX = ".opendistro-anomaly-detectors"; @Setter @Getter private String name = TYPE; @@ -94,7 +95,8 @@ public Object parse(Object o) { // number of total detectors. The output will likely need to be updated, standardized, and include more fields in the // future to cover a sufficient amount of potential questions the agent will need to handle. @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String detectorName = parameters.getOrDefault("detectorName", null); final String detectorNamePattern = parameters.getOrDefault("detectorNamePattern", null); final String indices = parameters.getOrDefault("indices", null); @@ -169,6 +171,7 @@ public void run(Map parameters, ActionListener listener) GetConfigRequest profileRequest = new GetConfigRequest( hit.getId(), + CONFIG_INDEX, Versions.MATCH_ANY, false, true, diff --git a/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java index 76b322cb..b7936417 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java @@ -26,6 +26,7 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; @@ -84,7 +85,8 @@ public Object parse(Object o) { // and total # of results. The output will likely need to be updated, standardized, and include more fields in the // future to cover a sufficient amount of potential questions the agent will need to handle. @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String detectorId = parameters.getOrDefault("detectorId", null); final Boolean realTime = parameters.containsKey("realTime") ? Boolean.parseBoolean(parameters.get("realTime")) : null; final Double anomalyGradeThreshold = parameters.containsKey("anomalyGradeThreshold") diff --git a/src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java b/src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java new file mode 100644 index 00000000..104c0bc0 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchAroundDocumentTool.java @@ -0,0 +1,370 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.transport.client.Client; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonObject; +import com.google.gson.JsonSyntaxException; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Tool to search N documents before and N documents after a specific document ID, + * ordered by a timestamp field using search_after pagination. + */ +@Getter +@Setter +@Log4j2 +@ToolAnnotation(SearchAroundDocumentTool.TYPE) +public class SearchAroundDocumentTool implements Tool { + + public static final String INPUT_FIELD = "input"; + public static final String INDEX_FIELD = "index"; + public static final String DOC_ID_FIELD = "doc_id"; + public static final String TIMESTAMP_FIELD = "timestamp_field"; + public static final String COUNT_FIELD = "count"; + public static final String INPUT_SCHEMA_FIELD = "input_schema"; + + public static final String TYPE = "SearchAroundDocumentTool"; + private static final String DEFAULT_DESCRIPTION = """ + Use this tool to search documents around a specific document by providing: \ + 'index' for the index name, 'doc_id' for the target document ID, \ + 'timestamp_field' for the field to order by, and 'count' for the number of documents before and after. \ + Returns N documents before, the target document, and N documents after."""; + + public static final String DEFAULT_INPUT_SCHEMA = """ + { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": "OpenSearch index name" + }, + "doc_id": { + "type": "string", + "description": "Target document ID" + }, + "timestamp_field": { + "type": "string", + "description": "Timestamp field for ordering (e.g., @timestamp)" + }, + "count": { + "type": "integer", + "description": "Number of documents before and after the target" + } + }, + "required": ["index", "doc_id", "timestamp_field", "count"], + "additionalProperties": false + } + """; + + private static final Gson GSON = new GsonBuilder().serializeSpecialFloatingPointValues().create(); + + public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); + + private String name = TYPE; + private Map attributes; + private String description = DEFAULT_DESCRIPTION; + + private Client client; + + private NamedXContentRegistry xContentRegistry; + + public SearchAroundDocumentTool(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + + this.attributes = new HashMap<>(); + attributes.put(INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.isEmpty()) { + return false; + } + boolean argumentsFromInput = parameters.containsKey(INPUT_FIELD) && !StringUtils.isEmpty(parameters.get(INPUT_FIELD)); + boolean argumentsFromParameters = parameters.containsKey(INDEX_FIELD) + && parameters.containsKey(DOC_ID_FIELD) + && parameters.containsKey(TIMESTAMP_FIELD) + && parameters.containsKey(COUNT_FIELD) + && !StringUtils.isEmpty(parameters.get(INDEX_FIELD)) + && !StringUtils.isEmpty(parameters.get(DOC_ID_FIELD)) + && !StringUtils.isEmpty(parameters.get(TIMESTAMP_FIELD)) + && !StringUtils.isEmpty(parameters.get(COUNT_FIELD)); + boolean validRequest = argumentsFromInput || argumentsFromParameters; + if (!validRequest) { + log.error("SearchAroundDocumentTool requires: index, doc_id, timestamp_field, and count parameters!"); + return false; + } + return true; + } + + private static Map processResponse(SearchHit hit) { + Map docContent = new HashMap<>(); + docContent.put("_index", hit.getIndex()); + docContent.put("_id", hit.getId()); + docContent.put("_score", hit.getScore()); + docContent.put("_source", hit.getSourceAsMap()); + if (hit.getSortValues() != null && hit.getSortValues().length > 0) { + docContent.put("sort", hit.getSortValues()); + } + return docContent; + } + + @Override + public void run(Map originalParameters, ActionListener listener) { + try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); + String input = parameters.get(INPUT_FIELD); + String index = null; + String docId = null; + String timestampField = null; + Integer count = null; + + if (!StringUtils.isEmpty(input)) { + try { + JsonObject jsonObject = GSON.fromJson(input, JsonObject.class); + if (jsonObject != null) { + if (jsonObject.has(INDEX_FIELD)) { + index = jsonObject.get(INDEX_FIELD).getAsString(); + } + if (jsonObject.has(DOC_ID_FIELD)) { + docId = jsonObject.get(DOC_ID_FIELD).getAsString(); + } + if (jsonObject.has(TIMESTAMP_FIELD)) { + timestampField = jsonObject.get(TIMESTAMP_FIELD).getAsString(); + } + if (jsonObject.has(COUNT_FIELD)) { + count = jsonObject.get(COUNT_FIELD).getAsInt(); + } + } + } catch (JsonSyntaxException e) { + log.error("Invalid JSON input: {}", input, e); + } + } + + // Fall back to direct parameters + if (StringUtils.isEmpty(index)) { + index = parameters.get(INDEX_FIELD); + } + if (StringUtils.isEmpty(docId)) { + docId = parameters.get(DOC_ID_FIELD); + } + if (StringUtils.isEmpty(timestampField)) { + timestampField = parameters.get(TIMESTAMP_FIELD); + } + if (count == null && parameters.containsKey(COUNT_FIELD)) { + try { + count = Double.valueOf(parameters.get(COUNT_FIELD)).intValue(); + } catch (NumberFormatException e) { + log.error("Invalid count parameter: {}", parameters.get(COUNT_FIELD), e); + } + } + + // Validate all required parameters + if (StringUtils.isEmpty(index) || StringUtils.isEmpty(docId) || StringUtils.isEmpty(timestampField) || count == null) { + listener + .onFailure( + new IllegalArgumentException( + "SearchAroundDocumentTool requires: index, doc_id, timestamp_field, and count parameters" + ) + ); + return; + } + + final String finalIndex = index; + final String finalDocId = docId; + final String finalTimestampField = timestampField; + final int finalCount = count; + + // Step 1: Fetch the target document by ID with sort values + log + .debug( + "Calling SearchAroundDocumentTool: index={}, doc_id={}, timestamp_field={}, count={}", + index, + docId, + timestampField, + count + ); + + SearchSourceBuilder targetSource = new SearchSourceBuilder() + .query(QueryBuilders.idsQuery().addIds(docId)) + .sort(new FieldSortBuilder(timestampField).order(SortOrder.DESC).unmappedType("boolean")) + .sort(new FieldSortBuilder("_doc").order(SortOrder.DESC).unmappedType("boolean")) + .size(1); + SearchRequest targetRequest = new SearchRequest(index).source(targetSource); + + client.search(targetRequest, ActionListener.wrap(targetResponse -> { + SearchHit[] targetHits = targetResponse.getHits().getHits(); + if (targetHits == null || targetHits.length == 0) { + listener.onFailure(new IllegalArgumentException("Document not found: " + finalDocId)); + return; + } + + SearchHit targetHit = targetHits[0]; + Object[] sortValues = targetHit.getSortValues(); + if (sortValues == null || sortValues.length < 2) { + listener.onFailure(new IllegalArgumentException("Could not get sort values from target document")); + return; + } + + // Build target document response + Map targetDoc = processResponse(targetHit); + + // Step 2: Search for documents BEFORE using search_after with DESC sort + BoolQueryBuilder beforeQuery = QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(finalDocId)); + + SearchSourceBuilder beforeSource = new SearchSourceBuilder() + .query(beforeQuery) + .sort(new FieldSortBuilder(finalTimestampField).order(SortOrder.DESC).unmappedType("boolean")) + .sort(new FieldSortBuilder("_doc").order(SortOrder.DESC).unmappedType("boolean")) + .searchAfter(sortValues) + .size(finalCount); + SearchRequest beforeRequest = new SearchRequest(finalIndex).source(beforeSource); + + client.search(beforeRequest, ActionListener.wrap(beforeResponse -> { + // Step 3: Search for documents AFTER using search_after with ASC sort + BoolQueryBuilder afterQuery = QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(finalDocId)); + + SearchSourceBuilder afterSource = new SearchSourceBuilder() + .query(afterQuery) + .sort(new FieldSortBuilder(finalTimestampField).order(SortOrder.ASC).unmappedType("boolean")) + .sort(new FieldSortBuilder("_doc").order(SortOrder.ASC).unmappedType("boolean")) + .searchAfter(sortValues) + .size(finalCount); + SearchRequest afterRequest = new SearchRequest(finalIndex).source(afterSource); + + client.search(afterRequest, ActionListener.wrap(afterSearchResponse -> { + + // Process "before" results (need to reverse to get chronological order) + SearchHit[] beforeHits = beforeResponse.getHits().getHits(); + List> beforeDocs = new ArrayList<>(); + for (SearchHit hit : beforeHits) { + beforeDocs.add(processResponse(hit)); + } + Collections.reverse(beforeDocs); + List> result = new ArrayList<>(beforeDocs); + + // Add target document + result.add(targetDoc); + + // Process "after" results + SearchHit[] afterHits = afterSearchResponse.getHits().getHits(); + for (SearchHit hit : afterHits) { + result.add(processResponse(hit)); + } + + String resultJson = GSON.toJson(result); + listener.onResponse((T) resultJson); + }, e -> { + log.error("Failed to search for documents after target", e); + listener.onFailure(e); + })); + }, e -> { + log.error("Failed to search for documents before target", e); + listener.onFailure(e); + })); + }, e -> { + log.error("Failed to fetch target document", e); + listener.onFailure(e); + })); + } catch (Exception e) { + log.error("Failed to run SearchAroundDocumentTool", e); + listener.onFailure(e); + } + } + + public static class Factory implements Tool.Factory { + + private Client client; + private static Factory INSTANCE; + + private NamedXContentRegistry xContentRegistry; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchAroundDocumentTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public SearchAroundDocumentTool create(Map params) { + return new SearchAroundDocumentTool(client, xContentRegistry); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + + @Override + public Map getDefaultAttributes() { + return DEFAULT_ATTRIBUTES; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java b/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java index 0e928c73..91c2bf14 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java @@ -30,6 +30,7 @@ import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.SortOrder; @@ -84,7 +85,8 @@ public Object parse(Object o) { // number of total monitors. The output will likely need to be updated, standardized, and include more fields in the // future to cover a sufficient amount of potential questions the agent will need to handle. @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); final String monitorId = parameters.getOrDefault("monitorId", null); final String monitorName = parameters.getOrDefault("monitorName", null); final String monitorNamePattern = parameters.getOrDefault("monitorNamePattern", null); diff --git a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java index c7081fcb..a1d2fa7d 100644 --- a/src/main/java/org/opensearch/agent/tools/WebSearchTool.java +++ b/src/main/java/org/opensearch/agent/tools/WebSearchTool.java @@ -8,6 +8,9 @@ import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -17,23 +20,25 @@ import java.util.Optional; import org.apache.commons.lang3.math.NumberUtils; -import org.apache.hc.client5.http.classic.methods.HttpGet; -import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; -import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; -import org.apache.hc.client5.http.impl.classic.HttpClients; import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.io.entity.EntityUtils; import org.jsoup.Connection; import org.jsoup.Jsoup; import org.jsoup.nodes.Document; import org.jsoup.nodes.Element; import org.jsoup.select.Elements; +import org.opensearch.OpenSearchStatusException; import org.opensearch.agent.ToolPlugin; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.threadpool.ThreadPool; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import com.google.common.collect.ImmutableMap; import com.google.gson.JsonArray; @@ -44,6 +49,14 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.http.async.AsyncExecuteRequest; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; @Log4j2 @Setter @@ -76,6 +89,28 @@ public class WebSearchTool implements Tool { + "}"; public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, "strict", false); + public static final String NEXT_PAGE = "next_page"; + public static final String ENGINE_ID = "engine_id"; + public static final String OFFSET = "offset"; + public static final String DUCKDUCKGO = "duckduckgo"; + public static final String GOOGLE = "google"; + public static final String BING = "bing"; + public static final String CUSTOM = "custom"; + public static final String ITEMS = "items"; + public static final String ENGINE = "engine"; + public static final String ENDPOINT = "endpoint"; + public static final String API_KEY = "api_key"; + public static final String CUSTOM_API = "custom_api"; + public static final String AUTHORIZATION = "Authorization"; + public static final String TITLE = "title"; + public static final String URL = "url"; + public static final String CONTENT = "content"; + public static final String QUERY = "query"; + public static final String QUESTION = "question"; + public static final String QUERY_KEY = "query_key"; + public static final String LIMIT_KEY = "limit_key"; + public static final String CUSTOM_RES_URL_JSONPATH = "custom_res_url_jsonpath"; + public static final String START = "start"; @Setter @Getter @@ -85,13 +120,15 @@ public class WebSearchTool implements Tool { private String description = DEFAULT_DESCRIPTION; @Getter private String version; - private CloseableHttpClient httpClient; + private final SdkAsyncHttpClient httpClient; private final ThreadPool threadPool; private Map attributes; public WebSearchTool(ThreadPool threadPool) { - this.httpClient = HttpClients.createDefault(); + // Use 1s for connection timeout, 3s for read timeout, 30 for max connections of httpclient. + // For WebSearchTool, we don't allow user to connect to private ip. + this.httpClient = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(1), Duration.ofSeconds(3), 30, false); this.threadPool = threadPool; this.attributes = new HashMap<>(); attributes.put(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); @@ -99,103 +136,107 @@ public WebSearchTool(ThreadPool threadPool) { } @Override - public void run(Map parameters, ActionListener listener) { + public void run(Map originalParameters, ActionListener listener) { try { + Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); // common search parameters - String query = parameters.getOrDefault("query", parameters.get("question")).replaceAll(" ", "+"); - String engine = parameters.getOrDefault("engine", "google"); - String endpoint = parameters.getOrDefault("endpoint", getDefaultEndpoint(engine)); - String apiKey = parameters.get("api_key"); - String nextPage = parameters.get("next_page"); + String query = parameters.getOrDefault(QUERY, parameters.get(QUESTION)).replaceAll(" ", "+"); + String engine = parameters.getOrDefault(ENGINE, GOOGLE); + String endpoint = parameters.getOrDefault(ENDPOINT, getDefaultEndpoint(engine)); + String apiKey = parameters.get(API_KEY); + String nextPage = parameters.get(NEXT_PAGE); // Google search parameters - String engineId = parameters.get("engine_id"); + String engineId = parameters.get(ENGINE_ID); // Custom search parameters - String authorization = parameters.get("Authorization"); - String queryKey = parameters.getOrDefault("query_key", "q"); - String offsetKey = parameters.getOrDefault("offset_key", "offset"); - String limitKey = parameters.getOrDefault("limit_key", "limit"); - String customResUrlJsonpath = parameters.get("custom_res_url_jsonpath"); + String authorization = parameters.get(AUTHORIZATION); + String queryKey = parameters.getOrDefault(QUERY_KEY, "q"); + String offsetKey = parameters.getOrDefault(OFFSET + "_key", OFFSET); + String limitKey = parameters.getOrDefault(LIMIT_KEY, "limit"); + String customResUrlJsonpath = parameters.get(CUSTOM_RES_URL_JSONPATH); + threadPool.executor(ToolPlugin.WEBSEARCH_CRAWLER_THREADPOOL).submit(() -> { - try { - String parsedNextPage = null; - if ("duckduckgo".equalsIgnoreCase(engine)) { - // duckduckgo has different approach to other APIs as it's not a standard public API. + String parsedNextPage; + if (DUCKDUCKGO.equalsIgnoreCase(engine)) { + // duckduckgo has different approach to other APIs as it's not a standard public API. + if (nextPage != null) { + fetchDuckDuckGoResult(nextPage, listener); + } else { + fetchDuckDuckGoResult(buildDDGEndpoint(getDefaultEndpoint(engine), query), listener); + } + } else { + SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder().method(SdkHttpMethod.GET); + if (GOOGLE.equalsIgnoreCase(engine)) { if (nextPage != null) { - fetchDuckDuckGoResult(nextPage, listener); + builder.uri(nextPage); + parsedNextPage = buildGoogleNextPage(endpoint, engineId, query, apiKey, nextPage); } else { - fetchDuckDuckGoResult(buildDDGEndpoint(getDefaultEndpoint(engine), query), listener); + builder.uri(buildGoogleUrl(endpoint, engineId, query, apiKey, 0)); + parsedNextPage = buildGoogleUrl(endpoint, engineId, query, apiKey, 10); } - } else { - HttpGet getRequest = null; - if ("google".equalsIgnoreCase(engine)) { - if (nextPage != null) { - getRequest = new HttpGet(nextPage); - parsedNextPage = buildGoogleNextPage(endpoint, engineId, query, apiKey, nextPage); - } else { - getRequest = new HttpGet(buildGoogleUrl(endpoint, engineId, query, apiKey, 0)); - parsedNextPage = buildGoogleUrl(endpoint, engineId, query, apiKey, 10); - } - } else if ("bing".equalsIgnoreCase(engine)) { - if (nextPage != null) { - getRequest = new HttpGet(nextPage); - parsedNextPage = buildBingNextPage(endpoint, query, nextPage); - } else { - getRequest = new HttpGet(buildBingUrl(endpoint, query, 0)); - parsedNextPage = buildBingUrl(endpoint, query, 10); - } - getRequest.addHeader("Ocp-Apim-Subscription-Key", apiKey); - } else if ("custom".equalsIgnoreCase(engine)) { - if (nextPage != null) { - getRequest = new HttpGet(nextPage); - parsedNextPage = buildCustomNextPage(endpoint, nextPage, queryKey, query, offsetKey, limitKey); - } else { - getRequest = new HttpGet(buildCustomUrl(endpoint, queryKey, query, offsetKey, 0, limitKey)); - parsedNextPage = buildCustomUrl(endpoint, queryKey, query, offsetKey, 10, limitKey); - } - getRequest.addHeader("Authorization", authorization); + } else if (BING.equalsIgnoreCase(engine)) { + if (nextPage != null) { + builder.uri(nextPage); + parsedNextPage = buildBingNextPage(endpoint, query, nextPage); } else { - // Search engine not supported. - listener.onFailure(new IllegalArgumentException("Unsupported search engine: %s".formatted(engine))); - return; + builder.uri(buildBingUrl(endpoint, query, 0)); + parsedNextPage = buildBingUrl(endpoint, query, 10); } - CloseableHttpResponse res = httpClient.execute(getRequest); - if (res.getCode() >= HttpStatus.SC_BAD_REQUEST) { - listener - .onFailure( - new IllegalArgumentException("Web search failed: %d %s".formatted(res.getCode(), res.getReasonPhrase())) - ); + builder.putHeader("Ocp-Apim-Subscription-Key", apiKey); + } else if (CUSTOM.equalsIgnoreCase(engine)) { + if (nextPage != null) { + builder.uri(nextPage); + parsedNextPage = buildCustomNextPage(endpoint, nextPage, queryKey, query, offsetKey, limitKey); } else { - String responseString = EntityUtils.toString(res.getEntity()); - parseResponse(responseString, authorization, parsedNextPage, engine, customResUrlJsonpath, listener); + builder.uri(buildCustomUrl(endpoint, queryKey, query, offsetKey, 0, limitKey)); + parsedNextPage = buildCustomUrl(endpoint, queryKey, query, offsetKey, 10, limitKey); } + builder.putHeader(AUTHORIZATION, authorization); + } else { + // Search engine not supported. + listener + .onFailure(new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported search engine: %s", engine))); + return; + } + SdkHttpFullRequest getRequest = builder.build(); + AsyncExecuteRequest executeRequest = AsyncExecuteRequest + .builder() + .request(getRequest) + .requestContentPublisher(new SimpleHttpContentPublisher(getRequest)) + .responseHandler( + new WebSearchResponseHandler(endpoint, authorization, parsedNextPage, engine, customResUrlJsonpath, listener) + ) + .build(); + try { + httpClient.execute(executeRequest); + } catch (Exception e) { + log.error("Web search failed!", e); + listener.onFailure(new IllegalStateException(String.format(Locale.ROOT, "Web search failed: %s", e.getMessage()))); } - } catch (Exception e) { - listener.onFailure(new IllegalStateException("Web search failed: %s".formatted(e.getMessage()))); } }); } catch (Exception e) { - listener.onFailure(new IllegalStateException("Web search failed: %s".formatted(e.getMessage()))); + listener.onFailure(new IllegalStateException(String.format(Locale.ROOT, "Web search failed: %s", e.getMessage()))); } } private String buildDDGEndpoint(String endpoint, String query) { - return "%s?q=%s".formatted(endpoint, query); + return String.format(Locale.ROOT, "%s?q=%s", endpoint, query); } private String buildGoogleNextPage(String endpoint, String engineId, String query, String apiKey, String currentPage) { - String[] offsetSplit = currentPage.split("&start="); + String[] offsetSplit = currentPage.split("&" + START + "="); int offset = NumberUtils.toInt(offsetSplit[1], 0) + 10; return buildGoogleUrl(endpoint, engineId, query, apiKey, offset); } private String buildGoogleUrl(String endpoint, String engineId, String query, String apiKey, int start) { - return "%s?q=%s&cx=%s&key=%s&start=%d".formatted(endpoint, query, engineId, apiKey, start); + return String.format(Locale.ROOT, "%s?q=%s&cx=%s&key=%s&" + START + "=%d", endpoint, query, engineId, apiKey, start); } private String buildBingNextPage(String endpoint, String query, String currentPage) { - String[] offsetSplit = currentPage.split("&offset="); + String[] offsetSplit = currentPage.split("&" + OFFSET + "="); int offset = NumberUtils.toInt(offsetSplit[1], 0) + 10; return buildBingUrl(endpoint, query, offset); } @@ -208,100 +249,28 @@ private String buildCustomNextPage( String offsetKey, String limitKey ) { - String[] pageSplit = currentPage.split("&%s=".formatted(offsetKey)); + String[] pageSplit = currentPage.split(String.format(Locale.ROOT, "&%s=", offsetKey)); int offsetValue = NumberUtils.toInt(pageSplit[1].split("&")[0], 0) + 10; return buildCustomUrl(endpoint, queryKey, query, offsetKey, offsetValue, limitKey); } private String buildCustomUrl(String endpoint, String queryKey, String query, String offsetKey, int offsetValue, String limitKey) { - return "%s?%s=%s&%s=%d&%s=10".formatted(endpoint, queryKey, query, offsetKey, offsetValue, limitKey); + return String.format(Locale.ROOT, "%s?%s=%s&%s=%d&%s=10", endpoint, queryKey, query, offsetKey, offsetValue, limitKey); } private String getDefaultEndpoint(String engine) { return switch (engine.toLowerCase(Locale.ROOT)) { - case "google" -> "https://customsearch.googleapis.com/customsearch/v1"; - case "bing" -> "https://api.bing.microsoft.com/v7.0/search"; - case "duckduckgo" -> "https://duckduckgo.com/html"; - case "custom" -> null; - default -> throw new IllegalArgumentException("Unsupported search engine: %s".formatted(engine)); + case GOOGLE -> "https://customsearch.googleapis.com/customsearch/v1"; + case BING -> "https://api.bing.microsoft.com/v7.0/search"; + case DUCKDUCKGO -> "https://duckduckgo.com/html"; + case CUSTOM -> null; + default -> throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported search engine: %s", engine)); }; } // pagination: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/page-results#paging-through-search-results private String buildBingUrl(String endpoint, String query, int offset) { - return "%s?q%s&textFormat=HTML&count=10&offset=%d".formatted(endpoint, query, offset); - } - - private void parseResponse( - String rawResponse, - String authorization, - String nextPage, - String engine, - String customResUrlJsonpath, - ActionListener listener - ) { - JsonObject rawJson = JsonParser.parseString(rawResponse).getAsJsonObject(); - switch (engine.toLowerCase(Locale.ROOT)) { - case "google": - parseGoogleResults(rawJson, nextPage, listener); - break; - case "bing": - parseBingResults(rawJson, nextPage, listener); - break; - case "custom": - List urls = JsonPath.read(rawResponse, customResUrlJsonpath); - parseCustomResults(urls, authorization, nextPage, listener); - break; - default: - listener.onFailure(new RuntimeException("Unsupported search engine: %s".formatted(engine))); - } - } - - private void parseGoogleResults(JsonObject googleResponse, String nextPage, ActionListener listener) { - Map results = new HashMap<>(); - results.put("next_page", nextPage); - // extract search results, each item is a search result: - // https://developers.google.com/custom-search/v1/reference/rest/v1/Search#result - JsonArray items = googleResponse.getAsJsonArray("items"); - List> crawlResults = new ArrayList<>(); - for (int i = 0; i < items.size(); i++) { - JsonObject item = items.get(i).getAsJsonObject(); - // extract the actual link for scrawl. - String link = item.get("link").getAsString(); - // extract title and content. - Map crawlResult = crawlPage(link, null); - crawlResults.add(crawlResult); - } - results.put("items", crawlResults); - listener.onResponse((T) StringUtils.gson.toJson(results)); - } - - private void parseBingResults(JsonObject bingResponse, String nextPage, ActionListener listener) { - Map results = new HashMap<>(); - results.put("next_page", nextPage); - List> crawlResults = new ArrayList<>(); - JsonArray values = bingResponse.get("webPages").getAsJsonObject().getAsJsonArray("value"); - for (int i = 0; i < values.size(); i++) { - JsonObject value = values.get(i).getAsJsonObject(); - String link = value.get("url").getAsString(); - Map crawlResult = crawlPage(link, null); - crawlResults.add(crawlResult); - } - results.put("items", crawlResults); - listener.onResponse((T) StringUtils.gson.toJson(results)); - } - - private void parseCustomResults(List urls, String authorization, String nextPage, ActionListener listener) { - Map results = new HashMap<>(); - results.put("next_page", nextPage); - List> crawlResults = new ArrayList<>(); - for (int i = 0; i < urls.size(); i++) { - String link = urls.get(i); - Map crawlResult = crawlPage(link, authorization); - crawlResults.add(crawlResult); - } - results.put("items", crawlResults); - listener.onResponse((T) StringUtils.gson.toJson(results)); + return String.format(Locale.ROOT, "%s?q%s&textFormat=HTML&count=10&" + OFFSET + "=%d", endpoint, query, offset); } private void fetchDuckDuckGoResult(String endpoint, ActionListener listener) { @@ -333,8 +302,8 @@ private void fetchDuckDuckGoResult(String endpoint, ActionListener listen Map crawlResult = crawlPage(link, null); crawlResults.add(crawlResult); } - results.put("next_page", nextPage); - results.put("items", crawlResults); + results.put(NEXT_PAGE, nextPage); + results.put(ITEMS, crawlResults); listener.onResponse((T) StringUtils.gson.toJson(results)); } catch (IOException e) { log.error("Failed to fetch duckduckgo results due to exception!"); @@ -376,48 +345,6 @@ private String getDDGNextPageLink(String endpoint, Document doc) { return sb.toString(); } - /** - * crawl a page and put the page content into the results map if it can be crawled successfully. - * - * @param url The url to crawl - */ - private Map crawlPage(String url, String authorization) { - try { - Connection connection = Jsoup.connect(url).timeout(10000).userAgent(USER_AGENT); - if (authorization != null) { - connection.header("Authorization", authorization); - } - Document doc = connection.get(); - Elements parentElements = doc.select("body"); - if (isCaptchaOrLoginPage(doc)) { - log.debug("Skipping {} - CAPTCHA required", url); - return null; - } - - Element bodyElement = parentElements.getFirst(); - String title = bodyElement.select("title").text(); - String content = bodyElement.text(); - return ImmutableMap.of("url", url, "title", title, "content", content); - } catch (Exception e) { - log.error("Failed to crawl link: {}", url); - return null; - } - } - - private boolean isCaptchaOrLoginPage(Document doc) { - String html = doc.html().toLowerCase(Locale.ROOT); - // 1. Check for CAPTCHA indicators - return !doc.select("input[name*='captcha'], input[id*='captcha']").isEmpty() || - // Google reCAPTCHA markers - !doc.select(".g-recaptcha, div[data-sitekey]").isEmpty() || - // CAPTCHA image patterns - !doc.select("img[src*='captcha'], img[src*='recaptcha']").isEmpty() || - // Text-based indicators - org.apache.commons.lang3.StringUtils.containsIgnoreCase(html, "verify you are human") || - // hCAPTCHA detection - !doc.select(".h-captcha").isEmpty(); - } - @Override public String getType() { return TYPE; @@ -425,46 +352,46 @@ public String getType() { @Override public boolean validate(Map parameters) { - String engine = parameters.get("engine"); + String engine = parameters.get(ENGINE); if (org.apache.commons.lang3.StringUtils.isEmpty(engine)) { return false; } - boolean isQueryEmpty = org.apache.commons.lang3.StringUtils.isEmpty(parameters.getOrDefault("query", parameters.get("question"))); + boolean isQueryEmpty = org.apache.commons.lang3.StringUtils.isEmpty(parameters.getOrDefault(QUERY, parameters.get(QUESTION))); if (isQueryEmpty) { log.warn("Query is empty"); return false; } boolean isEndpointEmpty = org.apache.commons.lang3.StringUtils - .isEmpty(parameters.getOrDefault("endpoint", getDefaultEndpoint(engine))); + .isEmpty(parameters.getOrDefault(ENDPOINT, getDefaultEndpoint(engine))); if (isEndpointEmpty) { log.warn("Endpoint is empty"); return false; } - if ("google".equalsIgnoreCase(engine)) { - boolean hasEngineIdAndApiKey = parameters.containsKey("engine_id") - && !parameters.get("engine_id").isEmpty() - && parameters.containsKey("api_key") - && !parameters.get("api_key").isEmpty(); + if (GOOGLE.equalsIgnoreCase(engine)) { + boolean hasEngineIdAndApiKey = parameters.containsKey(ENGINE_ID) + && !parameters.get(ENGINE_ID).isEmpty() + && parameters.containsKey(API_KEY) + && !parameters.get(API_KEY).isEmpty(); if (!hasEngineIdAndApiKey) { - log.warn("Google search engine_id or api_key is empty"); + log.warn("Google search" + ENGINE_ID + "or api_key is empty"); return false; } return true; - } else if ("duckduckgo".equalsIgnoreCase(engine)) { + } else if (DUCKDUCKGO.equalsIgnoreCase(engine)) { return true; - } else if ("bing".equalsIgnoreCase(engine)) { - boolean hasApiKey = org.apache.commons.lang3.StringUtils.isEmpty(parameters.get("api_key")); + } else if (BING.equalsIgnoreCase(engine)) { + boolean hasApiKey = org.apache.commons.lang3.StringUtils.isEmpty(parameters.get(API_KEY)); if (!hasApiKey) { log.warn("Bing search api_key is empty"); return false; } return true; - } else if ("custom".equalsIgnoreCase(engine)) { - String customApi = parameters.get("custom_api"); - String customResUrlJsonpath = parameters.get("custom_res_url_jsonpath"); + } else if (CUSTOM.equalsIgnoreCase(engine)) { + String customApi = parameters.get(CUSTOM_API); + String customResUrlJsonpath = parameters.get(CUSTOM_RES_URL_JSONPATH); if (org.apache.commons.lang3.StringUtils.isEmpty(customApi) || org.apache.commons.lang3.StringUtils.isEmpty(customResUrlJsonpath)) { log.warn("custom search API is empty or result json path is empty"); @@ -478,6 +405,48 @@ public boolean validate(Map parameters) { return false; } + /** + * crawl a page and put the page content into the results map if it can be crawled successfully. + * + * @param url The url to crawl + */ + public Map crawlPage(String url, String authorization) { + try { + Connection connection = Jsoup.connect(url).timeout(10000).userAgent(USER_AGENT); + if (authorization != null) { + connection.header(AUTHORIZATION, authorization); + } + Document doc = connection.get(); + Elements parentElements = doc.select("body"); + if (isCaptchaOrLoginPage(doc)) { + log.debug("Skipping {} - CAPTCHA required", url); + return null; + } + + Element bodyElement = parentElements.getFirst(); + String title = bodyElement.select(TITLE).text(); + String content = bodyElement.text(); + return ImmutableMap.of(URL, url, TITLE, title, CONTENT, content); + } catch (Exception e) { + log.error("Failed to crawl link: {}", url); + return null; + } + } + + private boolean isCaptchaOrLoginPage(Document doc) { + String html = doc.html().toLowerCase(Locale.ROOT); + // 1. Check for CAPTCHA indicators + return !doc.select("input[name*='captcha'], input[id*='captcha']").isEmpty() || + // Google reCAPTCHA markers + !doc.select(".g-recaptcha, div[data-sitekey]").isEmpty() || + // CAPTCHA image patterns + !doc.select("img[src*='captcha'], img[src*='recaptcha']").isEmpty() || + // Text-based indicators + org.apache.commons.lang3.StringUtils.containsIgnoreCase(html, "verify you are human") || + // hCAPTCHA detection + !doc.select(".h-captcha").isEmpty(); + } + public static class Factory implements Tool.Factory { private static Factory INSTANCE; private ThreadPool threadPool; @@ -522,4 +491,163 @@ public Map getDefaultAttributes() { return DEFAULT_ATTRIBUTES; } } + + private final class WebSearchResponseHandler implements SdkAsyncHttpResponseHandler { + private final String endpoint; + private final String authorization; + private final String parsedNextPage; + private final String engine; + private final String customResUrlJsonpath; + private final ActionListener listener; + + public WebSearchResponseHandler( + String endpoint, + String authorization, + String parsedNextPage, + String engine, + String customResUrlJsonpath, + ActionListener listener + ) { + this.endpoint = endpoint; + this.authorization = authorization; + this.parsedNextPage = parsedNextPage; + this.engine = engine; + this.customResUrlJsonpath = customResUrlJsonpath; + this.listener = listener; + } + + @Override + public void onHeaders(SdkHttpResponse response) { + SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse) response; + log.debug("received response headers: " + sdkResponse.headers()); + int statusCode = sdkResponse.statusCode(); + if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) { + log + .error( + "Received error from endpoint:{} with status code {}, response headers: {}", + endpoint, + statusCode, + sdkResponse.headers() + ); + listener + .onFailure( + new OpenSearchStatusException( + String.format(Locale.ROOT, "Failed to fetch results from endpoint: %s", endpoint), + RestStatus.fromCode(statusCode) + ) + ); + } + } + + @Override + public void onStream(Publisher stream) { + stream.subscribe(new Subscriber<>() { + private final StringBuilder responseBuilder = new StringBuilder(); + private Subscription subscription; + + @Override + public void onSubscribe(Subscription subscription) { + log.debug("Starting to fetch response..."); + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + responseBuilder.append(StandardCharsets.UTF_8.decode(byteBuffer)); + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onError(Throwable throwable) { + log.error("Failed to fetch results from endpoint: {}", endpoint, throwable); + listener.onFailure(new RuntimeException(throwable)); + } + + @Override + public void onComplete() { + log.debug("Successfully fetched results from endpoint: {}", endpoint); + parseResponse(responseBuilder.toString(), authorization, parsedNextPage, engine, customResUrlJsonpath, listener); + } + }); + } + + @Override + public void onError(Throwable error) { + log.error("Failed to fetch results from endpoint: {}", endpoint, error); + listener.onFailure(new RuntimeException(error)); + } + + private void parseResponse( + String rawResponse, + String authorization, + String nextPage, + String engine, + String customResUrlJsonpath, + ActionListener listener + ) { + JsonObject rawJson = JsonParser.parseString(rawResponse).getAsJsonObject(); + switch (engine.toLowerCase(Locale.ROOT)) { + case GOOGLE: + parseGoogleResults(rawJson, nextPage, listener); + break; + case BING: + parseBingResults(rawJson, nextPage, listener); + break; + case CUSTOM: + List urls = JsonPath.read(rawResponse, customResUrlJsonpath); + parseCustomResults(urls, authorization, nextPage, listener); + break; + default: + listener.onFailure(new RuntimeException(String.format(Locale.ROOT, "Unsupported search engine: %s", engine))); + } + } + + private void parseGoogleResults(JsonObject googleResponse, String nextPage, ActionListener listener) { + Map results = new HashMap<>(); + results.put(NEXT_PAGE, nextPage); + // extract search results, each item is a search result: + // https://developers.google.com/custom-search/v1/reference/rest/v1/Search#result + JsonArray items = googleResponse.getAsJsonArray(ITEMS); + List> crawlResults = new ArrayList<>(); + for (int i = 0; i < items.size(); i++) { + JsonObject item = items.get(i).getAsJsonObject(); + // extract the actual link for scrawl. + String link = item.get("link").getAsString(); + // extract title and content. + Map crawlResult = crawlPage(link, null); + crawlResults.add(crawlResult); + } + results.put(ITEMS, crawlResults); + listener.onResponse((T) StringUtils.gson.toJson(results)); + } + + private void parseBingResults(JsonObject bingResponse, String nextPage, ActionListener listener) { + Map results = new HashMap<>(); + results.put(NEXT_PAGE, nextPage); + List> crawlResults = new ArrayList<>(); + JsonArray values = bingResponse.get("webPages").getAsJsonObject().getAsJsonArray("value"); + for (int i = 0; i < values.size(); i++) { + JsonObject value = values.get(i).getAsJsonObject(); + String link = value.get(URL).getAsString(); + Map crawlResult = crawlPage(link, null); + crawlResults.add(crawlResult); + } + results.put(ITEMS, crawlResults); + listener.onResponse((T) StringUtils.gson.toJson(results)); + } + + private void parseCustomResults(List urls, String authorization, String nextPage, ActionListener listener) { + Map results = new HashMap<>(); + results.put(NEXT_PAGE, nextPage); + List> crawlResults = new ArrayList<>(); + for (int i = 0; i < urls.size(); i++) { + String link = urls.get(i); + Map crawlResult = crawlPage(link, authorization); + crawlResults.add(crawlResult); + } + results.put(ITEMS, crawlResults); + listener.onResponse((T) StringUtils.gson.toJson(results)); + } + } } diff --git a/src/main/java/org/opensearch/agent/tools/utils/PPLExecuteHelper.java b/src/main/java/org/opensearch/agent/tools/utils/PPLExecuteHelper.java new file mode 100644 index 00000000..10f7cd6a --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/PPLExecuteHelper.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils; + +import static org.opensearch.agent.tools.utils.ToolHelper.getPPLTransportActionListener; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; + +import org.json.JSONObject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; +import org.opensearch.sql.ppl.domain.PPLQueryRequest; +import org.opensearch.transport.client.Client; + +import com.google.common.collect.ImmutableMap; +import com.google.gson.reflect.TypeToken; + +import lombok.extern.log4j.Log4j2; + +/** + * Utility class for executing PPL queries and parsing results + */ +@Log4j2 +public class PPLExecuteHelper { + + /** + * Executes PPL query and parses the result using provided result parser + * + * @param The parsed result type + * @param client OpenSearch client + * @param ppl PPL query string to execute + * @param resultParser Function to parse PPL result into desired format + * @param listener Action listener for handling parsed results or failures + */ + public static void executePPLAndParseResult( + Client client, + String ppl, + Function, T> resultParser, + ActionListener listener + ) { + try { + JSONObject jsonContent = new JSONObject(ImmutableMap.of("query", ppl)); + PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc"); + TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest); + + client + .execute( + PPLQueryAction.INSTANCE, + transportPPLQueryRequest, + getPPLTransportActionListener(ActionListener.wrap(transportPPLQueryResponse -> { + String result = transportPPLQueryResponse.getResult(); + if (Strings.isEmpty(result)) { + listener.onFailure(new RuntimeException("Empty PPL response")); + } else { + Map pplResult = gson.fromJson(result, new TypeToken>() { + }.getType()); + if (pplResult.containsKey("error")) { + Object errorObj = pplResult.get("error"); + String errorDetail; + if (errorObj instanceof Map) { + Map errorMap = (Map) errorObj; + Object reason = errorMap.get("reason"); + errorDetail = reason != null ? reason.toString() : errorMap.toString(); + } else { + errorDetail = errorObj != null ? errorObj.toString() : "Unknown error"; + } + throw new RuntimeException("PPL query error: " + errorDetail); + } + + Object datarowsObj = pplResult.get("datarows"); + if (!(datarowsObj instanceof List)) { + throw new IllegalStateException("Invalid PPL response format: missing or invalid datarows"); + } + + listener.onResponse(resultParser.apply(pplResult)); + } + }, error -> { + log.error("PPL execution failed: {}", error.getMessage()); + listener.onFailure(new RuntimeException("PPL execution failed: " + error.getMessage(), error)); + })) + ); + } catch (Exception e) { + String errorMessage = String.format(Locale.ROOT, "Failed to execute PPL query: %s", e.getMessage()); + log.error(errorMessage, e); + listener.onFailure(new RuntimeException(errorMessage, e)); + } + } + + /** + * Helper method to create a result parser that extracts datarows + */ + public static Function, T> dataRowsParser(Function>, T> rowParser) { + return pplResult -> { + Object datarowsObj = pplResult.get("datarows"); + @SuppressWarnings("unchecked") + List> dataRows = (List>) datarowsObj; + if (dataRows.isEmpty()) { + log.debug("PPL query returned no data rows for the specified criteria"); + } + return rowParser.apply(dataRows); + }; + } +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java b/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java index b5433a0e..fa77024f 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java @@ -36,5 +36,4 @@ public static ModelType from(String value) { public static final String ALERTING_CONFIG_INDEX = ".opendistro-alerting-config"; public static final String ALERTING_ALERTS_INDEX = ".opendistro-alerting-alerts"; - } diff --git a/src/main/java/org/opensearch/agent/tools/utils/clustering/ClusteringHelper.java b/src/main/java/org/opensearch/agent/tools/utils/clustering/ClusteringHelper.java new file mode 100644 index 00000000..c937e997 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/clustering/ClusteringHelper.java @@ -0,0 +1,514 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils.clustering; + +import static org.opensearch.agent.tools.utils.clustering.HierarchicalAgglomerativeClustering.calculateCosineSimilarity; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.math3.ml.clustering.CentroidCluster; +import org.apache.commons.math3.ml.clustering.DoublePoint; +import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer; +import org.apache.commons.math3.ml.distance.DistanceMeasure; + +import com.google.common.collect.Lists; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ClusteringHelper { + private final double logVectorsClusteringThreshold; + + /** + * Constructor for ClusteringHelper + * + * @param logVectorsClusteringThreshold Threshold for determining when two vectors are similar + * Should be between 0 and 1.0 (inclusive) + * @throws IllegalArgumentException if threshold is outside valid range + */ + public ClusteringHelper(double logVectorsClusteringThreshold) { + if (logVectorsClusteringThreshold < 0.0 || logVectorsClusteringThreshold > 1.0) { + throw new IllegalArgumentException("Clustering threshold must be between 0.0 and 1.0, got: " + logVectorsClusteringThreshold); + } + this.logVectorsClusteringThreshold = logVectorsClusteringThreshold; + } + + /** + * Cluster log vectors using a two-phase approach and get representative vectors. + * Input validation is performed to ensure log vectors are valid. + * + * @param logVectors Map of trace IDs to their vector representations + * @return List of trace IDs representing the centroids of each cluster + * @throws IllegalArgumentException if logVectors contains invalid entries + */ + public List clusterLogVectorsAndGetRepresentative(Map logVectors) { + if (logVectors == null || logVectors.isEmpty()) { + return new ArrayList<>(); + } + + // Validate input vectors + validateLogVectors(logVectors); + + log.debug("Starting two-phase clustering for {} log vectors", logVectors.size()); + + // Convert map to arrays for processing + double[][] vectors = new double[logVectors.size()][]; + Map indexTraceIdMap = new HashMap<>(); + convertLogVectorsToArrays(logVectors, vectors, indexTraceIdMap); + + List finalCentroids; + + // Choose clustering approach based on dataset size + if (logVectors.size() > 1000) { + finalCentroids = processTwoPhaseClusteringForLargeDataset(vectors, indexTraceIdMap); + } else { + // Small dataset - use hierarchical clustering directly + finalCentroids = performClustering(vectors, indexTraceIdMap); + } + + log + .debug( + "Two-phase clustering completed: {} input vectors -> {} representative centroids", + logVectors.size(), + finalCentroids.size() + ); + + return finalCentroids; + } + + /** + * Converts log vectors map to arrays for processing + * + * @param logVectors Map of trace IDs to vector representations + * @param vectors Output array for vectors + * @param indexTraceIdMap Output map for index to trace ID mapping + */ + private void convertLogVectorsToArrays(Map logVectors, double[][] vectors, Map indexTraceIdMap) { + int i = 0; + for (Map.Entry entry : logVectors.entrySet()) { + vectors[i] = entry.getValue(); + indexTraceIdMap.put(i, entry.getKey()); + i++; + } + } + + /** + * Processes large datasets using two-phase clustering approach + * + * @param vectors Array of vectors + * @param indexTraceIdMap Mapping from vector index to trace ID + * @return List of trace IDs representing cluster centroids + */ + private List processTwoPhaseClusteringForLargeDataset(double[][] vectors, Map indexTraceIdMap) { + List finalCentroids = new ArrayList<>(); + log.debug("Large dataset detected ({}), applying K-means pre-clustering", vectors.length); + + // Calculate optimal number of K-means clusters (target 500 points per cluster) + int targetClusterSize = 500; + int numKMeansClusters = (vectors.length + (targetClusterSize - 1)) / targetClusterSize; + + log.debug("Using {} K-means clusters for pre-clustering", numKMeansClusters); + + try { + List> kMeansClusters = performKMeansClustering(vectors, numKMeansClusters); + + // Process each K-means cluster + for (int clusterIdx = 0; clusterIdx < kMeansClusters.size(); clusterIdx++) { + List kMeansCluster = kMeansClusters.get(clusterIdx); + log.debug("Processing K-means cluster {} with {} points", clusterIdx, kMeansCluster.size()); + + List clusterCentroids = processCluster(kMeansCluster, vectors, indexTraceIdMap, clusterIdx); + finalCentroids.addAll(clusterCentroids); + } + + } catch (Exception e) { + log.warn("K-means clustering failed, falling back to hierarchical clustering only: {}", e.getMessage()); + // Fallback to hierarchical clustering only + finalCentroids = performClustering(vectors, indexTraceIdMap); + } + + return finalCentroids; + } + + /** + * Processes a single K-means cluster + * + * @param kMeansCluster List of indices in the K-means cluster + * @param vectors Original vector array + * @param indexTraceIdMap Original mapping from indices to trace IDs + * @param clusterIdx Index of the cluster (for logging) + * @return List of trace IDs representing cluster centroids + */ + private List processCluster( + List kMeansCluster, + double[][] vectors, + Map indexTraceIdMap, + int clusterIdx + ) { + if (kMeansCluster.isEmpty()) { + return List.of(); + } + + if (kMeansCluster.size() == 1) { + return List.of(indexTraceIdMap.get(kMeansCluster.getFirst())); + } + + if (kMeansCluster.size() > 500) { + log.debug("The cluster size is greater than 500, performing partitioned clustering"); + return performHierarchicalClusteringOfPartition(kMeansCluster, vectors, indexTraceIdMap); + } + + log.debug("Applying hierarchical clustering to K-means cluster {} with {} points", clusterIdx, kMeansCluster.size()); + + // Extract vectors for this K-means cluster + double[][] clusterVectors = extractVectors(kMeansCluster, vectors); + Map clusterIndexTraceIdMap = createTraceIdMapping(kMeansCluster, indexTraceIdMap); + + // Apply hierarchical clustering within this K-means cluster + return performClustering(clusterVectors, clusterIndexTraceIdMap); + } + + /** + * Perform K-means clustering using Apache Commons Math3 + * + * @param vectors Input vectors for clustering + * @param numClusters Number of K-means clusters + * @return List of clusters, each containing indices of points in that cluster + * @throws RuntimeException if clustering fails + */ + private List> performKMeansClustering(double[][] vectors, int numClusters) { + if (vectors == null || vectors.length == 0) { + return new ArrayList<>(); + } + + if (numClusters <= 0) { + numClusters = 1; + } + + // Cap number of clusters to vector size + numClusters = Math.min(numClusters, vectors.length); + + try { + KMeansPlusPlusClusterer clusterer = createKMeansClusterer(numClusters); + List points = convertVectorsToPoints(vectors); + List> clusters = clusterer.cluster(points); + return extractClusterIndices(clusters, vectors); + } catch (Exception e) { + log.error("K-means clustering failed: {}", e.getMessage(), e); + throw new RuntimeException("K-means clustering failed: " + e.getMessage(), e); + } + } + + /** + * Creates a KMeansPlusPlusClusterer with cosine distance metric + * + * @param numClusters Number of clusters to create + * @return Configured KMeansPlusPlusClusterer + */ + private KMeansPlusPlusClusterer createKMeansClusterer(int numClusters) { + return new KMeansPlusPlusClusterer<>( + numClusters, + 300, // Maximum iterations + (DistanceMeasure) (a, b) -> 1 - calculateCosineSimilarity(a, b) + ); + } + + /** + * Converts vector array to list of DoublePoint objects + * + * @param vectors Array of vectors + * @return List of DoublePoint objects + */ + private List convertVectorsToPoints(double[][] vectors) { + List points = new ArrayList<>(vectors.length); + for (double[] vector : vectors) { + points.add(new DoublePoint(vector)); + } + return points; + } + + /** + * Validates log vectors to ensure they are valid for clustering + * + * @param logVectors Map of trace IDs to vector representations + * @throws IllegalArgumentException if vectors are invalid + */ + private void validateLogVectors(Map logVectors) { + int vectorDimension = -1; + + for (Map.Entry entry : logVectors.entrySet()) { + String traceId = entry.getKey(); + double[] vector = entry.getValue(); + + if (traceId == null || traceId.isEmpty()) { + throw new IllegalArgumentException("Trace ID cannot be null or empty"); + } + + if (vector == null) { + throw new IllegalArgumentException("Vector for trace ID '" + traceId + "' is null"); + } + + if (vector.length == 0) { + throw new IllegalArgumentException("Vector for trace ID '" + traceId + "' is empty"); + } + + // Ensure all vectors have the same dimension + if (vectorDimension == -1) { + vectorDimension = vector.length; + } else if (vector.length != vectorDimension) { + throw new IllegalArgumentException( + "Vector dimension mismatch: expected " + + vectorDimension + + " but got " + + vector.length + + " for trace ID '" + + traceId + + "'" + ); + } + + // Check for NaN or Infinity values + for (int i = 0; i < vector.length; i++) { + if (Double.isNaN(vector[i]) || Double.isInfinite(vector[i])) { + throw new IllegalArgumentException( + "Vector for trace ID '" + traceId + "' contains invalid value at index " + i + ": " + vector[i] + ); + } + } + } + } + + /** + * Extracts original vector indices for each K-means cluster + * + * @param clusters K-means clustering result + * @param vectors Original vector array + * @return List of clusters with original vector indices + */ + private List> extractClusterIndices(List> clusters, double[][] vectors) { + List> result = new ArrayList<>(); + for (CentroidCluster cluster : clusters) { + List clusterIndices = new ArrayList<>(); + for (DoublePoint point : cluster.getPoints()) { + // Find the original index of this point + for (int i = 0; i < vectors.length; i++) { + if (Arrays.equals(vectors[i], point.getPoint())) { + clusterIndices.add(i); + break; + } + } + } + if (!clusterIndices.isEmpty()) { + result.add(clusterIndices); + } + } + return result; + } + + /** + * Generic method to perform clustering with specified linkage method + * + * @param vectors Input vectors for clustering + * @param indexTraceIdMap Mapping from vector index to trace ID + * @return List of trace IDs representing cluster centroids + */ + private List performClustering(double[][] vectors, Map indexTraceIdMap) { + if (vectors == null || vectors.length == 0) { + return List.of(); + } + + if (vectors.length == 1) { + String traceId = indexTraceIdMap.get(0); + return List.of(traceId); + } + + List centroids = new ArrayList<>(); + try { + HierarchicalAgglomerativeClustering hac = new HierarchicalAgglomerativeClustering(vectors); + List clusters = hac + .fit(HierarchicalAgglomerativeClustering.LinkageMethod.COMPLETE, this.logVectorsClusteringThreshold); + + for (HierarchicalAgglomerativeClustering.ClusterNode cluster : clusters) { + int centroidIndex = hac.getClusterCentroid(cluster); + String traceId = indexTraceIdMap.get(centroidIndex); + centroids.add(traceId); + } + } catch (Exception e) { + log.error("Hierarchical clustering failed: {}", e.getMessage(), e); + // Fallback: return first point as representative if available + String traceId = indexTraceIdMap.get(0); + centroids.add(traceId); + } + + return centroids; + } + + /** + * If the first stage K-means clustering results exceed 500 clusters, implement batch processing and merge the results. + * @param kMeansCluster Clustering results from the first stage. + * @param vectors List of vectors by index. + * @param indexTraceIdMap Map of index to their trace id. + * @return List of trace IDs representing cluster centroids after partitioned processing + */ + private List performHierarchicalClusteringOfPartition( + List kMeansCluster, + double[][] vectors, + Map indexTraceIdMap + ) { + List> partition = Lists.partition(kMeansCluster, 500); + + List vectorRes = new ArrayList<>(); + Map index2Trace = new HashMap<>(); + + for (List partList : partition) { + double[][] clusterVectors = extractVectors(partList, vectors); + Map clusterIndexTraceIdMap = createTraceIdMapping(partList, indexTraceIdMap); + + log.debug("Starting performHierarchicalClusteringOfPartition!"); + processPartition(clusterVectors, clusterIndexTraceIdMap, vectorRes, index2Trace); + } + + return removeSimilarVectors(vectorRes, index2Trace); + } + + /** + * Extracts vectors for a partition based on indices + * + * @param partList List of indices in the partition + * @param vectors Original vector array + * @return Array of vectors for the partition + */ + private double[][] extractVectors(List partList, double[][] vectors) { + double[][] clusterVectors = new double[partList.size()][]; + for (int j = 0; j < partList.size(); j++) { + int originalIndex = partList.get(j); + clusterVectors[j] = vectors[originalIndex]; + } + return clusterVectors; + } + + /** + * Creates a mapping from partition indices to trace IDs + * + * @param partList List of indices in the partition + * @param indexTraceIdMap Original mapping from indices to trace IDs + * @return Mapping from partition indices to trace IDs + */ + private Map createTraceIdMapping(List partList, Map indexTraceIdMap) { + Map clusterIndexTraceIdMap = new HashMap<>(); + for (int j = 0; j < partList.size(); j++) { + int originalIndex = partList.get(j); + clusterIndexTraceIdMap.put(j, indexTraceIdMap.get(originalIndex)); + } + return clusterIndexTraceIdMap; + } + + /** + * Processes a partition for hierarchical clustering + * + * @param clusterVectors Vectors in the partition + * @param clusterIndexTraceIdMap Mapping from partition indices to trace IDs + * @param vectorRes Result vector collection to append to + * @param index2Trace Result mapping from indices to trace IDs to append to + */ + private void processPartition( + double[][] clusterVectors, + Map clusterIndexTraceIdMap, + List vectorRes, + Map index2Trace + ) { + if (clusterVectors.length == 0) { + return; + } + + if (clusterVectors.length == 1) { + vectorRes.add(clusterVectors[0]); + index2Trace.put(vectorRes.size() - 1, clusterIndexTraceIdMap.get(0)); + return; + } + + try { + HierarchicalAgglomerativeClustering hac = new HierarchicalAgglomerativeClustering(clusterVectors); + List clusters = hac + .fit(HierarchicalAgglomerativeClustering.LinkageMethod.COMPLETE, this.logVectorsClusteringThreshold); + log.info("Completing performHierarchicalClusteringOfPartition!"); + + for (HierarchicalAgglomerativeClustering.ClusterNode cluster : clusters) { + int centroidIndex = hac.getClusterCentroid(cluster); + vectorRes.add(clusterVectors[centroidIndex]); + index2Trace.put(vectorRes.size() - 1, clusterIndexTraceIdMap.get(centroidIndex)); + } + } catch (Exception e) { + log.error("Hierarchical clustering failed: {}", e.getMessage(), e); + // Fallback: return first point as representative + vectorRes.add(clusterVectors[0]); + index2Trace.put(vectorRes.size() - 1, clusterIndexTraceIdMap.get(0)); + } + } + + /** + * Compute the cosine similarity pairwise and remove vectors that are too similar. + * Vectors with similarity higher than threshold are considered duplicates. + * + * @param vectorRes List of vectors + * @param index2Trace Map of index to their trace id + * @return List of trace IDs after removing similar vectors + */ + private List removeSimilarVectors(List vectorRes, Map index2Trace) { + Set toRemove = new HashSet<>(); + + for (int i = 0; i < vectorRes.size(); i++) { + if (toRemove.contains(i)) { + continue; + } + + for (int j = i + 1; j < vectorRes.size(); j++) { + if (toRemove.contains(j)) { + continue; + } + + double similarity = calculateCosineSimilarity(vectorRes.get(i), vectorRes.get(j)); + // If similarity is higher than threshold, vectors are considered similar enough to remove one + if (similarity > this.logVectorsClusteringThreshold) { + log.debug("Removing similar vector with similarity: {}", similarity); + toRemove.add(j); + } + } + } + + log.debug("Removed {} similar vectors out of {}", toRemove.size(), vectorRes.size()); + return collectNonRemovedTraceIds(vectorRes, index2Trace, toRemove); + } + + /** + * Collects trace IDs for vectors that are not marked for removal + * + * @param vectors List of vectors + * @param indexToTraceMap Mapping from indices to trace IDs + * @param indicesToRemove Set of indices to exclude + * @return List of trace IDs for non-removed vectors + */ + private List collectNonRemovedTraceIds( + List vectors, + Map indexToTraceMap, + Set indicesToRemove + ) { + List result = new ArrayList<>(vectors.size() - indicesToRemove.size()); + for (int i = 0; i < vectors.size(); i++) { + if (!indicesToRemove.contains(i)) { + result.add(indexToTraceMap.get(i)); + } + } + return result; + } + +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/clustering/HierarchicalAgglomerativeClustering.java b/src/main/java/org/opensearch/agent/tools/utils/clustering/HierarchicalAgglomerativeClustering.java new file mode 100644 index 00000000..a9600118 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/clustering/HierarchicalAgglomerativeClustering.java @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils.clustering; + +import java.util.ArrayList; +import java.util.List; + +public class HierarchicalAgglomerativeClustering { + + private final double[][] data; + private final double[][] distanceMatrix; + private final int nSamples; + private final int nFeatures; + + public enum LinkageMethod { + SINGLE, // Minimum distance between clusters + COMPLETE, // Maximum distance between clusters + AVERAGE // Average distance between clusters + } + + /** + * Internal cluster node for tracking during clustering process + */ + public static class ClusterNode { + final int id; + final List samples; + final int size; + + ClusterNode(int id, int sample) { + this.id = id; + this.samples = new ArrayList<>(); + this.samples.add(sample); + this.size = 1; + } + + ClusterNode(int id, ClusterNode left, ClusterNode right) { + this.id = id; + this.samples = new ArrayList<>(); + this.samples.addAll(left.samples); + this.samples.addAll(right.samples); + this.size = left.size + right.size; + } + } + + /** + * Constructor - computes cosine distance matrix + */ + public HierarchicalAgglomerativeClustering(double[][] data) { + this.data = data; + this.nSamples = data.length; + this.nFeatures = data[0].length; + this.distanceMatrix = new double[nSamples][nSamples]; + + // Compute cosine distance matrix + computeCosineDistanceMatrix(); + } + + /** + * Compute pairwise cosine distances + * Cosine distance = 1 - cosine similarity + */ + private void computeCosineDistanceMatrix() { + // Pre-calculate norms for efficiency + double[] norms = new double[nSamples]; + for (int i = 0; i < nSamples; i++) { + double norm = 0.0; + for (int j = 0; j < nFeatures; j++) { + norm += data[i][j] * data[i][j]; + } + norms[i] = Math.sqrt(norm); + } + + // Calculate cosine distances + for (int i = 0; i < nSamples; i++) { + distanceMatrix[i][i] = 0.0; + for (int j = i + 1; j < nSamples; j++) { + double similarity = calculateCosineSimilarity(data[i], data[j], norms[i], norms[j]); + double distance = 1.0 - similarity; + distanceMatrix[i][j] = distanceMatrix[j][i] = distance; + } + } + } + + /** + * Optimized cosine similarity calculation with pre-calculated norms + */ + private static double calculateCosineSimilarity(double[] a, double[] b, double normA, double normB) { + if (normA == 0.0 || normB == 0.0) { + return 0.0; + } + + double dotProduct = 0.0; + for (int i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; + } + + return dotProduct / (normA * normB); + } + + /** + * Perform hierarchical clustering with distance threshold + * + * @param linkage The linkage method to use + * @param threshold Distance threshold - clustering stops when minimum distance exceeds this value + * @return List of final clusters + */ + public List fit(LinkageMethod linkage, double threshold) { + if (threshold < 0) { + throw new IllegalArgumentException("Distance threshold must be non-negative"); + } + + // Initialize clusters - each sample starts as its own cluster + List activeClusters = new ArrayList<>(); + for (int i = 0; i < nSamples; i++) { + activeClusters.add(new ClusterNode(i, i)); + } + + int nextClusterId = nSamples; + + // Main clustering loop + while (activeClusters.size() > 1) { + // Find the closest pair of clusters + int[] closestPair = findClosestClusters(activeClusters, linkage, threshold); + if (closestPair == null) { + break; + } + + int i = closestPair[0]; + int j = closestPair[1]; + + // Merge the two closest clusters + ClusterNode newCluster = new ClusterNode(nextClusterId++, activeClusters.get(i), activeClusters.get(j)); + + // Remove old clusters and add new one + activeClusters.remove(Math.max(i, j)); + activeClusters.remove(Math.min(i, j)); + activeClusters.add(newCluster); + } + + return activeClusters; + } + + /** + * Find the two closest clusters + */ + private int[] findClosestClusters(List clusters, LinkageMethod linkage, double threshold) { + double minDistance = threshold; + int bestI = -1, bestJ = -1; + + for (int i = 0; i < clusters.size(); i++) { + for (int j = i + 1; j < clusters.size(); j++) { + double distance = computeClusterDistance(clusters.get(i), clusters.get(j), linkage); + if (distance < minDistance) { + minDistance = distance; + bestI = i; + bestJ = j; + } + } + } + + return (bestI == -1) ? null : new int[] { bestI, bestJ }; + } + + /** + * Compute distance between clusters using specified linkage method + */ + private double computeClusterDistance(ClusterNode c1, ClusterNode c2, LinkageMethod linkage) { + return switch (linkage) { + case SINGLE -> singleLinkage(c1, c2); + case COMPLETE -> completeLinkage(c1, c2); + case AVERAGE -> averageLinkage(c1, c2); + }; + } + + /** + * Single linkage: minimum distance between any two points in different clusters + */ + private double singleLinkage(ClusterNode c1, ClusterNode c2) { + double minDist = Double.MAX_VALUE; + + for (int i : c1.samples) { + for (int j : c2.samples) { + double dist = distanceMatrix[i][j]; + if (dist < minDist) { + minDist = dist; + // Early termination for very small distances + if (minDist < 1e-10) { + return minDist; + } + } + } + } + + return minDist; + } + + /** + * Complete linkage: maximum distance between any two points in different clusters + */ + private double completeLinkage(ClusterNode c1, ClusterNode c2) { + double maxDist = Double.MIN_VALUE; + + for (int i : c1.samples) { + for (int j : c2.samples) { + double dist = distanceMatrix[i][j]; + if (dist > maxDist) { + maxDist = dist; + } + } + } + + return maxDist; + } + + /** + * Average linkage: average distance between all pairs of points in different clusters + */ + private double averageLinkage(ClusterNode c1, ClusterNode c2) { + double sumDist = 0.0; + int count = 0; + + for (int i : c1.samples) { + for (int j : c2.samples) { + sumDist += distanceMatrix[i][j]; + count++; + } + } + + return sumDist / count; + } + + /** + * Get cluster centroid (medoid) - the point with minimum total distance to other points in cluster + */ + public int getClusterCentroid(ClusterNode cluster) { + if (cluster.samples.size() == 1) { + return cluster.samples.getFirst(); + } + + int medoidIndex = cluster.samples.getFirst(); + double minTotalDistance = Double.MAX_VALUE; + + for (int pointI : cluster.samples) { + double totalDistance = 0.0; + for (int pointJ : cluster.samples) { + totalDistance += distanceMatrix[pointI][pointJ]; + } + + if (totalDistance < minTotalDistance) { + minTotalDistance = totalDistance; + medoidIndex = pointI; + } + } + + return medoidIndex; + } + + /** + * Backward compatibility method for cosine similarity calculation + */ + public static double calculateCosineSimilarity(double[] a, double[] b) { + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + + if (normA == 0 || normB == 0) { + return 0; + } + + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/DeepMergeRule.java b/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/DeepMergeRule.java new file mode 100644 index 00000000..749b5bc2 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/DeepMergeRule.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils.mergeMetaData; + +import java.util.Map; + +/** This rule will merge two array/struct object and merge their properties */ +public class DeepMergeRule implements MergeRule { + + @Override + public boolean isMatch(Map source, Map target) { + return source != null + && target != null + && source.get("properties") != null + && target.get("properties") != null + && source.getOrDefault("type", "object").equals(target.getOrDefault("type", "object")); + } + + @Override + public void mergeInto(String key, Map source, Map target) { + Map existing = (Map) target.get(key); + MergeRuleHelper.merge((Map) source.get("properties"), (Map) existing.get("properties")); + target.put(key, existing); + } +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/LatestRule.java b/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/LatestRule.java new file mode 100644 index 00000000..caa7b53e --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/LatestRule.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils.mergeMetaData; + +import java.util.Map; + +/** The rule always keep the latest one. */ +public class LatestRule implements MergeRule { + + @Override + public boolean isMatch(Map source, Map target) { + return true; + } + + @Override + public void mergeInto(String key, Map source, Map target) { + target.put(key, source); + } +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/MergeRule.java b/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/MergeRule.java new file mode 100644 index 00000000..16c9bd0b --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/MergeRule.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils.mergeMetaData; + +import java.util.Map; + +/** + * The Interface to merge index schemas. Need to implement isMatch: Whether match this rule, + * mergeInto, how to merge the source type to target map. + */ +public interface MergeRule { + boolean isMatch(Map source, Map target); + + void mergeInto(String key, Map source, Map target); +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/MergeRuleHelper.java b/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/MergeRuleHelper.java new file mode 100644 index 00000000..5e18cbf7 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/mergeMetaData/MergeRuleHelper.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils.mergeMetaData; + +import java.util.List; +import java.util.Map; + +public class MergeRuleHelper { + private static final List RULES = List + .of( + new DeepMergeRule(), + new LatestRule() // must come last + ); + + public static MergeRule selectRule(Map source, Map target) { + MergeRule resultRule = RULES.stream().filter(rule -> rule.isMatch(source, target)).findFirst().orElseThrow(); // logically + // unreachable if + // fallback exists + return resultRule; + } + + public static void merge(Map source, Map target) { + for (Map.Entry entry : source.entrySet()) { + String key = entry.getKey(); + Map sourceValue = (Map) entry.getValue(); + Map targetValue = (Map) target.get(key); + MergeRuleHelper.selectRule(sourceValue, targetValue).mergeInto(key, sourceValue, target); + } + } +} diff --git a/src/test/java/org/opensearch/agent/ToolPluginTest.java b/src/test/java/org/opensearch/agent/ToolPluginTest.java deleted file mode 100644 index 8bbe9138..00000000 --- a/src/test/java/org/opensearch/agent/ToolPluginTest.java +++ /dev/null @@ -1,12 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.agent; - -import org.opensearch.test.OpenSearchTestCase; - -public class ToolPluginTest extends OpenSearchTestCase { - -} diff --git a/src/test/java/org/opensearch/agent/ToolPluginTests.java b/src/test/java/org/opensearch/agent/ToolPluginTests.java new file mode 100644 index 00000000..d8cbc388 --- /dev/null +++ b/src/test/java/org/opensearch/agent/ToolPluginTests.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent; + +import static org.junit.Assert.assertEquals; + +import java.util.Collection; +import java.util.List; +import java.util.function.Supplier; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; +import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; +import org.opensearch.watcher.ResourceWatcherService; + +public class ToolPluginTests { + + @Mock + Client client; + @Mock + ClusterService clusterService; + @Mock + ThreadPool threadPool; + @Mock + ResourceWatcherService resourceWatcherService; + @Mock + ScriptService scriptService; + @Mock + NamedXContentRegistry xContentRegistry; + @Mock + Environment environment; + @Mock + NodeEnvironment nodeEnvironment; + @Mock + NamedWriteableRegistry namedWriteableRegistry; + @Mock + IndexNameExpressionResolver indexNameExpressionResolver; + @Mock + Supplier repositoriesServiceSupplier; + + Settings settings; + @Mock + RestController restController; + @Mock + ClusterSettings clusterSettings; + @Mock + IndexScopedSettings indexScopedSettings; + @Mock + SettingsFilter settingsFilter; + @Mock + Supplier nodesInCluster; + + ToolPlugin toolPlugin = new ToolPlugin(); + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + settings = Settings.builder().put("node.processors", 8).build(); + } + + @Test + public void test_getRestHandlers_successful() { + List restHandlers = toolPlugin + .getRestHandlers( + settings, + restController, + clusterSettings, + indexScopedSettings, + settingsFilter, + indexNameExpressionResolver, + nodesInCluster + ); + assertEquals(0, restHandlers.size()); + } + + @Test + public void test_getToolFactories_successful() { + assertEquals(16, toolPlugin.getToolFactories().size()); + } + + @Test + public void test_getExecutorBuilders_successful() { + assertEquals(1, toolPlugin.getExecutorBuilders(settings).size()); + } + + @Test + public void test_createComponent_successful() { + Collection collection = toolPlugin + .createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier + ); + assertEquals(0, collection.size()); + } + +} diff --git a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java index 63ed33af..e93e661a 100644 --- a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java @@ -165,7 +165,9 @@ public void testRunAsyncWithIllegalQueryThenListenerOnFailure() { mockedImpl.run(null, listener4); Exception exception4 = assertThrows(Exception.class, future4::join); - assertTrue(exception4.getCause() instanceof NullPointerException); + // parameter is re-created with extractInputParameters, thus will not be null + assertTrue(exception4.getCause() instanceof IllegalArgumentException); + assertEquals(exception4.getCause().getMessage(), "[input] is null or empty, can not process it."); } @Test diff --git a/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java b/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java new file mode 100644 index 00000000..d79336b0 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/DataDistributionToolTests.java @@ -0,0 +1,2373 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.jsoup.helper.Validate.fail; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.IndicesAdminClient; + +import com.google.common.collect.ImmutableMap; +import com.google.gson.JsonElement; + +import lombok.SneakyThrows; + +public class DataDistributionToolTests { + + private Map params = new HashMap<>(); + private final Client client = mock(Client.class); + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + @Mock + private SearchResponse searchResponse; + @Mock + private TransportPPLQueryResponse pplQueryResponse; + + @SneakyThrows + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + setupMockMappings(); + DataDistributionTool.Factory.getInstance().init(client); + } + + private void mockSearchResponse() { + SearchHit[] hits = createSampleHits(); + SearchHits searchHits = new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + when(searchResponse.getHits()).thenReturn(searchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + } + + private void mockPPLInvocation(String response) { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + when(pplQueryResponse.getResult()).thenReturn(response); + } + + @Test + @SneakyThrows + public void testCreateTool() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + assertEquals("DataDistributionTool", tool.getType()); + assertEquals("DataDistributionTool", tool.getName()); + assertEquals(DataDistributionTool.Factory.getInstance().getDefaultDescription(), tool.getDescription()); + assertNull(DataDistributionTool.Factory.getInstance().getDefaultVersion()); + } + + @Test + public void testValidate() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Valid parameters + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00" + ) + ) + ); + + // Valid parameters with new fields + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "filter", + "[\"{'term': {'status': 'error'}}\"]", + "ppl", + "source=logs-* | where status='error'" + ) + ) + ); + + // Missing required parameters + assertFalse(tool.validate(Map.of("index", "test_index"))); + assertFalse(tool.validate(Map.of())); + + // Missing selectionTimeRangeStart + assertFalse(tool.validate(Map.of("index", "test_index", "selectionTimeRangeEnd", "2025-01-15 11:00:00"))); + + // Missing selectionTimeRangeEnd + assertFalse(tool.validate(Map.of("index", "test_index", "selectionTimeRangeStart", "2025-01-15 10:00:00"))); + + // Valid with default queryType and timeField + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00" + ) + ) + ); + + // Valid with explicit queryType and timeField + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl", + "timeField", + "timestamp" + ) + ) + ); + } + + @Test + @SneakyThrows + public void testDSLSingleAnalysis() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify the analysis contains field distribution data + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("singleAnalysis should contain at least one field analysis", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify each field analysis has required structure (SummaryDataItem) + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue("Field analysis should be a JSON object", fieldAnalysis.isJsonObject()); + assertTrue("Field analysis should have 'field' property", fieldAnalysis.getAsJsonObject().has("field")); + assertTrue("Field analysis should have 'divergence' property", fieldAnalysis.getAsJsonObject().has("divergence")); + assertTrue("Field analysis should have 'topChanges' property", fieldAnalysis.getAsJsonObject().has("topChanges")); + assertNotNull("Field name should not be null", fieldAnalysis.getAsJsonObject().get("field").getAsString()); + assertTrue("TopChanges should be a JSON array", fieldAnalysis.getAsJsonObject().get("topChanges").isJsonArray()); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify filter was applied (should still have analysis data) + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("singleAnalysis should contain field analyses even with filter", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify structure is maintained with filter + JsonElement firstField = singleAnalysis.getAsJsonArray().get(0); + assertTrue("Field analysis should have proper structure with filter", firstField.getAsJsonObject().has("field")); + assertTrue("Field analysis should have topChanges with filter", firstField.getAsJsonObject().has("topChanges")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithMultipleFilters() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify multiple filters were applied + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("singleAnalysis should contain analyses with multiple filters", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify each field has proper structure with multiple filters + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue( + "Field analysis should maintain structure with multiple filters", + fieldAnalysis.getAsJsonObject().has("field") + ); + assertTrue( + "Field analysis should have topChanges with multiple filters", + fieldAnalysis.getAsJsonObject().has("topChanges") + ); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testPPLSingleAnalysis() { + String pplResponse = + """ + {"schema":[{"name":"status","type":"keyword"},{"name":"level","type":"integer"},{"name":"host","type":"keyword"}], + "datarows":[["error",3,"server-01"],["info",1,"server-02"],["warning",2,"server-03"],["error",4,"server-01"],["debug",1,"server-02"]], + "total":5,"size":5} + """; + + mockPPLInvocation(pplResponse); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify PPL data was processed correctly + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("Should have at least one field from PPL response", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify each field has proper structure + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue("Field analysis should have field property", fieldAnalysis.getAsJsonObject().has("field")); + assertTrue("Field analysis should have divergence property", fieldAnalysis.getAsJsonObject().has("divergence")); + assertTrue("Field analysis should have topChanges property", fieldAnalysis.getAsJsonObject().has("topChanges")); + assertTrue("TopChanges should be an array", fieldAnalysis.getAsJsonObject().get("topChanges").isJsonArray()); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testPPLWithCustomStatement() { + String pplResponse = """ + {"schema":[{"name":"status","type":"keyword"},{"name":"host","type":"keyword"},{"name":"count","type":"long"}], + "datarows":[["error","server-01",15],["error","server-02",8],["warning","server-01",3]], + "total":3,"size":3} + """; + + mockPPLInvocation(pplResponse); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl", + "ppl", + "source=logs-* | where status='error' | stats count() by host" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify custom PPL statement was processed + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("Should have at least one field from custom PPL response", singleAnalysis.getAsJsonArray().size() > 0); + + // Verify each field has proper structure + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue("Field analysis should have field property", fieldAnalysis.getAsJsonObject().has("field")); + assertTrue("Field analysis should have divergence property", fieldAnalysis.getAsJsonObject().has("divergence")); + assertTrue("Field analysis should have topChanges property", fieldAnalysis.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testComparisonAnalysis() { + // Mock different responses for baseline and selection data + SearchHit[] baselineHits = createBaselineHits(); + SearchHit[] selectionHits = createSelectionHits(); + + SearchHits baselineSearchHits = new SearchHits(baselineHits, new TotalHits(baselineHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchHits selectionSearchHits = new SearchHits( + selectionHits, + new TotalHits(selectionHits.length, TotalHits.Relation.EQUAL_TO), + 1.0f + ); + + // Mock sequential search calls - first selection, then baseline (based on new implementation) + when(searchResponse.getHits()) + .thenReturn(selectionSearchHits) // First call returns selection data + .thenReturn(baselineSearchHits); // Second call returns baseline data + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "baselineTimeRangeStart", + "2025-01-15 08:00:00", + "baselineTimeRangeEnd", + "2025-01-15 09:00:00", + "queryType", + "dsl" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain comparisonAnalysis", result.getAsJsonObject().has("comparisonAnalysis")); + + // Verify comparison analysis contains divergence data + JsonElement comparisonAnalysis = result.getAsJsonObject().get("comparisonAnalysis"); + assertTrue("comparisonAnalysis should be a JSON array", comparisonAnalysis.isJsonArray()); + assertTrue("comparisonAnalysis should contain field comparisons", comparisonAnalysis.getAsJsonArray().size() > 0); + + // Verify each comparison has required structure (SummaryDataItem) + for (int i = 0; i < comparisonAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldComparison = comparisonAnalysis.getAsJsonArray().get(i); + assertTrue("Field comparison should be a JSON object", fieldComparison.isJsonObject()); + assertTrue("Field comparison should have 'field' property", fieldComparison.getAsJsonObject().has("field")); + assertTrue( + "Field comparison should have 'divergence' property", + fieldComparison.getAsJsonObject().has("divergence") + ); + assertTrue( + "Field comparison should have 'topChanges' property", + fieldComparison.getAsJsonObject().has("topChanges") + ); + + // Verify divergence is a valid number + assertTrue("Divergence should be a number", fieldComparison.getAsJsonObject().get("divergence").isJsonPrimitive()); + double divergence = fieldComparison.getAsJsonObject().get("divergence").getAsDouble(); + assertTrue("Divergence should be non-negative", divergence >= 0.0); + + // Verify topChanges structure + JsonElement topChanges = fieldComparison.getAsJsonObject().get("topChanges"); + assertTrue("TopChanges should be a JSON array", topChanges.isJsonArray()); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testPPLComparisonAnalysis() { + String baseResponse = """ + {"schema":[{"name":"status","type":"keyword"},{"name":"level","type":"integer"}], + "datarows":[["info",1],["warning",2],["debug",1]], + "total":3,"size":3} + """; + + String selectionResponse = """ + {"schema":[{"name":"status","type":"keyword"},{"name":"level","type":"integer"}], + "datarows":[["error",3],["error",4],["warning",2]], + "total":3,"size":3} + """; + + // Mock sequential PPL calls - first selection, then baseline (based on new implementation) + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + when(pplQueryResponse.getResult()) + .thenReturn(selectionResponse) // First call returns selection data + .thenReturn(baseResponse); // Second call returns baseline data + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "baselineTimeRangeStart", + "2025-01-15 08:00:00", + "baselineTimeRangeEnd", + "2025-01-15 09:00:00", + "queryType", + "ppl", + "ppl", + "source=logs-* | where level > 1" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain comparisonAnalysis", result.getAsJsonObject().has("comparisonAnalysis")); + + // Verify comparison shows differences between baseline and selection + JsonElement comparisonAnalysis = result.getAsJsonObject().get("comparisonAnalysis"); + assertTrue("comparisonAnalysis should be a JSON array", comparisonAnalysis.isJsonArray()); + assertTrue("Should have at least one field from PPL comparison", comparisonAnalysis.getAsJsonArray().size() > 0); + + // Verify each field has proper structure + for (int i = 0; i < comparisonAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldComparison = comparisonAnalysis.getAsJsonArray().get(i); + assertTrue("Field comparison should have field property", fieldComparison.getAsJsonObject().has("field")); + assertTrue("Field comparison should have divergence property", fieldComparison.getAsJsonObject().has("divergence")); + assertTrue("Field comparison should have topChanges property", fieldComparison.getAsJsonObject().has("topChanges")); + + // Verify divergence is a valid number + double divergence = fieldComparison.getAsJsonObject().get("divergence").getAsDouble(); + assertTrue("Divergence should be non-negative", divergence >= 0.0); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidParameters() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap.of("index", "test_index"), + ActionListener + .wrap( + response -> fail("Should have failed with invalid parameters"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Invalid time format")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "invalid-json" + ), + ActionListener.wrap(response -> fail("Should have failed with invalid filter JSON"), e -> { + MatcherAssert.assertThat(e.getMessage(), containsString("Invalid 'filter' parameter")); + MatcherAssert.assertThat(e.getMessage(), containsString("must be a valid JSON array of strings")); + }) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithCustomTimeField() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "custom_timestamp", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + // Verify custom time field was used + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue( + "singleAnalysis should contain field analyses with custom time field", + singleAnalysis.getAsJsonArray().size() > 0 + ); + + // Verify that the custom time field doesn't appear in the analysis (it's used for filtering, not analysis) + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + String fieldName = fieldAnalysis.getAsJsonObject().get("field").getAsString(); + assertFalse("Custom time field should not appear in analysis results", "custom_timestamp".equals(fieldName)); + assertTrue("Field analysis should have topChanges property", fieldAnalysis.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testExecutionFailedInSearch() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new Exception("Search execution failed")); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener + .wrap( + response -> fail("Should have failed"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Search execution failed")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionFailedInPPL() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("PPL execution failed")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl" + ), + ActionListener + .wrap( + response -> fail("Should have failed"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("PPL execution failed")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithEmptyPPLResponse() { + String emptyResponse = ""; + mockPPLInvocation(emptyResponse); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl" + ), + ActionListener + .wrap( + response -> fail("Should have failed with empty response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Empty PPL response")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithPPLErrorResponse() { + String errorResponse = "{\"error\":{\"type\":\"parsing_exception\",\"reason\":\"Syntax error in PPL query\"}}"; + mockPPLInvocation(errorResponse); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "ppl" + ), + ActionListener + .wrap( + response -> fail("Should have failed with PPL error response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("PPL query error")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithNoData() { + // Mock empty search response + SearchHit[] emptyHits = new SearchHit[0]; + SearchHits emptySearchHits = new SearchHits(emptyHits, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); + when(searchResponse.getHits()).thenReturn(emptySearchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener + .wrap( + response -> fail("Should have failed with no data"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("No data found for selection time range")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidTimeFormat() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "invalid-time-format", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl" + ), + ActionListener + .wrap( + response -> fail("Should have failed with invalid time format"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Invalid time format")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidSize() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "size", + "not-a-number" + ), + ActionListener.wrap(response -> fail("Should have failed with invalid size"), e -> { + MatcherAssert.assertThat(e.getMessage(), containsString("Invalid 'size' parameter")); + MatcherAssert.assertThat(e.getMessage(), containsString("must be a valid integer")); + }) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithSizeExceedsMaxLimit() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "size", + "15000" + ), + ActionListener.wrap(response -> fail("Should have failed with size exceeding limit"), e -> { + MatcherAssert.assertThat(e.getMessage(), containsString("must be between 1 and 10000")); + MatcherAssert.assertThat(e.getMessage(), containsString("15000")); + }) + ); + } + + private void setupMockMappings() { + Map indexMappings = Map + .of( + "properties", + Map + .of( + "status", + Map.of("type", "keyword"), + "level", + Map.of("type", "integer"), + "@timestamp", + Map.of("type", "date"), + "message", + Map.of("type", "text"), + "host", + Map.of("type", "keyword"), + "service", + Map.of("type", "keyword") + ) + ); + Map mockedMappings = Map.of("test_index", mappingMetadata); + + when(mappingMetadata.getSourceAsMap()).thenReturn(indexMappings); + when(getMappingsResponse.getMappings()).thenReturn(mockedMappings); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + + // Mock the ActionFuture returned by getMappings + org.opensearch.common.action.ActionFuture mockActionFuture = mock( + org.opensearch.common.action.ActionFuture.class + ); + when(mockActionFuture.actionGet(anyLong())).thenReturn(getMappingsResponse); + when(mockActionFuture.actionGet()).thenReturn(getMappingsResponse); + when(indicesAdminClient.getMappings(any())).thenReturn(mockActionFuture); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getMappingsResponse); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + } + + private SearchHit[] createSampleHits() { + SearchHit[] hits = new SearchHit[20]; + String[] statuses = { "error", "info", "warning", "debug" }; + String[] hosts = { "server-01", "server-02", "server-03" }; + String[] services = { "auth", "payment", "notification" }; + int[] levels = { 1, 2, 3, 4, 5 }; + + for (int i = 0; i < 20; i++) { + SearchHit hit = new SearchHit(i + 1); + String status = statuses[i % statuses.length]; + String host = hosts[i % hosts.length]; + String service = services[i % services.length]; + int level = levels[i % levels.length]; + + String source = String + .format( + "{\"status\":\"%s\",\"level\":%d,\"@timestamp\":\"2025-01-15T10:%02d:00Z\",\"host\":\"%s\",\"service\":\"%s\",\"message\":\"Sample message %d\"}", + status, + level, + 30 + i, + host, + service, + i + ); + + BytesReference sourceRef = new BytesArray(source); + hit.sourceRef(sourceRef); + hits[i] = hit; + } + + return hits; + } + + private SearchHit[] createBaselineHits() { + SearchHit[] hits = new SearchHit[10]; + // Baseline data: mostly info and warning + String[] statuses = { "info", "warning" }; + String[] hosts = { "server-01", "server-02" }; + String[] services = { "auth", "payment" }; + int[] levels = { 1, 2 }; + + for (int i = 0; i < 10; i++) { + SearchHit hit = new SearchHit(i + 1); + String status = statuses[i % statuses.length]; + String host = hosts[i % hosts.length]; + String service = services[i % services.length]; + int level = levels[i % levels.length]; + + String source = String + .format( + "{\"status\":\"%s\",\"level\":%d,\"@timestamp\":\"2025-01-15T08:%02d:00Z\",\"host\":\"%s\",\"service\":\"%s\",\"message\":\"Baseline message %d\"}", + status, + level, + 30 + i, + host, + service, + i + ); + + BytesReference sourceRef = new BytesArray(source); + hit.sourceRef(sourceRef); + hits[i] = hit; + } + + return hits; + } + + private SearchHit[] createSelectionHits() { + SearchHit[] hits = new SearchHit[10]; + // Selection data: mostly error and debug (different from baseline) + String[] statuses = { "error", "debug" }; + String[] hosts = { "server-02", "server-03" }; + String[] services = { "payment", "notification" }; + int[] levels = { 3, 4, 5 }; + + for (int i = 0; i < 10; i++) { + SearchHit hit = new SearchHit(i + 1); + String status = statuses[i % statuses.length]; + String host = hosts[i % hosts.length]; + String service = services[i % services.length]; + int level = levels[i % levels.length]; + + String source = String + .format( + "{\"status\":\"%s\",\"level\":%d,\"@timestamp\":\"2025-01-15T10:%02d:00Z\",\"host\":\"%s\",\"service\":\"%s\",\"message\":\"Selection message %d\"}", + status, + level, + 30 + i, + host, + service, + i + ); + + BytesReference sourceRef = new BytesArray(source); + hit.sourceRef(sourceRef); + hits[i] = hit; + } + + return hits; + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithValidMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> testData = createTestDataForFieldAnalysis(); + Map fieldTypes = Map + .of("status", "keyword", "level", "integer", "host", "keyword", "service", "keyword", "@timestamp", "date", "message", "text"); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, testData, fieldTypes); + + assertNotNull(usefulFields); + assertFalse(usefulFields.isEmpty()); + assertTrue(usefulFields.contains("status")); + assertTrue(usefulFields.contains("level")); + assertTrue(usefulFields.contains("host")); + assertTrue(usefulFields.contains("service")); + assertFalse(usefulFields.contains("@timestamp")); + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithEmptyMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> testData = createTestDataForFieldAnalysis(); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, testData, emptyFieldTypes); + + assertNotNull(usefulFields); + assertTrue(usefulFields.size() > 0); + assertFalse(usefulFields.contains("@timestamp")); + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithMappingException() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> testData = createTestDataForFieldAnalysis(); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, testData, emptyFieldTypes); + + assertNotNull(usefulFields); + assertTrue(usefulFields.size() > 0); + assertFalse(usefulFields.contains("@timestamp")); + assertFalse(usefulFields.contains("_id")); + assertFalse(usefulFields.contains("_index")); + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithHighCardinalityFields() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> testData = createHighCardinalityTestData(); + Map fieldTypes = Map.of("status", "keyword", "unique_field", "keyword", "@timestamp", "date"); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, testData, fieldTypes); + + assertNotNull(usefulFields); + // unique_field has high cardinality (20 unique values in 20 documents) so should be excluded + assertFalse("High cardinality field unique_field should be excluded", usefulFields.contains("unique_field")); + // status has low cardinality (2 unique values) so should be included + assertTrue("Low cardinality field status should be included", usefulFields.contains("status")); + } + + @Test + @SneakyThrows + public void testGetUsefulFieldsWithEmptyData() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + List> emptyData = List.of(); + Map fieldTypes = Map.of("status", "keyword", "level", "integer"); + + java.lang.reflect.Method getUsefulFieldsMethod = DataDistributionTool.class + .getDeclaredMethod("getUsefulFields", List.class, Map.class); + getUsefulFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + List usefulFields = (List) getUsefulFieldsMethod.invoke(tool, emptyData, fieldTypes); + + assertNotNull(usefulFields); + assertTrue(usefulFields.size() > 0); + } + + private List> createTestDataForFieldAnalysis() { + List> data = new ArrayList<>(); + String[] statuses = { "error", "info", "warning" }; + String[] hosts = { "server-01", "server-02" }; + String[] services = { "auth", "payment" }; + + for (int i = 0; i < 10; i++) { + Map doc = new HashMap<>(); + doc.put("status", statuses[i % statuses.length]); + doc.put("level", i % 5 + 1); + doc.put("host", hosts[i % hosts.length]); + doc.put("service", services[i % services.length]); + doc.put("@timestamp", "2025-01-15T10:" + String.format("%02d", 30 + i) + ":00Z"); + doc.put("message", "Test message " + i); + data.add(doc); + } + return data; + } + + private List> createHighCardinalityTestData() { + List> data = new ArrayList<>(); + String[] statuses = { "error", "info" }; + + for (int i = 0; i < 20; i++) { + Map doc = new HashMap<>(); + doc.put("status", statuses[i % statuses.length]); + doc.put("unique_field", "value_" + i); + doc.put("@timestamp", "2025-01-15T10:" + String.format("%02d", 30 + i) + ":00Z"); + data.add(doc); + } + return data; + } + + @Test + @SneakyThrows + public void testGroupNumericKeysWithManyNumericValues() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method groupNumericKeysMethod = DataDistributionTool.class + .getDeclaredMethod("groupNumericKeys", Map.class, Map.class); + groupNumericKeysMethod.setAccessible(true); + + Map selectionDist = new HashMap<>(); + Map baselineDist = new HashMap<>(); + + for (int i = 1; i <= 15; i++) { + selectionDist.put(String.valueOf(i), 0.1); + baselineDist.put(String.valueOf(i + 5), 0.1); + } + + Object result = groupNumericKeysMethod.invoke(tool, selectionDist, baselineDist); + + assertNotNull(result); + java.lang.reflect.Method groupedSelectionDistMethod = result.getClass().getDeclaredMethod("groupedSelectionDist"); + @SuppressWarnings("unchecked") + Map groupedSelection = (Map) groupedSelectionDistMethod.invoke(result); + + assertEquals(5, groupedSelection.size()); + assertTrue(groupedSelection.keySet().stream().allMatch(key -> key.contains("-"))); + } + + @Test + @SneakyThrows + public void testGroupNumericKeysWithFewNumericValues() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method groupNumericKeysMethod = DataDistributionTool.class + .getDeclaredMethod("groupNumericKeys", Map.class, Map.class); + groupNumericKeysMethod.setAccessible(true); + + Map selectionDist = Map.of("1", 0.3, "2", 0.4, "3", 0.3); + Map baselineDist = Map.of("1", 0.2, "2", 0.5, "3", 0.3); + + Object result = groupNumericKeysMethod.invoke(tool, selectionDist, baselineDist); + + assertNotNull(result); + java.lang.reflect.Method groupedSelectionDistMethod = result.getClass().getDeclaredMethod("groupedSelectionDist"); + @SuppressWarnings("unchecked") + Map groupedSelection = (Map) groupedSelectionDistMethod.invoke(result); + + assertEquals(selectionDist, groupedSelection); + } + + @Test + @SneakyThrows + public void testGroupNumericKeysWithNonNumericValues() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method groupNumericKeysMethod = DataDistributionTool.class + .getDeclaredMethod("groupNumericKeys", Map.class, Map.class); + groupNumericKeysMethod.setAccessible(true); + + Map selectionDist = new HashMap<>(); + Map baselineDist = new HashMap<>(); + + for (int i = 1; i <= 15; i++) { + selectionDist.put(String.valueOf(i), 0.1); + } + selectionDist.put("error", 0.2); + selectionDist.put("warning", 0.3); + + Object result = groupNumericKeysMethod.invoke(tool, selectionDist, baselineDist); + + assertNotNull(result); + java.lang.reflect.Method groupedSelectionDistMethod = result.getClass().getDeclaredMethod("groupedSelectionDist"); + @SuppressWarnings("unchecked") + Map groupedSelection = (Map) groupedSelectionDistMethod.invoke(result); + + assertEquals(selectionDist, groupedSelection); + } + + @Test + @SneakyThrows + public void testGetNumberFieldsWithValidMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + Map fieldTypes = Map.of("status", "keyword", "level", "integer", "host", "keyword", "response_time", "float"); + + java.lang.reflect.Method getNumberFieldsMethod = DataDistributionTool.class.getDeclaredMethod("getNumberFields", Map.class); + getNumberFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + java.util.Set numberFields = (java.util.Set) getNumberFieldsMethod.invoke(tool, fieldTypes); + + assertNotNull(numberFields); + assertTrue(numberFields.contains("level")); + assertTrue(numberFields.contains("response_time")); + assertFalse(numberFields.contains("status")); + assertFalse(numberFields.contains("host")); + } + + @Test + @SneakyThrows + public void testGetNumberFieldsWithEmptyMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getNumberFieldsMethod = DataDistributionTool.class.getDeclaredMethod("getNumberFields", Map.class); + getNumberFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + java.util.Set numberFields = (java.util.Set) getNumberFieldsMethod.invoke(tool, emptyFieldTypes); + + assertNotNull(numberFields); + assertTrue(numberFields.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetNumberFieldsWithMappingException() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getNumberFieldsMethod = DataDistributionTool.class.getDeclaredMethod("getNumberFields", Map.class); + getNumberFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + java.util.Set numberFields = (java.util.Set) getNumberFieldsMethod.invoke(tool, emptyFieldTypes); + + assertNotNull(numberFields); + assertTrue(numberFields.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetNumberFieldsWithNullActionFuture() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + Map emptyFieldTypes = Map.of(); + + java.lang.reflect.Method getNumberFieldsMethod = DataDistributionTool.class.getDeclaredMethod("getNumberFields", Map.class); + getNumberFieldsMethod.setAccessible(true); + + @SuppressWarnings("unchecked") + java.util.Set numberFields = (java.util.Set) getNumberFieldsMethod.invoke(tool, emptyFieldTypes); + + assertNotNull(numberFields); + assertTrue(numberFields.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetFieldTypesWithValidMapping() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getFieldTypesMethod = DataDistributionTool.class + .getDeclaredMethod("getFieldTypes", String.class, ActionListener.class); + getFieldTypesMethod.setAccessible(true); + + java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1); + java.util.concurrent.atomic.AtomicReference> resultRef = new java.util.concurrent.atomic.AtomicReference<>(); + + ActionListener> listener = ActionListener.wrap(result -> { + resultRef.set(result); + latch.countDown(); + }, e -> { + latch.countDown(); + fail("getFieldTypes failed: " + e.getMessage()); + }); + + getFieldTypesMethod.invoke(tool, "test_index", listener); + latch.await(); + + Map fieldTypes = resultRef.get(); + assertNotNull(fieldTypes); + assertFalse(fieldTypes.isEmpty()); + assertEquals("keyword", fieldTypes.get("status")); + assertEquals("integer", fieldTypes.get("level")); + assertEquals("keyword", fieldTypes.get("host")); + } + + @Test + @SneakyThrows + public void testGetFieldTypesWithEmptyMapping() { + when(getMappingsResponse.getMappings()).thenReturn(Map.of()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getFieldTypesMethod = DataDistributionTool.class + .getDeclaredMethod("getFieldTypes", String.class, ActionListener.class); + getFieldTypesMethod.setAccessible(true); + + java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1); + java.util.concurrent.atomic.AtomicReference> resultRef = new java.util.concurrent.atomic.AtomicReference<>(); + + ActionListener> listener = ActionListener.wrap(result -> { + resultRef.set(result); + latch.countDown(); + }, e -> { + latch.countDown(); + fail("getFieldTypes failed: " + e.getMessage()); + }); + + getFieldTypesMethod.invoke(tool, "test_index", listener); + latch.await(); + + Map fieldTypes = resultRef.get(); + assertNotNull(fieldTypes); + assertTrue(fieldTypes.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetFieldTypesWithMappingException() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new RuntimeException("Mapping failed")); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getFieldTypesMethod = DataDistributionTool.class + .getDeclaredMethod("getFieldTypes", String.class, ActionListener.class); + getFieldTypesMethod.setAccessible(true); + + java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1); + java.util.concurrent.atomic.AtomicReference> resultRef = new java.util.concurrent.atomic.AtomicReference<>(); + + ActionListener> listener = ActionListener.wrap(result -> { + resultRef.set(result); + latch.countDown(); + }, e -> { + resultRef.set(Map.of()); + latch.countDown(); + }); + + getFieldTypesMethod.invoke(tool, "test_index", listener); + latch.await(); + + Map fieldTypes = resultRef.get(); + assertNotNull(fieldTypes); + assertTrue(fieldTypes.isEmpty()); + } + + @Test + @SneakyThrows + public void testGetPPLQueryWithTimeRangeEmptyQuery() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getPPLQueryWithTimeRangeMethod = DataDistributionTool.class + .getDeclaredMethod("getPPLQueryWithTimeRange", String.class, String.class, String.class, String.class); + getPPLQueryWithTimeRangeMethod.setAccessible(true); + + try { + getPPLQueryWithTimeRangeMethod.invoke(tool, "", "2025-01-15 10:00:00", "2025-01-15 11:00:00", "@timestamp"); + fail("Expected IllegalArgumentException for empty PPL query"); + } catch (java.lang.reflect.InvocationTargetException e) { + assertTrue("Expected IllegalArgumentException", e.getCause() instanceof IllegalArgumentException); + assertEquals("PPL query cannot be empty", e.getCause().getMessage()); + } + } + + @Test + @SneakyThrows + public void testGetPPLQueryWithTimeRangeEmptyTimeField() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getPPLQueryWithTimeRangeMethod = DataDistributionTool.class + .getDeclaredMethod("getPPLQueryWithTimeRange", String.class, String.class, String.class, String.class); + getPPLQueryWithTimeRangeMethod.setAccessible(true); + + String result = (String) getPPLQueryWithTimeRangeMethod + .invoke(tool, "source=logs-*", "2025-01-15 10:00:00", "2025-01-15 11:00:00", ""); + + assertEquals("source=logs-*", result); + } + + @Test + @SneakyThrows + public void testGetPPLQueryWithTimeRangeExistingWhere() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getPPLQueryWithTimeRangeMethod = DataDistributionTool.class + .getDeclaredMethod("getPPLQueryWithTimeRange", String.class, String.class, String.class, String.class); + getPPLQueryWithTimeRangeMethod.setAccessible(true); + + String result = (String) getPPLQueryWithTimeRangeMethod + .invoke(tool, "source=logs-* | where status='error'", "2025-01-15 10:00:00", "2025-01-15 11:00:00", "@timestamp"); + + assertEquals( + "source=logs-* | WHERE `@timestamp` >= '2025-01-15 10:00:00' AND `@timestamp` <= '2025-01-15 11:00:00' | where status='error'", + result + ); + } + + @Test + @SneakyThrows + public void testGetPPLQueryWithTimeRangeNoExistingWhere() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + java.lang.reflect.Method getPPLQueryWithTimeRangeMethod = DataDistributionTool.class + .getDeclaredMethod("getPPLQueryWithTimeRange", String.class, String.class, String.class, String.class); + getPPLQueryWithTimeRangeMethod.setAccessible(true); + + String result = (String) getPPLQueryWithTimeRangeMethod + .invoke(tool, "source=logs-* | stats count() by status", "2025-01-15 10:00:00", "2025-01-15 11:00:00", "@timestamp"); + + assertEquals( + "source=logs-* | WHERE `@timestamp` >= '2025-01-15 10:00:00' AND `@timestamp` <= '2025-01-15 11:00:00' | stats count() by status", + result + ); + } + + // ========== DSL Query Format Tests ========== + + @Test + @SneakyThrows + public void testDSLWithRawDSLQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}]}}" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("singleAnalysis should contain field analyses with raw DSL", singleAnalysis.getAsJsonArray().size() > 0); + + for (int i = 0; i < singleAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldAnalysis = singleAnalysis.getAsJsonArray().get(i); + assertTrue( + "Field analysis should have proper structure with raw DSL", + fieldAnalysis.getAsJsonObject().has("field") + ); + assertTrue("Field analysis should have topChanges with raw DSL", fieldAnalysis.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithComplexRawDSLQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + String complexDSL = """ + { + "bool": { + "must": [ + {"term": {"status": "error"}}, + {"range": {"level": {"gte": 3}}} + ], + "should": [ + {"match": {"message": "timeout"}}, + {"wildcard": {"host": "server-*"}} + ], + "minimum_should_match": 1 + } + } + """; + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + complexDSL + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("singleAnalysis should be a JSON array", singleAnalysis.isJsonArray()); + assertTrue("Complex DSL should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithInvalidRawDSLQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Test that invalid DSL query causes execution to fail + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "invalid-json-query" + ), + ActionListener.wrap(response -> { + fail("Should have failed with invalid DSL query"); + }, e -> { + // Expect failure due to invalid DSL format + assertTrue( + "Should fail with exception for invalid DSL", + e instanceof IllegalArgumentException || e.getMessage().contains("Invalid query format") + ); + }) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArraySingleFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Single filter should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayComplexFilters() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3, 'lte': 5}}}\", \"{'wildcard': {'host': 'server-*'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Complex filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayMatchQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'match': {'message': 'error timeout'}}\", \"{'match_phrase': {'service': 'payment service'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Match query filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayExistsAndPrefixQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'exists': {'field': 'error_code'}}\", \"{'prefix': {'host': 'prod'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Exists and prefix filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayRegexpQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'regexp': {'host': 'server-[0-9]+'}}\", \"{'wildcard': {'service': '*payment*'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Regexp and wildcard filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayTermsQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'terms': {'status': ['error', 'warning']}}\", \"{'terms': {'level': [3, 4, 5]}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Terms filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayMultiMatchQueries() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'multi_match': {'query': 'error timeout', 'fields': ['message', 'description']}}\", \"{'multi_match': {'query': 'connection failed', 'fields': ['error_msg', 'details']}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Multi-match filters should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithFilterArrayInvalidFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Test that invalid filter JSON causes parameter validation to fail + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\", \"invalid-json-filter\"]" + ), + ActionListener.wrap(response -> { + fail("Should have failed with invalid filter JSON"); + }, e -> { + // Expect IllegalArgumentException due to invalid filter format + assertTrue("Should fail with IllegalArgumentException for invalid filter", e instanceof IllegalArgumentException); + assertTrue( + "Error message should mention invalid filter", + e.getMessage().contains("Invalid 'filter' parameter") || e.getMessage().contains("Invalid query format") + ); + }) + ); + } + + @Test + @SneakyThrows + public void testDSLComparisonWithRawDSLQuery() { + SearchHit[] baselineHits = createBaselineHits(); + SearchHit[] selectionHits = createSelectionHits(); + + SearchHits baselineSearchHits = new SearchHits(baselineHits, new TotalHits(baselineHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchHits selectionSearchHits = new SearchHits( + selectionHits, + new TotalHits(selectionHits.length, TotalHits.Relation.EQUAL_TO), + 1.0f + ); + + when(searchResponse.getHits()).thenReturn(selectionSearchHits).thenReturn(baselineSearchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "baselineTimeRangeStart", + "2025-01-15 08:00:00", + "baselineTimeRangeEnd", + "2025-01-15 09:00:00", + "queryType", + "dsl", + "dsl", + "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}]}}" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain comparisonAnalysis", result.getAsJsonObject().has("comparisonAnalysis")); + + JsonElement comparisonAnalysis = result.getAsJsonObject().get("comparisonAnalysis"); + assertTrue("comparisonAnalysis should be a JSON array", comparisonAnalysis.isJsonArray()); + assertTrue("Raw DSL comparison should produce results", comparisonAnalysis.getAsJsonArray().size() > 0); + + for (int i = 0; i < comparisonAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldComparison = comparisonAnalysis.getAsJsonArray().get(i); + assertTrue("Field comparison should have divergence", fieldComparison.getAsJsonObject().has("divergence")); + assertTrue("Field comparison should have topChanges", fieldComparison.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLComparisonWithFilterArray() { + SearchHit[] baselineHits = createBaselineHits(); + SearchHit[] selectionHits = createSelectionHits(); + + SearchHits baselineSearchHits = new SearchHits(baselineHits, new TotalHits(baselineHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchHits selectionSearchHits = new SearchHits( + selectionHits, + new TotalHits(selectionHits.length, TotalHits.Relation.EQUAL_TO), + 1.0f + ); + + when(searchResponse.getHits()).thenReturn(selectionSearchHits).thenReturn(baselineSearchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "baselineTimeRangeStart", + "2025-01-15 08:00:00", + "baselineTimeRangeEnd", + "2025-01-15 09:00:00", + "queryType", + "dsl", + "filter", + "[\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain comparisonAnalysis", result.getAsJsonObject().has("comparisonAnalysis")); + + JsonElement comparisonAnalysis = result.getAsJsonObject().get("comparisonAnalysis"); + assertTrue("comparisonAnalysis should be a JSON array", comparisonAnalysis.isJsonArray()); + assertTrue("Filter array comparison should produce results", comparisonAnalysis.getAsJsonArray().size() > 0); + + for (int i = 0; i < comparisonAnalysis.getAsJsonArray().size(); i++) { + JsonElement fieldComparison = comparisonAnalysis.getAsJsonArray().get(i); + assertTrue("Field comparison should have divergence", fieldComparison.getAsJsonObject().has("divergence")); + assertTrue("Field comparison should have topChanges", fieldComparison.getAsJsonObject().has("topChanges")); + } + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithBothRawDSLAndFilterArray() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}]}}", + "filter", + "[\"{'range': {'level': {'gte': 3}}}\"]" + ), + ActionListener.wrap(response -> { + // When both dsl and filter are provided, dsl should take precedence + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Both DSL and filter should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + // ========== Query Format Validation Tests ========== + + @Test + @SneakyThrows + public void testValidateFilterArrayFormat() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Valid filter array formats + assertTrue( + "Single filter should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "filter", + "[\"{'term': {'status': 'error'}}\"]" + ) + ) + ); + + assertTrue( + "Multiple filters should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "filter", + "[\"{'term': {'status': 'error'}}\", \"{'range': {'level': {'gte': 3}}}\"]" + ) + ) + ); + + assertTrue( + "Empty filter array should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "filter", + "[]" + ) + ) + ); + } + + @Test + @SneakyThrows + public void testValidateRawDSLFormat() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Valid DSL formats + assertTrue( + "Simple DSL should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "dsl", + "{\"term\": {\"status\": \"error\"}}" + ) + ) + ); + + assertTrue( + "Complex DSL should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "dsl", + "{\"bool\": {\"must\": [{\"term\": {\"status\": \"error\"}}], \"filter\": [{\"range\": {\"level\": {\"gte\": 3}}}]}}" + ) + ) + ); + + assertTrue( + "Empty DSL should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "dsl", + "" + ) + ) + ); + } + + @Test + @SneakyThrows + public void testValidateBothDSLAndFilterFormats() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // Both DSL and filter provided should be valid + assertTrue( + "Both DSL and filter should be valid", + tool + .validate( + Map + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "dsl", + "{\"term\": {\"status\": \"error\"}}", + "filter", + "[\"{'range': {'level': {'gte': 3}}}\"]" + ) + ) + ); + } + + // ========== Edge Cases and Error Handling Tests ========== + + @Test + @SneakyThrows + public void testDSLWithEmptyFilterArray() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis with empty filter", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Empty filter should still produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithMalformedFilterJSON() { + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "filter", + "[malformed-json]" + ), + ActionListener.wrap(response -> fail("Should have failed with malformed filter JSON"), e -> { + // The filter "[malformed-json]" is valid JSON array syntax, but "malformed-json" is not valid query JSON + // Error occurs during query execution, not parameter validation + MatcherAssert + .assertThat( + e.getMessage(), + anyOf(containsString("Invalid query format"), containsString("Invalid 'filter' parameter")) + ); + }) + ); + } + + @Test + @SneakyThrows + public void testDSLWithBoolQueryInRawDSL() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + String boolDSL = """ + { + "bool": { + "must": [ + {"term": {"status": "error"}} + ], + "should": [ + {"match": {"message": "timeout"}}, + {"match": {"message": "connection"}} + ], + "must_not": [ + {"term": {"level": 1}} + ], + "filter": [ + {"range": {"@timestamp": {"gte": "2025-01-15T09:00:00Z"}}} + ], + "minimum_should_match": 1 + } + } + """; + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + boolDSL + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Bool query DSL should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithRawDSLTermsQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"terms\": {\"status\": [\"error\", \"warning\"]}}" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Terms query DSL should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLWithRawDSLMultiMatchQuery() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"multi_match\": {\"query\": \"error timeout\", \"fields\": [\"message\", \"description\"]}}" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("Multi-match query DSL should produce analysis results", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testDSLQueryPrecedenceOverFilter() { + mockSearchResponse(); + DataDistributionTool tool = DataDistributionTool.Factory.getInstance().create(params); + + // When both dsl and filter are provided, dsl should take precedence + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "selectionTimeRangeStart", + "2025-01-15 10:00:00", + "selectionTimeRangeEnd", + "2025-01-15 11:00:00", + "queryType", + "dsl", + "dsl", + "{\"term\": {\"status\": \"error\"}}", + "filter", + "[\"{'term': {'status': 'info'}}\"]" + ), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue("Response should contain singleAnalysis", result.getAsJsonObject().has("singleAnalysis")); + + JsonElement singleAnalysis = result.getAsJsonObject().get("singleAnalysis"); + assertTrue("DSL should take precedence over filter", singleAnalysis.getAsJsonArray().size() > 0); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java b/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java new file mode 100644 index 00000000..dd17e50d --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/LogPatternAnalysisToolTests.java @@ -0,0 +1,667 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; +import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; +import org.opensearch.transport.client.Client; + +import com.google.common.collect.ImmutableMap; +import com.google.gson.JsonElement; + +import lombok.SneakyThrows; + +public class LogPatternAnalysisToolTests { + + private Map params = new HashMap<>(); + private final Client client = mock(Client.class); + @Mock + private TransportPPLQueryResponse pplQueryResponse; + + @SneakyThrows + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + LogPatternAnalysisTool.Factory.getInstance().init(client); + } + + private void mockPPLInvocation(String response) { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + when(pplQueryResponse.getResult()).thenReturn(response); + } + + @Test + @SneakyThrows + public void testCreateTool() { + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + assertEquals("LogPatternAnalysisTool", tool.getType()); + assertEquals("LogPatternAnalysisTool", tool.getName()); + assertEquals(LogPatternAnalysisTool.Factory.getInstance().getDefaultDescription(), tool.getDescription()); + assertNull(LogPatternAnalysisTool.Factory.getInstance().getDefaultVersion()); + } + + @Test + public void testValidate() { + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + // Valid parameters + assertTrue( + tool + .validate( + Map + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ) + ) + ); + + // Missing required parameters + assertFalse(tool.validate(Map.of("index", "test_index"))); + assertFalse(tool.validate(Map.of())); + } + + @Test + @SneakyThrows + public void testLogInsightExecution() { + String pplResponse = + """ + {"schema":[{"name":"patterns_field","type":"string"},{"name":"pattern_count","type":"long"},{"name":"sample_logs","type":"array"}], + "datarows":[["Error in processing <*>",5,["Error in processing request","Error in processing data"]], + ["Failed to connect <*>",3,["Failed to connect to database","Failed to connect to server"]]], + "total":2,"size":2} + """; + + mockPPLInvocation(pplResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener.wrap(response -> { + System.out.println(response); + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("logInsights")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testLogPatternDiffAnalysis() { + // Mock different responses for base and selection time ranges + String baseResponse = """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[100,"User login successful"],[20,"Database query executed"],[10,"Cache hit"]], + "total":3,"size":3} + """; + + String selectionResponse = + """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[50,"User login successful"],[80,"Error in authentication <*>"],[15,"Connection timeout <*>"],[5,"Database query executed"]], + "total":4,"size":4} + """; + + // Mock sequential PPL calls - first base, then selection + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + when(pplQueryResponse.getResult()) + .thenReturn(baseResponse) // First call returns base data + .thenReturn(selectionResponse); // Second call returns selection data + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "baseTimeRangeStart", + "2025-01-01T00:00:00Z", + "baseTimeRangeEnd", + "2025-01-01T01:00:00Z", + "selectionTimeRangeStart", + "2025-01-01T01:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T02:00:00Z" + ), + ActionListener.wrap(response -> { + System.out.println("Pattern diff response: " + response); + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("patternMapDifference")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testLogSequenceAnalysis() { + // Mock different responses for base and selection time ranges + String baseResponse = + """ + {"schema":[{"name":"traceId","type":"string"},{"name":"patterns_field","type":"string"},{"name":"@timestamp","type":"timestamp"}], + "datarows":[["trace1","User login attempt","2025-01-01T00:00:00Z"],["trace1","Authentication successful","2025-01-01T00:00:01Z"],["trace1","Session created","2025-01-01T00:00:02Z"], + ["trace2","User login attempt","2025-01-01T00:00:10Z"],["trace2","Authentication successful","2025-01-01T00:00:11Z"],["trace2","Session created","2025-01-01T00:00:12Z"]], + "total":6,"size":6} + """; + + String selectionResponse = + """ + {"schema":[{"name":"traceId","type":"string"},{"name":"patterns_field","type":"string"},{"name":"@timestamp","type":"timestamp"}], + "datarows":[["trace3","User login attempt","2025-01-01T01:00:00Z"],["trace3","Authentication failed","2025-01-01T01:00:01Z"],["trace3","Account locked","2025-01-01T01:00:02Z"], + ["trace4","Database connection timeout","2025-01-01T01:00:10Z"],["trace4","Retry connection","2025-01-01T01:00:11Z"],["trace4","Connection failed","2025-01-01T01:00:12Z"], + ["trace5","User login attempt","2025-01-01T01:00:20Z"],["trace5","Authentication successful","2025-01-01T01:00:21Z"],["trace5","Session created","2025-01-01T01:00:22Z"]], + "total":9,"size":9} + """; + + // Mock sequential PPL calls - first base, then selection + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + when(pplQueryResponse.getResult()) + .thenReturn(baseResponse) // First call returns base data + .thenReturn(selectionResponse); // Second call returns selection data + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "traceFieldName", + "traceId", + "baseTimeRangeStart", + "2025-01-01T00:00:00Z", + "baseTimeRangeEnd", + "2025-01-01T01:00:00Z", + "selectionTimeRangeStart", + "2025-01-01T01:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T02:00:00Z" + ), + ActionListener.wrap(response -> { + System.out.println("Sequence analysis response: " + response); + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("BASE") || result.getAsJsonObject().has("EXCEPTIONAL")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidParameters() { + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap.of("index", "test_index"), + ActionListener + .wrap( + response -> fail("Should have failed with invalid parameters"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Missing required parameters")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithEmptyPPLResponse() { + String emptyResponse = """ + {"schema":[],"datarows":[],"total":0,"size":0} + """; + + mockPPLInvocation(emptyResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener.wrap(response -> { + System.out.println("response: " + response); + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("logInsights")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testExecutionFailedInPPL() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("PPL execution failed")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("PPL execution failed:")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithIndexNotFound() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("IndexNotFoundException: no such index")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "nonexistent_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with IndexNotFoundException"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("IndexNotFoundException")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithEmptyPPLResult() { + String emptyResponse = ""; + mockPPLInvocation(emptyResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with empty response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Empty PPL response")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidPPLResponse() { + String invalidResponse = "{\"invalid\":\"response\"}"; + mockPPLInvocation(invalidResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with invalid response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Invalid PPL response")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithNonExistentIndex() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("no such index [nonexistent_index]")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "nonexistent_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with non-existent index"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("no such index")) + ) + ); + } + + @Test + @SneakyThrows + public void testLogInsightWithFilter() { + String pplResponse = + """ + {"schema":[{"name":"patterns_field","type":"string"},{"name":"pattern_count","type":"long"},{"name":"sample_logs","type":"array"}], + "datarows":[["Auth error for user <*>",3,["Auth error for user admin","Auth error for user guest"]]], + "total":1,"size":1} + """; + + AtomicReference capturedPPL = new AtomicReference<>(); + doAnswer(invocation -> { + TransportPPLQueryRequest request = (TransportPPLQueryRequest) invocation.getArguments()[1]; + capturedPPL.set(request.getRequest()); + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + when(pplQueryResponse.getResult()).thenReturn(pplResponse); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .builder() + .put("index", "test_index") + .put("timeField", "@timestamp") + .put("logFieldName", "message") + .put("selectionTimeRangeStart", "2025-01-01T00:00:00Z") + .put("selectionTimeRangeEnd", "2025-01-01T01:00:00Z") + .put("filter", "serviceName='ts-auth-service'") + .build(), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("logInsights")); + // Verify the PPL query contains the filter clause + assertTrue(capturedPPL.get().contains("where serviceName='ts-auth-service'")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testLogPatternDiffWithFilter() { + String baseResponse = """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[100,"User login successful"],[20,"Database query executed"]], + "total":2,"size":2} + """; + + String selectionResponse = """ + {"schema":[{"name":"cnt","type":"long"},{"name":"patterns_field","type":"string"}], + "datarows":[[50,"User login successful"],[80,"Error in authentication <*>"]], + "total":2,"size":2} + """; + + AtomicReference firstPPL = new AtomicReference<>(); + AtomicReference secondPPL = new AtomicReference<>(); + doAnswer(invocation -> { + TransportPPLQueryRequest request = (TransportPPLQueryRequest) invocation.getArguments()[1]; + String ppl = request.getRequest(); + if (firstPPL.get() == null) { + firstPPL.set(ppl); + } else { + secondPPL.set(ppl); + } + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pplQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + when(pplQueryResponse.getResult()).thenReturn(baseResponse).thenReturn(selectionResponse); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .builder() + .put("index", "test_index") + .put("timeField", "@timestamp") + .put("logFieldName", "message") + .put("baseTimeRangeStart", "2025-01-01T00:00:00Z") + .put("baseTimeRangeEnd", "2025-01-01T01:00:00Z") + .put("selectionTimeRangeStart", "2025-01-01T01:00:00Z") + .put("selectionTimeRangeEnd", "2025-01-01T02:00:00Z") + .put("filter", "severity='ERROR'") + .build(), + ActionListener.wrap(response -> { + JsonElement result = gson.fromJson(response, JsonElement.class); + assertTrue(result.getAsJsonObject().has("patternMapDifference")); + // Verify both PPL queries contain the filter clause + assertTrue(firstPPL.get().contains("where severity='ERROR'")); + assertTrue(secondPPL.get().contains("where severity='ERROR'")); + }, e -> fail("Tool execution failed: " + e.getMessage())) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithNonExistentLogField() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Unknown field [nonexistent_field]")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "nonexistent_field", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with non-existent field"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Unknown field")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithInvalidTimeFormat() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Invalid date format: invalid-time-format")); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "invalid-time-format", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with invalid time format"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("Invalid date format")) + ) + ); + } + + @Test + @SneakyThrows + public void testExecutionWithPPLErrorResponse() { + String errorResponse = "{\"error\":{\"type\":\"parsing_exception\",\"reason\":\"Syntax error in PPL query\"}}"; + mockPPLInvocation(errorResponse); + LogPatternAnalysisTool tool = LogPatternAnalysisTool.Factory.getInstance().create(params); + + tool + .run( + ImmutableMap + .of( + "index", + "test_index", + "timeField", + "@timestamp", + "logFieldName", + "message", + "selectionTimeRangeStart", + "2025-01-01T00:00:00Z", + "selectionTimeRangeEnd", + "2025-01-01T01:00:00Z" + ), + ActionListener + .wrap( + response -> fail("Should have failed with PPL error response"), + e -> MatcherAssert.assertThat(e.getMessage(), containsString("PPL query error")) + ) + ); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java b/src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java new file mode 100644 index 00000000..0ffcf06e --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/MetricChangeAnalysisToolTests.java @@ -0,0 +1,619 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.IndicesAdminClient; + +import com.google.common.collect.ImmutableMap; + +import lombok.SneakyThrows; + +public class MetricChangeAnalysisToolTests { + + private Map params = new HashMap<>(); + private final Client client = mock(Client.class); + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + @Mock + private SearchResponse searchResponse; + + @SneakyThrows + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + setupMockMappings(); + MetricChangeAnalysisTool.Factory.getInstance().init(client); + } + + private void setupMockMappings() { + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + + Map properties = ImmutableMap + .builder() + .put("response_time", ImmutableMap.of("type", "long")) + .put("cpu_usage", ImmutableMap.of("type", "double")) + .put("memory_usage", ImmutableMap.of("type", "float")) + .put("status", ImmutableMap.of("type", "keyword")) + .put("@timestamp", ImmutableMap.of("type", "date")) + .build(); + + Map mappingSource = ImmutableMap.of("properties", properties); + when(mappingMetadata.getSourceAsMap()).thenReturn(mappingSource); + when(getMappingsResponse.getMappings()).thenReturn(ImmutableMap.of("test-index", mappingMetadata)); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getMappingsResponse); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + } + + private SearchHit[] createSampleHits(Map... sources) { + SearchHit[] hits = new SearchHit[sources.length]; + for (int i = 0; i < sources.length; i++) { + hits[i] = new SearchHit(i); + hits[i].sourceRef(null); + hits[i].score(1.0f); + // Use reflection to set source + try { + java.lang.reflect.Field sourceField = SearchHit.class.getDeclaredField("source"); + sourceField.setAccessible(true); + sourceField.set(hits[i], sources[i]); + } catch (Exception e) { + // Fallback: create hit with source + } + } + return hits; + } + + private void mockSearchResponse(SearchHit[] hits) { + SearchHits searchHits = new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + when(searchResponse.getHits()).thenReturn(searchHits); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + } + + // ========== Percentile Calculation Tests ========== + + @Test + @SneakyThrows + public void testCalculatePercentile() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentile", List.class, int.class); + calculatePercentileMethod.setAccessible(true); + + // Test with sorted values + List values = List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0); + + double p25 = (double) calculatePercentileMethod.invoke(tool, values, 25); + double p50 = (double) calculatePercentileMethod.invoke(tool, values, 50); + double p75 = (double) calculatePercentileMethod.invoke(tool, values, 75); + double p90 = (double) calculatePercentileMethod.invoke(tool, values, 90); + + assertEquals(3.25, p25, 0.01); + assertEquals(5.5, p50, 0.01); + assertEquals(7.75, p75, 0.01); + assertEquals(9.1, p90, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileWithSingleValue() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentile", List.class, int.class); + calculatePercentileMethod.setAccessible(true); + + List values = List.of(5.0); + + double p25 = (double) calculatePercentileMethod.invoke(tool, values, 25); + double p50 = (double) calculatePercentileMethod.invoke(tool, values, 50); + double p75 = (double) calculatePercentileMethod.invoke(tool, values, 75); + double p90 = (double) calculatePercentileMethod.invoke(tool, values, 90); + + assertEquals(5.0, p25, 0.01); + assertEquals(5.0, p50, 0.01); + assertEquals(5.0, p75, 0.01); + assertEquals(5.0, p90, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileWithEmptyList() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentile", List.class, int.class); + calculatePercentileMethod.setAccessible(true); + + List values = List.of(); + + double p50 = (double) calculatePercentileMethod.invoke(tool, values, 50); + + assertEquals(0.0, p50, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentiles() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentilesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculatePercentiles", List.class); + calculatePercentilesMethod.setAccessible(true); + + List values = List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0); + + Object result = calculatePercentilesMethod.invoke(tool, values); + + assertNotNull(result); + java.lang.reflect.Method p50Method = result.getClass().getDeclaredMethod("p50"); + java.lang.reflect.Method p90Method = result.getClass().getDeclaredMethod("p90"); + + double p50 = (double) p50Method.invoke(result); + double p90 = (double) p90Method.invoke(result); + + assertEquals(5.5, p50, 0.01); + assertEquals(9.1, p90, 0.01); + } + + // ========== Numeric Value Extraction Tests ========== + + @Test + @SneakyThrows + public void testExtractNumericValues() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method extractNumericValuesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("extractNumericValues", List.class, String.class); + extractNumericValuesMethod.setAccessible(true); + + List> data = List + .of( + Map.of("level", 1, "status", "error"), + Map.of("level", 2, "status", "warning"), + Map.of("level", 3, "status", "info"), + Map.of("level", "4", "status", "debug"), // String number + Map.of("status", "error") // Missing level field + ); + + @SuppressWarnings("unchecked") + List values = (List) extractNumericValuesMethod.invoke(tool, data, "level"); + + assertNotNull(values); + assertEquals(4, values.size()); + assertTrue(values.contains(1.0)); + assertTrue(values.contains(2.0)); + assertTrue(values.contains(3.0)); + assertTrue(values.contains(4.0)); + } + + @Test + @SneakyThrows + public void testExtractNumericValuesWithNonNumericData() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method extractNumericValuesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("extractNumericValues", List.class, String.class); + extractNumericValuesMethod.setAccessible(true); + + List> data = List.of(Map.of("status", "error"), Map.of("status", "warning"), Map.of("status", "info")); + + @SuppressWarnings("unchecked") + List values = (List) extractNumericValuesMethod.invoke(tool, data, "status"); + + assertNotNull(values); + assertTrue(values.isEmpty()); + } + + @Test + @SneakyThrows + public void testExtractNumericValuesWithNestedField() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method extractNumericValuesMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("extractNumericValues", List.class, String.class); + extractNumericValuesMethod.setAccessible(true); + + List> data = List + .of( + Map.of("metrics", Map.of("response_time", 100)), + Map.of("metrics", Map.of("response_time", 200)), + Map.of("metrics", Map.of("response_time", 300)) + ); + + @SuppressWarnings("unchecked") + List values = (List) extractNumericValuesMethod.invoke(tool, data, "metrics.response_time"); + + assertNotNull(values); + assertEquals(3, values.size()); + assertTrue(values.contains(100.0)); + assertTrue(values.contains(200.0)); + assertTrue(values.contains(300.0)); + } + + // ========== Log Ratio Tests ========== + + @Test + @SneakyThrows + public void testSafeLogRatio() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method safeLogRatioMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("safeLogRatio", double.class, double.class); + safeLogRatioMethod.setAccessible(true); + + // Test 2x increase: |log(2/1)| = log(2) ≈ 0.693 + double ratio1 = (double) safeLogRatioMethod.invoke(tool, 2.0, 1.0); + assertEquals(Math.log(2.0), ratio1, 0.01); + + // Test 1.8x increase: |log(180/100)| = log(1.8) ≈ 0.588 + double ratio2 = (double) safeLogRatioMethod.invoke(tool, 180.0, 100.0); + assertEquals(Math.log(1.8), ratio2, 0.01); + + // Test decrease: |log(5/10)| = |log(0.5)| = log(2) ≈ 0.693 + double ratio3 = (double) safeLogRatioMethod.invoke(tool, 5.0, 10.0); + assertEquals(Math.log(2.0), ratio3, 0.01); + + // Test zero baseline: should return cap (10.0) + double ratio4 = (double) safeLogRatioMethod.invoke(tool, 10.0, 0.0); + assertEquals(10.0, ratio4, 0.01); + + // Test both near zero: should return 0.0 + double ratio5 = (double) safeLogRatioMethod.invoke(tool, 0.0, 0.0); + assertEquals(0.0, ratio5, 0.01); + + // Test no change: |log(1)| = 0 + double ratio6 = (double) safeLogRatioMethod.invoke(tool, 100.0, 100.0); + assertEquals(0.0, ratio6, 0.01); + } + + // ========== Variance Calculation Tests ========== + + @Test + @SneakyThrows + public void testCalculatePercentileVarianceSkipsZeroBaseline() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + // Both baselines zero → score 0 + Object sel1 = constructor.newInstance(10.0, 5.0); + Object base1 = constructor.newInstance(0.0, 0.0); + assertEquals(0.0, (double) calculatePercentileVarianceMethod.invoke(tool, sel1, base1), 0.01); + + // Only P50 baseline zero → score based on P90 only + Object sel2 = constructor.newInstance(10.0, 20.0); + Object base2 = constructor.newInstance(0.0, 10.0); + double expected = Math.abs(Math.log(20.0 / 10.0)); // log(2) ≈ 0.693 + assertEquals(expected, (double) calculatePercentileVarianceMethod.invoke(tool, sel2, base2), 0.01); + + // Only P90 baseline zero → score based on P50 only + Object sel3 = constructor.newInstance(20.0, 5.0); + Object base3 = constructor.newInstance(10.0, 0.0); + double expected3 = Math.abs(Math.log(20.0 / 10.0)); // log(2) ≈ 0.693 + assertEquals(expected3, (double) calculatePercentileVarianceMethod.invoke(tool, sel3, base3), 0.01); + } + + // ========== Variance Calculation Tests ========== + + @Test + @SneakyThrows + public void testCalculatePercentileVariance() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + // Create PercentileStats using reflection (p50, p90) + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + // selection p50=20, p90=40; baseline p50=10, p90=20 → both are 2x + Object selectionStats = constructor.newInstance(20.0, 40.0); + Object baselineStats = constructor.newInstance(10.0, 20.0); + + double variance = (double) calculatePercentileVarianceMethod.invoke(tool, selectionStats, baselineStats); + + // score = 0.5 * log(2) + 0.5 * log(2) = log(2) ≈ 0.693 + assertEquals(Math.log(2.0), variance, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileVarianceWithNoChange() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + Object selectionStats = constructor.newInstance(20.0, 40.0); + Object baselineStats = constructor.newInstance(20.0, 40.0); + + double variance = (double) calculatePercentileVarianceMethod.invoke(tool, selectionStats, baselineStats); + + assertEquals(0.0, variance, 0.01); + } + + @Test + @SneakyThrows + public void testCalculatePercentileVarianceRelativeChange() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculatePercentileVarianceMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod( + "calculatePercentileVariance", + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"), + Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats") + ); + calculatePercentileVarianceMethod.setAccessible(true); + + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + java.lang.reflect.Constructor constructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + constructor.setAccessible(true); + + // Test that 1->2 (2x) ranks higher than 100->180 (1.8x) + Object smallChangeStats = constructor.newInstance(2.0, 2.0); + Object smallBaselineStats = constructor.newInstance(1.0, 1.0); + + Object largeChangeStats = constructor.newInstance(180.0, 180.0); + Object largeBaselineStats = constructor.newInstance(100.0, 100.0); + + double smallVariance = (double) calculatePercentileVarianceMethod.invoke(tool, smallChangeStats, smallBaselineStats); + double largeVariance = (double) calculatePercentileVarianceMethod.invoke(tool, largeChangeStats, largeBaselineStats); + + // 2x change: 0.5 * log(2) + 0.5 * log(2) = log(2) ≈ 0.693 + // 1.8x change: 0.5 * log(1.8) + 0.5 * log(1.8) = log(1.8) ≈ 0.588 + assertEquals(Math.log(2.0), smallVariance, 0.01); + assertEquals(Math.log(1.8), largeVariance, 0.01); + assertTrue("2x change should rank higher than 1.8x change", smallVariance > largeVariance); + } + + // ========== Percentile Analysis Tests ========== + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysis() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + // Use non-monotonic (gauge-like) data to avoid counter detection + List> selectionData = List + .of( + Map.of("response_time", 400, "cpu_usage", 80), + Map.of("response_time", 100, "cpu_usage", 50), + Map.of("response_time", 300, "cpu_usage", 60), + Map.of("response_time", 200, "cpu_usage", 70) + ); + + List> baselineData = List + .of( + Map.of("response_time", 150, "cpu_usage", 65), + Map.of("response_time", 50, "cpu_usage", 45), + Map.of("response_time", 200, "cpu_usage", 75), + Map.of("response_time", 100, "cpu_usage", 55) + ); + + java.util.Set numberFields = java.util.Set.of("response_time", "cpu_usage"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + assertEquals(2, analyses.size()); + + // Verify first analysis has highest variance + java.lang.reflect.Method varianceMethod = analyses.get(0).getClass().getDeclaredMethod("variance"); + double firstVariance = (double) varianceMethod.invoke(analyses.get(0)); + double secondVariance = (double) varianceMethod.invoke(analyses.get(1)); + + assertTrue(firstVariance >= secondVariance); + } + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysisWithEmptyData() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + List> selectionData = List.of(); + List> baselineData = List.of(); + java.util.Set numberFields = java.util.Set.of("response_time"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + assertTrue(analyses.isEmpty()); + } + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysisWithMissingFields() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + List> selectionData = List.of(Map.of("response_time", 100), Map.of("response_time", 200)); + + List> baselineData = List.of(Map.of("cpu_usage", 50), Map.of("cpu_usage", 60)); + + java.util.Set numberFields = java.util.Set.of("response_time", "cpu_usage"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + // Should skip fields that don't have data in both datasets + assertTrue(analyses.isEmpty()); + } + + @Test + @SneakyThrows + public void testCalculateMetricChangeAnalysisRanking() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method calculateMetricChangeAnalysisMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("calculateMetricChangeAnalysis", List.class, List.class, java.util.Set.class); + calculateMetricChangeAnalysisMethod.setAccessible(true); + + // Create non-monotonic data where response_time has high change and cpu_usage has low change + List> selectionData = List + .of( + Map.of("response_time", 3000, "cpu_usage", 52), + Map.of("response_time", 1000, "cpu_usage", 54), + Map.of("response_time", 4000, "cpu_usage", 51), + Map.of("response_time", 2000, "cpu_usage", 53) + ); + + List> baselineData = List + .of( + Map.of("response_time", 300, "cpu_usage", 52), + Map.of("response_time", 100, "cpu_usage", 50), + Map.of("response_time", 400, "cpu_usage", 53), + Map.of("response_time", 200, "cpu_usage", 51) + ); + + java.util.Set numberFields = java.util.Set.of("response_time", "cpu_usage"); + + @SuppressWarnings("unchecked") + List analyses = (List) calculateMetricChangeAnalysisMethod.invoke(tool, selectionData, baselineData, numberFields); + + assertNotNull(analyses); + assertEquals(2, analyses.size()); + + // First field should be response_time (higher variance) + java.lang.reflect.Method fieldMethod = analyses.get(0).getClass().getDeclaredMethod("field"); + String firstField = (String) fieldMethod.invoke(analyses.get(0)); + assertEquals("response_time", firstField); + + // Second field should be cpu_usage (lower variance) + String secondField = (String) fieldMethod.invoke(analyses.get(1)); + assertEquals("cpu_usage", secondField); + } + + @Test + @SneakyThrows + public void testFormatResultsLimitsToTopTen() { + MetricChangeAnalysisTool tool = MetricChangeAnalysisTool.Factory.getInstance().create(params); + + java.lang.reflect.Method formatResultsMethod = MetricChangeAnalysisTool.class + .getDeclaredMethod("formatResults", List.class, int.class); + formatResultsMethod.setAccessible(true); + + // Create 15 fields with different variance scores + List analyses = new ArrayList<>(); + Class fieldAnalysisClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$FieldPercentileAnalysis"); + Class percentileStatsClass = Class.forName("org.opensearch.agent.tools.MetricChangeAnalysisTool$PercentileStats"); + + java.lang.reflect.Constructor statsConstructor = percentileStatsClass.getDeclaredConstructor(double.class, double.class); + statsConstructor.setAccessible(true); + + java.lang.reflect.Constructor analysisConstructor = fieldAnalysisClass + .getDeclaredConstructor(String.class, double.class, percentileStatsClass, percentileStatsClass); + analysisConstructor.setAccessible(true); + + Object stats = statsConstructor.newInstance(20.0, 40.0); + + // Create 15 fields with descending variance scores + for (int i = 0; i < 15; i++) { + double variance = 15.0 - i; // 15.0, 14.0, 13.0, ..., 1.0 + Object analysis = analysisConstructor.newInstance("field_" + i, variance, stats, stats); + analyses.add(analysis); + } + + // Test with topN = 10 + @SuppressWarnings("unchecked") + List results10 = (List) formatResultsMethod.invoke(tool, analyses, 10); + assertNotNull(results10); + assertEquals("Should return only top 10 results", 10, results10.size()); + + // Test with topN = 5 (default) + @SuppressWarnings("unchecked") + List results5 = (List) formatResultsMethod.invoke(tool, analyses, 5); + assertNotNull(results5); + assertEquals("Should return only top 5 results", 5, results5.size()); + + // Test with topN = 3 + @SuppressWarnings("unchecked") + List results3 = (List) formatResultsMethod.invoke(tool, analyses, 3); + assertNotNull(results3); + assertEquals("Should return only top 3 results", 3, results3.size()); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java index 5669e639..140022be 100644 --- a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java @@ -10,6 +10,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.when; +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -23,12 +24,14 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; @@ -176,6 +179,36 @@ public void testTool() { } + @Test + public void testToolWhenGettingSagemakerError() { + PPLTool tool = PPLTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "head", "100")); + assertEquals(PPLTool.TYPE, tool.getName()); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + OpenSearchStatusException exception = new OpenSearchStatusException( + "Error from remote service: {\"ErrorCode\":\"CLIENT_ERROR_FROM_MODEL\",\"LogStreamArn\":\"arn:aws:logs:us-east-1:12345678:log-group:/aws/sagemaker/Endpoints/demo-test-name\",\"Message\":\"Received client error (404) from primary with message \\\"{\\n \\\"code\\\":404,\\n \\\"message\\\":\\\"prediction failure\\\",\\n \\\"error\\\":\\\"Input token limit exceeded. The model only supports schemas with less than 1000-1500 fields, and has optimal performance for 350 fields or fewer.\\\"\\n}\\\". See https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logEventViewer:group=/aws/sagemaker/Endpoints/demo-test-name in account 12345678 for more information.\",\"OriginalMessage\":\"{\\n \\\"code\\\":500,\\n \\\"message\\\":\\\"prediction failure\\\",\\n \\\"error\\\":\\\"Input token limit exceeded. The model only supports schemas with less than 1000-1500 fields, and has optimal performance for 350 fields or fewer.\\\"\\n}\",\"OriginalStatusCode\":404}", + RestStatus.fromCode(404) + ); + + listener.onFailure(exception); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + OpenSearchStatusException exception = assertThrows( + OpenSearchStatusException.class, + () -> tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertEquals("ppl result", returnResults.get("executionResult")); + assertEquals("source=demo| head 1", returnResults.get("ppl")); + }, e -> { throw new OpenSearchStatusException(e.getMessage(), ((OpenSearchStatusException) e).status()); })) + ); + assertTrue(exception.getMessage().contains("")); + assertFalse(exception.getMessage().contains("demo-test-name")); + assertFalse(exception.getMessage().contains("12345678")); + + } + @Test public void testTool_ForSparkInputWithWrongSchema() { PPLTool tool = PPLTool.Factory @@ -292,6 +325,18 @@ public void testTool_ForSparkInputWithStructInput() { } + @Test + public void testTool_basic() { + PPLTool tool = PPLTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "previous_tool_name", "previousTool", "head", "-5")); + assertEquals(tool.getDescription(), PPLTool.Factory.getInstance().getDefaultDescription()); + assertEquals(tool.getType(), PPLTool.Factory.getInstance().getDefaultType()); + assertEquals(null, PPLTool.Factory.getInstance().getDefaultVersion()); + assertEquals(List.of(COMMON_MODEL_ID_FIELD), PPLTool.Factory.getInstance().getAllModelKeys()); + + } + @Test public void testTool_withPreviousInput() { PPLTool tool = PPLTool.Factory diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java index 66e1c7d8..9b5a48e2 100644 --- a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java @@ -95,6 +95,8 @@ public void setup() { null, null, null, + null, + null, null ); } diff --git a/src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java new file mode 100644 index 00000000..fb8dba22 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchAroundDocumentToolTests.java @@ -0,0 +1,687 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.transport.client.Client; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; + +public class SearchAroundDocumentToolTests { + + private Client client; + private SearchAroundDocumentTool tool; + private static final Gson GSON = new Gson(); + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + client = mock(Client.class); + SearchAroundDocumentTool.Factory.getInstance().init(client, NamedXContentRegistry.EMPTY); + tool = SearchAroundDocumentTool.Factory.getInstance().create(Collections.emptyMap()); + } + + private SearchHit createMockHit(String id, String index, Map source, Object[] sortValues) { + SearchHit hit = mock(SearchHit.class); + when(hit.getId()).thenReturn(id); + when(hit.getIndex()).thenReturn(index); + when(hit.getScore()).thenReturn(1.0f); + when(hit.getSourceAsMap()).thenReturn(source); + when(hit.getSortValues()).thenReturn(sortValues); + return hit; + } + + private SearchResponse createMockSearchResponse(SearchHit[] hits) { + SearchResponse response = mock(SearchResponse.class); + SearchHits searchHits = mock(SearchHits.class); + when(searchHits.getHits()).thenReturn(hits); + when(response.getHits()).thenReturn(searchHits); + return response; + } + + private void mockThreeSearchCalls(SearchResponse targetResponse, SearchResponse beforeResponse, SearchResponse afterResponse) { + AtomicInteger callCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + int call = callCount.getAndIncrement(); + switch (call) { + case 0: + listener.onResponse(targetResponse); + break; + case 1: + listener.onResponse(beforeResponse); + break; + case 2: + listener.onResponse(afterResponse); + break; + default: + listener.onFailure(new RuntimeException("Unexpected search call")); + } + return null; + }).when(client).search(any(SearchRequest.class), any()); + } + + // ========== Validate Tests ========== + + @Test + public void testValidateWithNullParameters() { + assertFalse(tool.validate(null)); + } + + @Test + public void testValidateWithEmptyParameters() { + assertFalse(tool.validate(Collections.emptyMap())); + } + + @Test + public void testValidateWithJsonInput() { + Map params = new HashMap<>(); + params.put("input", "{\"index\":\"test\",\"doc_id\":\"1\",\"timestamp_field\":\"@timestamp\",\"count\":5}"); + assertTrue(tool.validate(params)); + } + + @Test + public void testValidateWithDirectParameters() { + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "5"); + assertTrue(tool.validate(params)); + } + + @Test + public void testValidateWithMissingParameters() { + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc1"); + assertFalse(tool.validate(params)); + } + + @Test + public void testValidateWithEmptyValues() { + Map params = new HashMap<>(); + params.put("index", ""); + params.put("doc_id", "doc1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "5"); + assertFalse(tool.validate(params)); + } + + // ========== Factory Tests ========== + + @Test + public void testFactoryCreate() { + SearchAroundDocumentTool.Factory factory = SearchAroundDocumentTool.Factory.getInstance(); + SearchAroundDocumentTool createdTool = factory.create(Collections.emptyMap()); + assertNotNull(createdTool); + assertEquals(SearchAroundDocumentTool.TYPE, createdTool.getType()); + } + + @Test + public void testFactoryGetInstance() { + SearchAroundDocumentTool.Factory factory1 = SearchAroundDocumentTool.Factory.getInstance(); + SearchAroundDocumentTool.Factory factory2 = SearchAroundDocumentTool.Factory.getInstance(); + assertTrue(factory1 == factory2); + } + + @Test + public void testFactoryDefaults() { + SearchAroundDocumentTool.Factory factory = SearchAroundDocumentTool.Factory.getInstance(); + assertEquals(SearchAroundDocumentTool.TYPE, factory.getDefaultType()); + assertNotNull(factory.getDefaultDescription()); + assertNull(factory.getDefaultVersion()); + assertNotNull(factory.getDefaultAttributes()); + } + + // ========== Tool Metadata Tests ========== + + @Test + public void testGetType() { + assertEquals("SearchAroundDocumentTool", tool.getType()); + } + + @Test + public void testGetVersion() { + assertNull(tool.getVersion()); + } + + // ========== Run - Success Tests ========== + + @Test + public void testRunWithDirectParameters() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit( + "target-doc", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:05", "message", "target"), + sortValues + ); + + SearchHit beforeHit1 = createMockHit( + "before-1", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:03", "message", "before1"), + new Object[] { 998L, 3L } + ); + SearchHit beforeHit2 = createMockHit( + "before-2", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:04", "message", "before2"), + new Object[] { 999L, 4L } + ); + + SearchHit afterHit1 = createMockHit( + "after-1", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:06", "message", "after1"), + new Object[] { 1001L, 6L } + ); + SearchHit afterHit2 = createMockHit( + "after-2", + "test-index", + Map.of("@timestamp", "2024-01-01T00:00:07", "message", "after2"), + new Object[] { 1002L, 7L } + ); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + // Before: returned in DESC order (before2, before1) - tool reverses to chronological + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] { beforeHit2, beforeHit1 }); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] { afterHit1, afterHit2 }); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "target-doc"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + assertNotNull(result); + + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(5, docs.size()); + + // Verify chronological order: before1, before2, target, after1, after2 + assertEquals("before-1", docs.get(0).get("_id")); + assertEquals("before-2", docs.get(1).get("_id")); + assertEquals("target-doc", docs.get(2).get("_id")); + assertEquals("after-1", docs.get(3).get("_id")); + assertEquals("after-2", docs.get(4).get("_id")); + } + + @Test + public void testRunWithJsonInput() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit( + "doc-1", + "my-index", + Map.of("@timestamp", "2024-01-01T00:00:05", "message", "target"), + sortValues + ); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("input", "{\"index\":\"my-index\",\"doc_id\":\"doc-1\",\"timestamp_field\":\"@timestamp\",\"count\":2}"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + assertNotNull(result); + + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + assertEquals("doc-1", docs.get(0).get("_id")); + } + + @Test + public void testRunWithNoBeforeOrAfterDocuments() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "only doc"), sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "5"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + assertEquals("doc-1", docs.get(0).get("_id")); + } + + @Test + public void testRunWithCountAsDouble() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2.0"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + assertNotNull(result); + } + + @Test + public void testRunResponseContainsSortValues() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "1"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + + Map doc = docs.get(0); + assertEquals("doc-1", doc.get("_id")); + assertEquals("test-index", doc.get("_index")); + assertNotNull(doc.get("_source")); + assertNotNull(doc.get("sort")); + } + + // ========== Run - Error Tests ========== + + @Test + public void testRunWithDocumentNotFound() throws Exception { + SearchResponse emptyResponse = createMockSearchResponse(new SearchHit[] {}); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(emptyResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "nonexistent"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("Document not found")); + } + } + + @Test + public void testRunWithMissingSortValues() throws Exception { + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), new Object[] {}); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(targetResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("sort values")); + } + } + + @Test + public void testRunWithNullSortValues() throws Exception { + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), null); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(targetResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("sort values")); + } + } + + @Test + public void testRunWithOnlySingleSortValue() throws Exception { + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), new Object[] { 1000L }); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(targetResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("sort values")); + } + } + + @Test + public void testRunWithMissingRequiredParameters() throws Exception { + Map params = new HashMap<>(); + params.put("index", "test-index"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + assertTrue(e.getCause().getMessage().contains("requires")); + } + } + + @Test + public void testRunWithTargetSearchFailure() throws Exception { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Target search failed")); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("Target search failed", e.getCause().getMessage()); + } + } + + @Test + public void testRunWithBeforeSearchFailure() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + + AtomicInteger callCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + int call = callCount.getAndIncrement(); + if (call == 0) { + listener.onResponse(targetResponse); + } else { + listener.onFailure(new RuntimeException("Before search failed")); + } + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("Before search failed", e.getCause().getMessage()); + } + } + + @Test + public void testRunWithAfterSearchFailure() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + SearchHit targetHit = createMockHit("doc-1", "test-index", Map.of("message", "target"), sortValues); + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + + AtomicInteger callCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + int call = callCount.getAndIncrement(); + if (call == 0) { + listener.onResponse(targetResponse); + } else if (call == 1) { + listener.onResponse(beforeResponse); + } else { + listener.onFailure(new RuntimeException("After search failed")); + } + return null; + }).when(client).search(any(SearchRequest.class), any()); + + Map params = new HashMap<>(); + params.put("index", "test-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "2"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + try { + future.get(); + assertTrue("Should have thrown", false); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RuntimeException); + assertEquals("After search failed", e.getCause().getMessage()); + } + } + + // ========== Run - Ordering Tests ========== + + @Test + public void testRunBeforeDocsAreReversedToChronological() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + + SearchHit targetHit = createMockHit("target", "idx", Map.of("ts", 1000), sortValues); + + // Before search returns in DESC order: doc3 (newest before), doc2, doc1 (oldest before) + SearchHit beforeHit3 = createMockHit("before-3", "idx", Map.of("ts", 999), new Object[] { 999L, 4L }); + SearchHit beforeHit2 = createMockHit("before-2", "idx", Map.of("ts", 998), new Object[] { 998L, 3L }); + SearchHit beforeHit1 = createMockHit("before-1", "idx", Map.of("ts", 997), new Object[] { 997L, 2L }); + + SearchHit afterHit1 = createMockHit("after-1", "idx", Map.of("ts", 1001), new Object[] { 1001L, 6L }); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] { beforeHit3, beforeHit2, beforeHit1 }); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] { afterHit1 }); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "idx"); + params.put("doc_id", "target"); + params.put("timestamp_field", "ts"); + params.put("count", "3"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(5, docs.size()); + + // Before docs should be reversed to chronological (oldest first) + assertEquals("before-1", docs.get(0).get("_id")); + assertEquals("before-2", docs.get(1).get("_id")); + assertEquals("before-3", docs.get(2).get("_id")); + assertEquals("target", docs.get(3).get("_id")); + assertEquals("after-1", docs.get(4).get("_id")); + } + + @Test + public void testRunResponseDocStructure() throws Exception { + Object[] sortValues = new Object[] { 1000L, 5L }; + Map source = Map.of("@timestamp", "2024-01-01", "level", "INFO", "message", "test log"); + + SearchHit targetHit = createMockHit("doc-1", "logs-index", source, sortValues); + + SearchResponse targetResponse = createMockSearchResponse(new SearchHit[] { targetHit }); + SearchResponse beforeResponse = createMockSearchResponse(new SearchHit[] {}); + SearchResponse afterResponse = createMockSearchResponse(new SearchHit[] {}); + + mockThreeSearchCalls(targetResponse, beforeResponse, afterResponse); + + Map params = new HashMap<>(); + params.put("index", "logs-index"); + params.put("doc_id", "doc-1"); + params.put("timestamp_field", "@timestamp"); + params.put("count", "1"); + + CompletableFuture future = new CompletableFuture<>(); + tool.run(params, ActionListener.wrap(future::complete, future::completeExceptionally)); + + String result = future.get(); + List> docs = GSON.fromJson(result, new TypeToken>>() { + }.getType()); + assertEquals(1, docs.size()); + + Map doc = docs.get(0); + assertEquals("doc-1", doc.get("_id")); + assertEquals("logs-index", doc.get("_index")); + assertNotNull(doc.get("_score")); + assertNotNull(doc.get("sort")); + + @SuppressWarnings("unchecked") + Map returnedSource = (Map) doc.get("_source"); + assertEquals("2024-01-01", returnedSource.get("@timestamp")); + assertEquals("INFO", returnedSource.get("level")); + assertEquals("test log", returnedSource.get("message")); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java b/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java index 5b6dfa7f..4008904d 100644 --- a/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java +++ b/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java @@ -12,6 +12,9 @@ import org.junit.Test; import org.opensearch.agent.tools.utils.ToolHelper; +import org.opensearch.agent.tools.utils.mergeMetaData.MergeRuleHelper; + +import com.google.gson.Gson; import lombok.extern.log4j.Log4j2; @@ -87,4 +90,224 @@ private void assertMapEquals(Map expected, Map a assertEquals(entry.getValue(), actual.get(entry.getKey())); } } + + private Gson gson = new Gson(); + + private Map prepareMap1() { + String mapBlock = """ + { + "event": { + "properties": { + "field1": { + "type": "string" + } + } + } + } + + """; + Map tmpMap = gson.fromJson(mapBlock, Map.class); + return tmpMap; + } + + private Map prepareMap2() { + String mapBlock = """ + { + "event": { + "properties": { + "field2": { + "type": "string" + } + } + } + } + + """; + Map tmpMap = gson.fromJson(mapBlock, Map.class); + return tmpMap; + } + + private Map prepareNormalMap1() { + String mapBlock = """ + { + "event1": { + "properties": { + "field1": { + "type": "string" + } + } + }, + "replace" : { + "type":"string" + } + + } + + """; + Map tmpMap = gson.fromJson(mapBlock, Map.class); + return tmpMap; + } + + private Map prepareNormalMap2() { + String mapBlock = """ + { + "event2": { + "properties": { + "field2": { + "type": "string" + } + } + }, + "replace" : { + "type":"keyword" + } + } + + """; + Map tmpMap = gson.fromJson(mapBlock, Map.class); + return tmpMap; + } + + @Test + public void testMergeTwoObjectMaps() { + String mapBlock = """ + { + "event": { + "properties": { + "field1": { + "type": "string" + }, + "field2": { + "type": "string" + } + } + } + } + + """; + Map allFields = new HashMap<>(); + Map map1 = prepareMap1(); + Map map2 = prepareMap2(); + MergeRuleHelper.merge(map1, allFields); + MergeRuleHelper.merge(map2, allFields); + assertEquals(allFields, gson.fromJson(mapBlock, Map.class)); + } + + @Test + public void testMergeTwoNormalMaps() { + String mapBlock = """ + { + "event1": { + "properties": { + "field1": { + "type": "string" + } + } + }, + "event2": { + "properties": { + "field2": { + "type": "string" + } + } + }, + "replace" : { + "type":"keyword" + } + } + + """; + Map allFields = new HashMap<>(); + Map map1 = prepareNormalMap1(); + Map map2 = prepareNormalMap2(); + MergeRuleHelper.merge(map1, allFields); + MergeRuleHelper.merge(map2, allFields); + assertEquals(allFields, gson.fromJson(mapBlock, Map.class)); + } + + @Test + public void testMergeTwoDeepMaps() { + String mapBlock = """ + { + "event": { + "properties": { + "field1": { + "type": "string" + }, + "field2": { + "type": "string" + }, + "deep": { + "properties": { + "field1": { + "type": "string" + }, + "field2": { + "type": "string" + } + } + } + } + } + + } + + """; + Map allFields = new HashMap<>(); + Map map1 = prepareDeepMap1(); + Map map2 = prepareDeepMap2(); + MergeRuleHelper.merge(map1, allFields); + MergeRuleHelper.merge(map2, allFields); + assertEquals(allFields, gson.fromJson(mapBlock, Map.class)); + } + + private Map prepareDeepMap1() { + String mapBlock = """ + { + "event": { + "properties": { + "field1": { + "type": "string" + }, + "deep": { + "properties": { + "field1": { + "type": "string" + } + } + } + } + } + + } + + """; + Map tmpMap = gson.fromJson(mapBlock, Map.class); + return tmpMap; + } + + private Map prepareDeepMap2() { + String mapBlock = """ + { + "event": { + "properties": { + "field2": { + "type": "string" + }, + "deep": { + "properties": { + "field2": { + "type": "string" + } + } + } + } + } + } + + """; + Map tmpMap = gson.fromJson(mapBlock, Map.class); + return tmpMap; + } + } diff --git a/src/test/java/org/opensearch/agent/tools/utils/ClusteringHelperTests.java b/src/test/java/org/opensearch/agent/tools/utils/ClusteringHelperTests.java new file mode 100644 index 00000000..81345627 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/utils/ClusteringHelperTests.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.agent.tools.utils.clustering.ClusteringHelper; +import org.opensearch.test.OpenSearchTestCase; + +public class ClusteringHelperTests extends OpenSearchTestCase { + + public void testConstructorWithValidThreshold() { + new ClusteringHelper(0.0); + new ClusteringHelper(0.5); + new ClusteringHelper(1.0); + } + + public void testConstructorWithInvalidThreshold() { + assertThrows(IllegalArgumentException.class, () -> new ClusteringHelper(-0.1)); + assertThrows(IllegalArgumentException.class, () -> new ClusteringHelper(1.1)); + } + + public void testClusterLogVectorsWithNullInput() { + ClusteringHelper helper = new ClusteringHelper(0.8); + assertTrue(helper.clusterLogVectorsAndGetRepresentative(null).isEmpty()); + } + + public void testClusterLogVectorsWithEmptyInput() { + ClusteringHelper helper = new ClusteringHelper(0.8); + assertTrue(helper.clusterLogVectorsAndGetRepresentative(new HashMap<>()).isEmpty()); + } + + public void testClusterLogVectorsWithSingleVector() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 2.0, 3.0 }); + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertEquals(1, result.size()); + assertEquals("trace1", result.get(0)); + } + + public void testClusterLogVectorsWithSmallDataset() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 0.0, 0.0 }); + logVectors.put("trace2", new double[] { 0.9, 0.1, 0.0 }); + logVectors.put("trace3", new double[] { 0.0, 1.0, 0.0 }); + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertFalse(result.isEmpty()); + assertTrue(result.size() <= 3); + } + + public void testValidateLogVectorsWithNullTraceId() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put(null, new double[] { 1.0, 2.0 }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithEmptyTraceId() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("", new double[] { 1.0, 2.0 }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithNullVector() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", null); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithEmptyVector() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] {}); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithDimensionMismatch() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 2.0 }); + logVectors.put("trace2", new double[] { 1.0, 2.0, 3.0 }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithNaNValue() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, Double.NaN }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testValidateLogVectorsWithInfiniteValue() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, Double.POSITIVE_INFINITY }); + + assertThrows(IllegalArgumentException.class, () -> helper.clusterLogVectorsAndGetRepresentative(logVectors)); + } + + public void testClusterLogVectorsWithLargeDataset() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + + // Create 1500 vectors to trigger large dataset processing + for (int i = 0; i < 1500; i++) { + double[] vector = new double[] { Math.random(), Math.random(), Math.random() }; + logVectors.put("trace" + i, vector); + } + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertFalse(result.isEmpty()); + assertTrue(result.size() < 1500); // Should reduce the number of representatives + } + + public void testClusterLogVectorsWithIdenticalVectors() { + ClusteringHelper helper = new ClusteringHelper(0.8); + Map logVectors = new HashMap<>(); + double[] vector = { 1.0, 2.0, 3.0 }; + + for (int i = 0; i < 5; i++) { + logVectors.put("trace" + i, vector.clone()); + } + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertEquals(1, result.size()); // Should cluster identical vectors into one + } + + public void testClusterLogVectorsWithHighThreshold() { + ClusteringHelper helper = new ClusteringHelper(0.99); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 0.0 }); + logVectors.put("trace2", new double[] { 0.0, 1.0 }); + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertEquals(2, result.size()); // High threshold should keep vectors separate + } + + public void testClusterLogVectorsWithLowThreshold() { + ClusteringHelper helper = new ClusteringHelper(0.1); + Map logVectors = new HashMap<>(); + logVectors.put("trace1", new double[] { 1.0, 0.1 }); + logVectors.put("trace2", new double[] { 0.9, 0.2 }); + + List result = helper.clusterLogVectorsAndGetRepresentative(logVectors); + assertTrue(result.size() <= 2); // Low threshold may cluster similar vectors + } +} diff --git a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java index 76d5c72e..66cedc2f 100644 --- a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java +++ b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java @@ -65,6 +65,7 @@ public void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100); updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); updateClusterSettings("plugins.ml_commons.agent_framework_enabled", true); + updateClusterSettings("plugins.ml_commons.connector.private_ip_enabled", true); } @SneakyThrows diff --git a/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java new file mode 100644 index 00000000..8c2f197a --- /dev/null +++ b/src/test/java/org/opensearch/integTest/DataDistributionToolIT.java @@ -0,0 +1,476 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; + +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; +import org.opensearch.ml.common.utils.StringUtils; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; + +import lombok.SneakyThrows; + +public class DataDistributionToolIT extends BaseAgentToolsIT { + + public static String requestBodyResourceFile = + "org/opensearch/agent/tools/register_flow_agent_of_data_distribution_tool_request_body.json"; + public String registerAgentRequestBody; + public static String TEST_DATA_INDEX_NAME = "test_data_distribution_index"; + + private String agentId; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + prepareDataIndex(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(requestBodyResourceFile).toURI())); + agentId = createAgent(registerAgentRequestBody); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + @SneakyThrows + private void prepareDataIndex() { + createIndexWithConfiguration( + TEST_DATA_INDEX_NAME, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"@timestamp\": {\n" + + " \"type\": \"date\",\n" + + " \"format\": \"yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis\"\n" + + " },\n" + + " \"status\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"level\": {\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " \"host\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"response_time\": {\n" + + " \"type\": \"float\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + + // Add baseline data (09:00:00 to 10:00:00) + addDocToIndex( + TEST_DATA_INDEX_NAME, + "base1", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 09:30:00", "success", 1, "server-01", 120.5) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "base2", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 09:45:00", "success", 1, "server-02", 95.2) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "base3", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 09:50:00", "info", 2, "server-01", 110.8) + ); + + // Add selection data (10:00:00 to 11:00:00) + addDocToIndex( + TEST_DATA_INDEX_NAME, + "sel1", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 10:15:00", "error", 3, "server-01", 250.3) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "sel2", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 10:30:00", "error", 4, "server-02", 180.7) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "sel3", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 10:45:00", "warning", 2, "server-03", 140.1) + ); + addDocToIndex( + TEST_DATA_INDEX_NAME, + "sel4", + List.of("@timestamp", "status", "level", "host", "response_time"), + List.of("2025-01-01 10:50:00", "error", 3, "server-01", 300.5) + ); + } + + @SneakyThrows + public void testDataDistributionToolSingleAnalysis() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolComparisonAnalysis() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"baselineTimeRangeStart\": \"2025-01-01 09:00:00\", \"baselineTimeRangeEnd\": \"2025-01-01 10:00:00\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"comparisonAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"success\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.67},{\"value\":\"info\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"1\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.67},{\"value\":\"3\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.33},{\"value\":\"4\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"110.8\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"95.2\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"120.5\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.33},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'term': {'status': 'error'}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolMissingRequiredParameters() { + Exception exception = assertThrows(Exception.class, () -> executeAgent(agentId, "{\"parameters\": {\"index\": \"test_index\"}}")); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid time format")); + } + + @SneakyThrows + public void testDataDistributionToolInvalidIndex() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"non_existent_index\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("no such index")); + } + + @SneakyThrows + public void testDataDistributionToolPPLSingleAnalysis() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"queryType\": \"ppl\", \"ppl\": \"source=%s\"}}", + TEST_DATA_INDEX_NAME, + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.5},{\"value\":\"2.0\",\"selectionPercentage\":0.25},{\"value\":\"4.0\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolPPLComparisonAnalysis() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"baselineTimeRangeStart\": \"2025-01-01 09:00:00\", \"baselineTimeRangeEnd\": \"2025-01-01 10:00:00\", \"queryType\": \"ppl\", \"ppl\": \"source=%s\"}}", + TEST_DATA_INDEX_NAME, + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"comparisonAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75,\"baselinePercentage\":0.0},{\"value\":\"success\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.67},{\"value\":\"info\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"warning\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"1.0\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.67},{\"value\":\"3.0\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.0},{\"value\":\"2.0\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.33},{\"value\":\"4.0\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"110.8\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"95.2\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"120.5\",\"selectionPercentage\":0.0,\"baselinePercentage\":0.33},{\"value\":\"140.1\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"250.3\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"180.7\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0},{\"value\":\"300.5\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]},{\"field\":\"host\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5,\"baselinePercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.33},{\"value\":\"server-03\",\"selectionPercentage\":0.25,\"baselinePercentage\":0.0}]}]}"; + assertEquals(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolPPLWithCustomQuery() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"queryType\": \"ppl\", \"ppl\": \"source=%s | where level > 2\"}}", + TEST_DATA_INDEX_NAME, + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67},{\"value\":\"4.0\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithDSLQueryType() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"queryType\": \"dsl\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithMultipleFilters() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'term': {'status': 'error'}}\\\", \\\"{'range': {'level': {'gte': 3}}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithCustomSize() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"size\": \"500\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithCustomTimeField() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"timeField\": \"@timestamp\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithRangeFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'range': {'response_time': {'gte': 150.0}}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithMatchFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'match': {'status': 'error'}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithRawDSLQuery() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"dsl\": \"{\\\"bool\\\": {\\\"must\\\": [{\\\"term\\\": {\\\"status\\\": \\\"error\\\"}}]}}\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.67},{\"value\":\"4\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolWithExistsFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"[\\\"{'exists': {'field': 'response_time'}}\\\"]\"}}", + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":0.75,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":0.75},{\"value\":\"warning\",\"selectionPercentage\":0.25}]},{\"field\":\"level\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"3\",\"selectionPercentage\":0.5},{\"value\":\"2\",\"selectionPercentage\":0.25},{\"value\":\"4\",\"selectionPercentage\":0.25}]},{\"field\":\"host\",\"divergence\":0.5,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.5},{\"value\":\"server-02\",\"selectionPercentage\":0.25},{\"value\":\"server-03\",\"selectionPercentage\":0.25}]},{\"field\":\"response_time\",\"divergence\":0.25,\"topChanges\":[{\"value\":\"140.1\",\"selectionPercentage\":0.25},{\"value\":\"250.3\",\"selectionPercentage\":0.25},{\"value\":\"180.7\",\"selectionPercentage\":0.25},{\"value\":\"300.5\",\"selectionPercentage\":0.25}]}]}"; + assertResults(expectedResult, result); + } + + @SneakyThrows + public void testDataDistributionToolInvalidFilterFormat() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"filter\": \"invalid-json\"}}", + TEST_DATA_INDEX_NAME + ) + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid 'filter' parameter")); + } + + @SneakyThrows + public void testDataDistributionToolInvalidSizeParameter() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"size\": \"not-a-number\"}}", + TEST_DATA_INDEX_NAME + ) + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid 'size' parameter")); + } + + @SneakyThrows + public void testDataDistributionToolInvalidTimeFormat() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"invalid-time-format\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\"}}", + TEST_DATA_INDEX_NAME + ) + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("Invalid time format")); + } + + @SneakyThrows + public void testDataDistributionToolPPLWithComplexQuery() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 11:00:00\", \"queryType\": \"ppl\", \"ppl\": \"source=%s | where level > 2\"}}", + TEST_DATA_INDEX_NAME, + TEST_DATA_INDEX_NAME + ) + ); + + String expectedResult = + "{\"singleAnalysis\":[{\"field\":\"status\",\"divergence\":1.0,\"topChanges\":[{\"value\":\"error\",\"selectionPercentage\":1.0}]},{\"field\":\"level\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"3.0\",\"selectionPercentage\":0.67},{\"value\":\"4.0\",\"selectionPercentage\":0.33}]},{\"field\":\"host\",\"divergence\":0.6666666666666666,\"topChanges\":[{\"value\":\"server-01\",\"selectionPercentage\":0.67},{\"value\":\"server-02\",\"selectionPercentage\":0.33}]},{\"field\":\"response_time\",\"divergence\":0.3333333333333333,\"topChanges\":[{\"value\":\"250.3\",\"selectionPercentage\":0.33},{\"value\":\"180.7\",\"selectionPercentage\":0.33},{\"value\":\"300.5\",\"selectionPercentage\":0.33}]}]}"; + assertResults(expectedResult, result); + } + + private void assertResults(String expectedResult, String result) { + try { + JsonNode resultJson = StringUtils.MAPPER.readTree(result); + JsonNode expectedJson = StringUtils.MAPPER.readTree(expectedResult); + JsonNode expectedAnalysis = expectedJson.get("singleAnalysis"); + JsonNode resultAnalysis = resultJson.get("singleAnalysis"); + for (int i = 0; i < expectedAnalysis.size(); i++) { + assertEquals(expectedAnalysis.get(i).get("field").asText(), resultAnalysis.get(i).get("field").asText()); + assertEquals(expectedAnalysis.get(i).get("divergence").asText(), resultAnalysis.get(i).get("divergence").asText()); + JsonNode expectedTopChanges = expectedAnalysis.get(i).get("topChanges"); + JsonNode resultTopChanges = resultAnalysis.get(i).get("topChanges"); + for (int j = 0; j < expectedTopChanges.size(); j++) { + assertEquals( + expectedTopChanges.get(j).get("selectionPercentage").asText(), + resultTopChanges.get(j).get("selectionPercentage").asText() + ); + assertEquals(expectedTopChanges.get(j).get("value").asText(), resultTopChanges.get(j).get("value").asText()); + } + } + } catch (JsonProcessingException e) { + fail("Failed to process jsons"); + } + } +} diff --git a/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java b/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java new file mode 100644 index 00000000..9ea60485 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/LogPatternAnalysisToolIT.java @@ -0,0 +1,268 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; + +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import lombok.SneakyThrows; + +public class LogPatternAnalysisToolIT extends BaseAgentToolsIT { + + public static String requestBodyResourceFile = + "org/opensearch/agent/tools/register_flow_agent_of_log_pattern_analysis_tool_request_body.json"; + public String registerAgentRequestBody; + public static String TEST_LOG_INDEX_NAME = "test_log_analysis_index"; + + private String agentId; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + prepareLogIndex(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(requestBodyResourceFile).toURI())); + agentId = createAgent(registerAgentRequestBody); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + @SneakyThrows + private void prepareLogIndex() { + createIndexWithConfiguration( + TEST_LOG_INDEX_NAME, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"@timestamp\": {\n" + + " \"type\": \"date\",\n" + + " \"format\": \"yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis\"\n" + + " },\n" + + " \"message\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"traceId\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"serviceName\": {\n" + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + + // Add baseline data in base time range (09:00:00 to 10:00:00) + addDocToIndex( + TEST_LOG_INDEX_NAME, + "base1", + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 09:30:00", "System startup completed", "trace-base-001", "auth-service") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "base2", + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 09:45:00", "Database connection established", "trace-base-002", "db-service") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "base3", + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 09:50:00", "User session initialized", "trace-base-003", "auth-service") + ); + + // Add test log data with error keywords for logInsight + addDocToIndex( + TEST_LOG_INDEX_NAME, + "1", + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:00:00", "User login successful", "trace-001", "auth-service") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "2", + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:01:00", "Database connection established", "trace-001", "db-service") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "3", + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:02:00", "Error connection timeout failed", "trace-002", "db-service") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "4", + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:03:00", "User logout completed", "trace-001", "auth-service") + ); + addDocToIndex( + TEST_LOG_INDEX_NAME, + "5", + List.of("@timestamp", "message", "traceId", "serviceName"), + List.of("2025-01-01 10:04:00", "Exception in authentication service", "trace-003", "auth-service") + ); + } + + @SneakyThrows + public void testLogPatternAnalysisToolLogInsight() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("logInsights")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithBaseTimeRange() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("patternMapDifference")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithTraceField() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"traceFieldName\": \"traceId\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}", + TEST_LOG_INDEX_NAME + ) + ); + System.out.println(result); + assertNotNull(result); + assertTrue(result.contains("BASE") || result.contains("EXCEPTIONAL")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolMissingRequiredParameters() { + Exception exception = assertThrows(Exception.class, () -> executeAgent(agentId, "{\"parameters\": {\"index\": \"%s\"}}")); + MatcherAssert.assertThat(exception.getMessage(), containsString("Missing required parameters")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolInvalidIndex() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"non_existent_index\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("no such index")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolNonExistentLogField() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"nonexistent_field\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("not a valid term")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolInvalidTimeFormat() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"invalid-time-format\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("not a valid term")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolLogInsightWithFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\", \"filter\": \"serviceName='db-service'\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("logInsights")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithBaseTimeRangeAndFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\", \"filter\": \"serviceName='auth-service'\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("patternMapDifference")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolWithTraceFieldAndFilter() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"traceFieldName\": \"traceId\", \"baseTimeRangeStart\": \"2025-01-01 09:00:00\", \"baseTimeRangeEnd\": \"2025-01-01 10:00:00\", \"selectionTimeRangeStart\": \"2025-01-01 10:00:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:05:00\", \"filter\": \"serviceName='auth-service'\"}}", + TEST_LOG_INDEX_NAME + ) + ); + assertNotNull(result); + assertTrue(result.contains("BASE") || result.contains("EXCEPTIONAL")); + } + + @SneakyThrows + public void testLogPatternAnalysisToolEmptyTimeRange() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + "{\"parameters\": {\"index\": \"%s\", \"timeField\": \"@timestamp\", \"logFieldName\": \"message\", \"selectionTimeRangeStart\": \"2025-01-01 10:05:00\", \"selectionTimeRangeEnd\": \"2025-01-01 10:00:00\"}}" + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("not a valid term")); + } +} diff --git a/src/test/java/org/opensearch/integTest/PPLToolIT.java b/src/test/java/org/opensearch/integTest/PPLToolIT.java index cf576be8..9ca930aa 100644 --- a/src/test/java/org/opensearch/integTest/PPLToolIT.java +++ b/src/test/java/org/opensearch/integTest/PPLToolIT.java @@ -54,9 +54,9 @@ public void testPPLTool() { String agentId = registerAgent(); String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee\"}}"); assertEquals( - "{\"ppl\":\"source\\u003demployee| where age \\u003e 56 | stats COUNT() as cnt\"," + "{\"ppl\":\"source=employee | where age > 56 | stats COUNT() as cnt\"," + "\"executionResult\":\"{\\n \\\"schema\\\": [\\n {\\n \\\"name\\\": \\\"cnt\\\",\\n " - + "\\\"type\\\": \\\"int\\\"\\n }\\n ],\\n \\\"datarows\\\": [\\n [\\n 0\\n ]\\n ],\\n " + + "\\\"type\\\": \\\"bigint\\\"\\n }\\n ],\\n \\\"datarows\\\": [\\n [\\n 0\\n ]\\n ],\\n " + "\\\"total\\\": 1,\\n \\\"size\\\": 1\\n}\"}", result ); @@ -114,7 +114,7 @@ public void testPPLTool_withNonExistingIndex_thenThrowException() { exception.getMessage(), allOf( containsString( - "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + "Return this final answer to human directly and do not use other tools: 'Please provide the existing index name(s)'. Please try to directly send this message to human to ask for index name" ) ) ); diff --git a/src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java b/src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java new file mode 100644 index 00000000..50f45b29 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/SearchAroundDocumentToolIT.java @@ -0,0 +1,280 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; + +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; + +import lombok.SneakyThrows; + +public class SearchAroundDocumentToolIT extends BaseAgentToolsIT { + + private static final String TEST_INDEX_NAME = "test_search_around_document_index"; + private static final String REGISTER_AGENT_RESOURCE = + "org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json"; + + private String registerAgentRequestBody; + private String agentId; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(REGISTER_AGENT_RESOURCE).toURI())); + prepareDataIndex(); + agentId = createAgent(registerAgentRequestBody); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + @SneakyThrows + private void prepareDataIndex() { + createIndexWithConfiguration(TEST_INDEX_NAME, """ + { + "mappings": { + "properties": { + "@timestamp": { + "type": "date", + "format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis" + }, + "message": { + "type": "text" + }, + "level": { + "type": "keyword" + } + } + } + }"""); + + // Index 7 documents with known timestamps and IDs + addDocToIndex( + TEST_INDEX_NAME, + "doc1", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:00:00", "First log entry", "INFO") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc2", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:10:00", "Second log entry", "INFO") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc3", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:20:00", "Third log entry", "WARN") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc4", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:30:00", "Fourth log entry - target", "ERROR") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc5", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:40:00", "Fifth log entry", "WARN") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc6", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 09:50:00", "Sixth log entry", "INFO") + ); + addDocToIndex( + TEST_INDEX_NAME, + "doc7", + List.of("@timestamp", "message", "level"), + List.of("2025-01-01 10:00:00", "Seventh log entry", "ERROR") + ); + } + + @SneakyThrows + public void testSearchAroundDocument_basicSearch() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc4\", \"timestamp_field\": \"@timestamp\", \"count\": \"2\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // Should have 5 documents: 2 before + target + 2 after + assertEquals(5, docs.size()); + + // Verify chronological order: doc2, doc3, doc4 (target), doc5, doc6 + assertEquals("doc2", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc3", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc4", docs.get(2).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(3).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc6", docs.get(4).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_countExceedsAvailable() { + // doc1 is the first document, requesting 5 before but only 0 exist before it + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc1\", \"timestamp_field\": \"@timestamp\", \"count\": \"5\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // Should have target + up to 5 after (doc2-doc7 = 6 after, but count=5) + // No before docs since doc1 is the earliest + assertEquals(6, docs.size()); + + // First should be the target (doc1), followed by 5 after docs + assertEquals("doc1", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc2", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc3", docs.get(2).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc4", docs.get(3).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(4).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc6", docs.get(5).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_lastDocument() { + // doc7 is the last document, requesting 3 before + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc7\", \"timestamp_field\": \"@timestamp\", \"count\": \"3\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // Should have 3 before + target, no after docs + assertEquals(4, docs.size()); + + assertEquals("doc4", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc6", docs.get(2).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc7", docs.get(3).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_countOfOne() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc4\", \"timestamp_field\": \"@timestamp\", \"count\": \"1\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + + // 1 before + target + 1 after = 3 + assertEquals(3, docs.size()); + + assertEquals("doc3", docs.get(0).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc4", docs.get(1).getAsJsonObject().get("_id").getAsString()); + assertEquals("doc5", docs.get(2).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_jsonInput() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"input\": \"{\\\"index\\\": \\\"%s\\\", \\\"doc_id\\\": \\\"doc4\\\", \\\"timestamp_field\\\": \\\"@timestamp\\\", \\\"count\\\": 2}\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + assertEquals(5, docs.size()); + assertEquals("doc4", docs.get(2).getAsJsonObject().get("_id").getAsString()); + } + + @SneakyThrows + public void testSearchAroundDocument_responseContainsSource() { + String result = executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"doc4\", \"timestamp_field\": \"@timestamp\", \"count\": \"1\"}}", + TEST_INDEX_NAME + ) + ); + + JsonArray docs = JsonParser.parseString(result).getAsJsonArray(); + // Verify the target document has _source with correct fields + JsonElement targetDoc = docs.get(1); + assertTrue(targetDoc.getAsJsonObject().has("_source")); + assertTrue(targetDoc.getAsJsonObject().has("_id")); + assertTrue(targetDoc.getAsJsonObject().has("_index")); + + String source = targetDoc.getAsJsonObject().get("_source").toString(); + assertTrue(source.contains("Fourth log entry - target")); + assertTrue(source.contains("ERROR")); + } + + @SneakyThrows + public void testSearchAroundDocument_nonExistentDoc() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent( + agentId, + String + .format( + Locale.ROOT, + "{\"parameters\": {\"index\": \"%s\", \"doc_id\": \"nonexistent\", \"timestamp_field\": \"@timestamp\", \"count\": \"2\"}}", + TEST_INDEX_NAME + ) + ) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("Document not found")); + } + + @SneakyThrows + public void testSearchAroundDocument_missingRequiredParameters() { + Exception exception = assertThrows( + Exception.class, + () -> executeAgent(agentId, String.format(Locale.ROOT, "{\"parameters\": {\"index\": \"%s\"}}", TEST_INDEX_NAME)) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("requires")); + } +} diff --git a/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java index afdfb3d4..33f26e99 100644 --- a/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java +++ b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java @@ -19,6 +19,7 @@ import org.opensearch.client.Response; import com.google.gson.Gson; +import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.sun.net.httpserver.HttpServer; @@ -46,8 +47,6 @@ public void setupTestAgent() throws IOException, InterruptedException { connectorId = setUpConnectorWithRetry(5); modelGroupId = setupModelGroup(); modelId = setupLLMModel(connectorId, modelGroupId); - // wait for model to get deployed - TimeUnit.SECONDS.sleep(1); agentId = setupConversationalAgent(modelId); log.info("model_id: {}, agent_id: {}", modelId, agentId); } @@ -172,10 +171,11 @@ private String setupLLMModel(String connectorId, String modelGroupId) throws IOE + "}" ); Response response = executeRequest(request); - String resp = readResponse(response); - - return JsonParser.parseString(resp).getAsJsonObject().get("model_id").getAsString(); + JsonObject respObj = JsonParser.parseString(resp).getAsJsonObject(); + String taskId = respObj.get("task_id").getAsString(); + waitTaskComplete(taskId); + return respObj.get("model_id").getAsString(); } private String setupConversationalAgent(String modelId) throws IOException { diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_data_distribution_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_data_distribution_tool_request_body.json new file mode 100644 index 00000000..31aba32a --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_data_distribution_tool_request_body.json @@ -0,0 +1,10 @@ +{ + "name": "Test_data_distribution_tool_flow_agent", + "type": "flow", + "tools": [ + { + "type": "DataDistributionTool", + "parameters": {} + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_log_pattern_analysis_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_log_pattern_analysis_tool_request_body.json new file mode 100644 index 00000000..86bdbc7c --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_log_pattern_analysis_tool_request_body.json @@ -0,0 +1,10 @@ +{ + "name": "Test_log_pattern_analysis_tool_flow_agent", + "type": "flow", + "tools": [ + { + "type": "LogPatternAnalysisTool", + "parameters": {} + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json new file mode 100644 index 00000000..c97af635 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_around_document_tool_request_body.json @@ -0,0 +1,9 @@ +{ + "name": "Test_Search_Around_Document_Agent", + "type": "flow", + "tools": [ + { + "type": "SearchAroundDocumentTool" + } + ] +} \ No newline at end of file