diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index a8faae88..7dbfaf8e 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -13,6 +13,7 @@ import org.opensearch.agent.tools.CreateAlertTool; import org.opensearch.agent.tools.CreateAnomalyDetectorTool; +import org.opensearch.agent.tools.CreateAnomalyDetectorToolEnhanced; import org.opensearch.agent.tools.DataDistributionTool; import org.opensearch.agent.tools.LogPatternAnalysisTool; import org.opensearch.agent.tools.LogPatternTool; @@ -100,6 +101,7 @@ public Collection createComponents( SearchMonitorsTool.Factory.getInstance().init(client); CreateAlertTool.Factory.getInstance().init(client); CreateAnomalyDetectorTool.Factory.getInstance().init(client); + CreateAnomalyDetectorToolEnhanced.Factory.getInstance().init(client, namedWriteableRegistry); LogPatternTool.Factory.getInstance().init(client, xContentRegistry); WebSearchTool.Factory.getInstance().init(threadPool); LogPatternAnalysisTool.Factory.getInstance().init(client); @@ -123,6 +125,7 @@ public List> getToolFactories() { SearchMonitorsTool.Factory.getInstance(), CreateAlertTool.Factory.getInstance(), CreateAnomalyDetectorTool.Factory.getInstance(), + CreateAnomalyDetectorToolEnhanced.Factory.getInstance(), LogPatternTool.Factory.getInstance(), WebSearchTool.Factory.getInstance(), LogPatternAnalysisTool.Factory.getInstance(), diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index 2c4a2273..e52963fb 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -16,7 +16,6 @@ import java.security.PrivilegedExceptionAction; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; @@ -209,7 +208,7 @@ public void run(Map parameters, ActionListener listener) ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", true); // find all date type fields from the mapping - final Set dateFields = findDateTypeFields(fieldsToType); + final Set dateFields = ToolHelper.findDateTypeFields(fieldsToType); if (dateFields.isEmpty()) { throw new IllegalArgumentException( "The index " + indexName + " doesn't have date type fields, cannot create an anomaly detector for it." @@ -228,6 +227,8 @@ public void run(Map parameters, ActionListener listener) // construct the prompt String prompt = constructPrompt(filteredMapping, firstIndexName); + log.info("Using prompt for anomaly detector creation (simple): {}", prompt); + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet .builder() .parameters(Collections.singletonMap("prompt", prompt)) @@ -335,21 +336,8 @@ private Map enrichParameters(Map parameters) { /** * - * @param fieldsToType the flattened field-> field type mapping - * @return a list containing all the date type fields - */ - private Set findDateTypeFields(final Map fieldsToType) { - Set result = new HashSet<>(); - for (Map.Entry entry : fieldsToType.entrySet()) { - String value = entry.getValue(); - if (value.equals("date") || value.equals("date_nanos")) { - result.add(entry.getKey()); - } - } - return result; - } - @SuppressWarnings("unchecked") + **/ private static Map loadDefaultPromptFromFile() { try (InputStream inputStream = CreateAnomalyDetectorTool.class.getResourceAsStream("CreateAnomalyDetectorDefaultPrompt.json")) { if (inputStream != null) { diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolEnhanced.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolEnhanced.java new file mode 100644 index 00000000..f749aa0f --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolEnhanced.java @@ -0,0 +1,2489 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.agent.tools.utils.CommonConstants.COMMON_MODEL_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.StringJoiner; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import org.apache.commons.text.StringSubstitutor; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.ad.client.AnomalyDetectionNodeClient; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.transport.IndexAnomalyDetectorRequest; +import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; +import org.opensearch.agent.tools.utils.AnomalyDetectorToolHelper; +import org.opensearch.agent.tools.utils.ToolHelper; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.common.Randomness; +import org.opensearch.common.UUIDs; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.indexInsight.IndexInsight; +import org.opensearch.ml.common.indexInsight.MLIndexInsightType; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.spi.tools.WithModelTool; +import org.opensearch.ml.common.transport.indexInsight.MLIndexInsightGetAction; +import org.opensearch.ml.common.transport.indexInsight.MLIndexInsightGetRequest; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.SearchHit; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.bucket.filter.InternalFilter; +import org.opensearch.search.aggregations.bucket.sampler.InternalSampler; +import org.opensearch.search.aggregations.bucket.sampler.SamplerAggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.SuggestConfigParamRequest; +import org.opensearch.timeseries.transport.SuggestConfigParamResponse; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.transport.client.Client; + +import com.google.common.collect.ImmutableMap; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Enhanced tool for creating anomaly detectors with LLM-assisted configuration and validation. + * Analyzes index mappings, generates detector configurations using LLM, validates through multiple phases, + * and automatically creates and starts detectors for multiple indices. + * + * Usage: + * 1. Register agent: + * POST /_plugins/_ml/agents/_register + * { + * "name": "AnomalyDetectorEnhanced", + * "type": "flow", + * "tools": [ + * { + * "name": "create_anomaly_detector_enhanced", + * "type": "CreateAnomalyDetectorToolEnhanced", + * "parameters": { + * "model_id": "model-id", + * "model_type": "CLAUDE" + * } + * } + * ] + * } + * 2. Execute agent: + * POST /_plugins/_ml/agents/{agent_id}/_execute + * { + * "parameters": { + * "input": ["ecommerce-data", "server-logs"] + * } + * } + * 3. Result: detector creation status for each index + * { + * "ecommerce-data": { + * "status": "success", + * "detectorId": "abc123", + * "detectorName": "ecommerce-data-detector-xyz", + * "createResponse": "Detector created successfully", + * "startResponse": "Detector started successfully" + * }, + * "server-logs": { + * "status": "failed_validation", + * "error": "Insufficient data for model training" + * } + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(CreateAnomalyDetectorToolEnhanced.TYPE) +public class CreateAnomalyDetectorToolEnhanced implements WithModelTool { + public static final String TYPE = "CreateAnomalyDetectorToolEnhanced"; + + private static final String DEFAULT_DESCRIPTION = + "Enhanced tool for creating anomaly detector configurations. Takes an index name, extracts the index mappings, and uses LLM to generate complete detector JSON configurations ready for the create detector API."; + + // LLM output format: {key=value|key=value|...} + // Parsed into a map by extracting content between { }, splitting on |, splitting on =. + // Handles any key order, unknown keys, and whitespace. + private static final Pattern BRACES_PATTERN = Pattern.compile("\\{([^}]+)}"); + private static final Pattern INTERVAL_MINUTES_PATTERN = Pattern.compile("(\\d+)\\s*[Mm]inute"); + + private static final String NONE_SIGNAL = "{NONE}"; + private static final int MAX_DETECTORS_PER_INDEX = 3; + + private static final Set VALID_FIELD_TYPES = Set + .of( + "keyword", + "constant_keyword", + "wildcard", + "long", + "integer", + "short", + "byte", + "double", + "float", + "half_float", + "scaled_float", + "unsigned_long", + "ip", + "date", + "date_nanos" + ); + + private static final String OUTPUT_KEY_INDEX = "index"; + private static final String OUTPUT_KEY_CATEGORY_FIELD = "categoryField"; + private static final String OUTPUT_KEY_AGGREGATION_FIELD = "aggregationField"; + private static final String OUTPUT_KEY_AGGREGATION_METHOD = "aggregationMethod"; + private static final String OUTPUT_KEY_DATE_FIELDS = "dateFields"; + private static final Map DEFAULT_PROMPT_DICT = loadDefaultPromptFromFile(); + + private static final int MAX_DETECTOR_VALIDATION_RETRIES = 3; + private static final int MAX_MODEL_VALIDATION_RETRIES = 3; + private static final int MAX_FORMAT_FIX_RETRIES = 1; + + // Detector configuration defaults + private static final int DEFAULT_INTERVAL_MINUTES = 10; + private static final int DEFAULT_OTEL_INTERVAL_MINUTES = 2; + private static final int DEFAULT_WINDOW_DELAY_MINUTES = 1; + private static final int DEFAULT_SHINGLE_SIZE = 8; + private static final int DEFAULT_SCHEMA_VERSION = 0; + private static final String DEFAULT_DETECTOR_DESCRIPTION = "Detector generated by OpenSearch Assistant"; + private static final int MAX_DETECTOR_NAME_LENGTH = 64; + + private static final int SUGGEST_API_TIMEOUT_SECONDS = 30; + private static final int DATE_FIELD_QUERY_TIMEOUT_SECONDS = 10; + private static final String DATE_FIELD_LOOKBACK_PERIOD = "now-30d"; + + private static final String DEFAULT_CUSTOM_RESULT_INDEX = "opensearch-ad-plugin-result-auto-insights"; + private static final int MAX_FREQUENCY_MINUTES = 1440; // 24 hours + private static final int MAX_INDICES_PER_REQUEST = 100; + private static final String LLM_OUTPUT_FORMAT = + "{category_field=FIELD_OR_EMPTY|aggregation_field=FIELD1,FIELD2|aggregation_method=METHOD1,METHOD2|date_field=DATE_FIELD" + + "|filter=FIELD:OP:VALUE_OR_EMPTY|interval=MINUTES|description=ONE_SENTENCE}"; + private static final int MAX_PROMPT_FIELDS = 200; + private static final int FIELD_FILTER_THRESHOLD = 30; + private static final int FIELD_FILTER_SAMPLE_SIZE = 100000; + private static final double FIELD_NULL_THRESHOLD = 0.001; // drop fields present in <0.1% of docs + + private String name = TYPE; + private String description = DEFAULT_DESCRIPTION; + private String version; + private Client client; + private AnomalyDetectionNodeClient adClient; + private String modelId; + private ModelType modelType; + private String contextPrompt; + private String customResultIndex; + private Map attributes; + + enum ModelType { + CLAUDE, + OPENAI; + + public static ModelType from(String value) { + return valueOf(value.toUpperCase(Locale.ROOT)); + } + + } + + /** + * + * @param client the OpenSearch transport client + * @param modelId the model ID of LLM + * @param modelType the model type (CLAUDE or OPENAI) + * @param contextPrompt custom prompt (if empty, loads from default file) + * @param namedWriteableRegistry the named writeable registry + */ + public CreateAnomalyDetectorToolEnhanced( + Client client, + String modelId, + String modelType, + String contextPrompt, + String customResultIndex, + NamedWriteableRegistry namedWriteableRegistry + ) { + this.client = client; + this.adClient = new AnomalyDetectionNodeClient(client, namedWriteableRegistry); + this.modelId = modelId; + this.customResultIndex = (!Strings.isNullOrEmpty(customResultIndex)) ? customResultIndex : DEFAULT_CUSTOM_RESULT_INDEX; + if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { + throw new IllegalArgumentException("Unsupported model_type: " + modelType); + } + this.modelType = ModelType.from(modelType); + + if (Strings.isNullOrEmpty(contextPrompt)) { + this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), ""); + } else { + this.contextPrompt = contextPrompt; + } + + if (this.contextPrompt == null || this.contextPrompt.trim().isEmpty()) { + throw new IllegalArgumentException("Configuration error: detector creation prompt not found"); + } + } + + /** + * Creates anomaly detectors for specified indices with automatic validation and error handling. + * + * @param parameters Map containing "input" with JSON array of index names + * @param listener ActionListener to receive results as JSON string + */ + @Override + public void run(Map parameters, ActionListener listener) { + log.debug("[CreateAnomalyDetectorToolEnhanced] run() invoked — customResultIndex={}", customResultIndex); + if (parameters.containsKey("input")) { + String inputStr = parameters.get("input"); + if (inputStr != null && inputStr.trim().startsWith("[")) { + parameters.put("input", "{\"indices\": " + inputStr + "}"); + } + } + parameters = ToolUtils.extractInputParameters(parameters, attributes); + final String tenantId = parameters.get(TENANT_ID_FIELD); + try { + List indices = AnomalyDetectorToolHelper.extractIndicesList(parameters); + validateIndices(indices); + int maxRetries = Math.min(Integer.parseInt(parameters.getOrDefault("maxRetries", String.valueOf(MAX_FORMAT_FIX_RETRIES))), 3); + + processMultipleIndices(indices, tenantId, maxRetries, listener); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private static void validateIndices(List indices) { + if (indices.size() > MAX_INDICES_PER_REQUEST) { + throw new IllegalArgumentException("Too many indices: " + indices.size() + ". Maximum is " + MAX_INDICES_PER_REQUEST + "."); + } + for (String idx : indices) { + if (Strings.isNullOrEmpty(idx)) { + throw new IllegalArgumentException("Index name cannot be empty."); + } + if (idx.startsWith(".")) { + throw new IllegalArgumentException("System indices not supported: " + idx); + } + if (idx.length() > 255 || idx.contains("\n") || idx.contains("\r")) { + throw new IllegalArgumentException( + "Invalid index name: " + idx.substring(0, Math.min(50, idx.length())).replaceAll("[\\n\\r]", "") + ); + } + } + } + + private void processMultipleIndices(List indices, String tenantId, int maxRetries, ActionListener listener) { + Map results = new HashMap<>(); + processNextIndex(indices, 0, tenantId, maxRetries, results, listener); + } + + private void processNextIndex( + List indices, + int currentIndex, + String tenantId, + int maxRetries, + Map results, + ActionListener listener + ) { + if (currentIndex >= indices.size()) { + listener.onResponse((T) gson.toJson(results)); + return; + } + + String indexName = indices.get(currentIndex); + processSingleIndex(indexName, tenantId, maxRetries, new ActionListener() { + @Override + public void onResponse(String result) { + // All paths now return a JSON array of detector results + List> resultList = gson.fromJson(result, List.class); + results.put(indexName, resultList); + processNextIndex(indices, currentIndex + 1, tenantId, maxRetries, results, listener); + } + + @Override + public void onFailure(Exception e) { + results.put(indexName, List.of(DetectorResult.failedValidation(indexName, e.getMessage()).toMap())); + processNextIndex(indices, currentIndex + 1, tenantId, maxRetries, results, listener); + } + }); + } + + // Flow: get index insight -> get mappings -> filter null fields -> LLM generates config -> validate -> create detector + private void processSingleIndex(String indexName, String tenantId, int maxRetries, ActionListener listener) { + // First, try to get Index Insight analysis (graceful fallback if unavailable) + getIndexInsight(indexName, tenantId, ActionListener.wrap(indexInsight -> { + getMappingsAndFilterFields(indexName, ActionListener.wrap(mappingContext -> { + MappingContext enhancedContext = mappingContext.withIndexInsight(indexInsight); + filterNullFieldsIfNeeded(enhancedContext, tenantId, maxRetries, listener); + }, listener::onFailure)); + }, e -> { + log.warn("Index Insight failed for '{}', proceeding without: {}", indexName, e.getMessage()); + getMappingsAndFilterFields( + indexName, + ActionListener + .wrap(mappingContext -> filterNullFieldsIfNeeded(mappingContext, tenantId, maxRetries, listener), listener::onFailure) + ); + })); + } + + /** + * If mapping has more than FIELD_FILTER_THRESHOLD fields, run a sampler aggregation to drop + * fields that are null in >99.9% of sampled docs (same approach as Index Insight). + * Otherwise, proceed directly. + */ + private void filterNullFieldsIfNeeded(MappingContext ctx, String tenantId, int maxRetries, ActionListener listener) { + if (ctx.filteredMapping.size() <= FIELD_FILTER_THRESHOLD) { + proceedWithLLM(ctx, tenantId, maxRetries, listener); + return; + } + log + .info( + "Index '{}' has {} fields (>{} threshold), running null field filter", + ctx.indexName, + ctx.filteredMapping.size(), + FIELD_FILTER_THRESHOLD + ); + + filterNullFields(ctx.indexName, ctx.filteredMapping, ActionListener.wrap(filtered -> { + log.info("Null filter reduced '{}' from {} to {} fields", ctx.indexName, ctx.filteredMapping.size(), filtered.size()); + MappingContext filteredCtx = new MappingContext( + ctx.indexName, + filtered, + ctx.dateFields, + ctx.indexInsight, + ctx.detectorDecisions, + ctx.sampleDocs, + ctx.dataDensity24h + ); + proceedWithLLM(filteredCtx, tenantId, maxRetries, listener); + }, e -> { + log.warn("Null field filter failed for '{}', using unfiltered mapping: {}", ctx.indexName, e.getMessage()); + proceedWithLLM(ctx, tenantId, maxRetries, listener); + })); + } + + private void proceedWithLLM(MappingContext mappingContext, String tenantId, int maxRetries, ActionListener listener) { + // Check for OTel fast-path before LLM + OtelSignalType otelType = detectOtelSignal(mappingContext.filteredMapping); + if (otelType != null) { + log.info("OTel {} mapping detected for '{}', using predefined detectors", otelType, mappingContext.indexName); + createOtelDetectors(mappingContext.indexName, otelType, mappingContext.filteredMapping, listener); + return; + } + + // Pre-filter date fields: drop any with 0 docs in last 30d so LLM only sees viable options + filterDateFieldsByDensity( + mappingContext, + ActionListener.wrap(filteredCtx -> enrichAndCreateDetectors(filteredCtx, tenantId, maxRetries, listener), e -> { + log.warn("Date field filtering failed for '{}', using all date fields: {}", mappingContext.indexName, e.getMessage()); + enrichAndCreateDetectors(mappingContext, tenantId, maxRetries, listener); + }) + ); + } + + /** Gather sample docs and data density, then run the sequential multi-detector loop. */ + private void enrichAndCreateDetectors(MappingContext ctx, String tenantId, int maxRetries, ActionListener listener) { + String dateField = ctx.dateFields.iterator().next(); + getSampleDocuments(ctx.indexName, 10, ActionListener.wrap(sampleDocs -> { + getDataDensity(ctx.indexName, dateField, ActionListener.wrap(density -> { + MappingContext enrichedCtx = ctx.withSampleDocs(sampleDocs).withDataDensity(density); + createMultipleDetectors(enrichedCtx, tenantId, maxRetries, new ArrayList<>(), new ArrayList<>(), 0, listener); + }, e -> { + MappingContext enrichedCtx = ctx.withSampleDocs(sampleDocs); + createMultipleDetectors(enrichedCtx, tenantId, maxRetries, new ArrayList<>(), new ArrayList<>(), 0, listener); + })); + }, e -> { createMultipleDetectors(ctx, tenantId, maxRetries, new ArrayList<>(), new ArrayList<>(), 0, listener); })); + } + + // ── OTel fast-path ──────────────────────────────────────────────────────── + + enum OtelSignalType { + TRACES, + LOGS + } + + /** + * Detect OTel signal type from index mapping fields. + * Traces: Data Prepper otel-v1-apm-span standard template fields. + * Logs: SS4O log schema fields. + * Metrics: intentionally deferred — key-value schema requires filter-by-name support. + */ + /** Check for field presence, accounting for .keyword sub-fields in text mappings. */ + private static boolean hasField(Map fields, String name) { + return fields.containsKey(name) || fields.containsKey(name + ".keyword"); + } + + /** Resolve to the keyword variant of a field. AD category fields must be keyword type. */ + private static String resolveKeywordField(Map fields, String name) { + if (fields.containsKey(name) && "keyword".equals(fields.get(name))) + return name; + if (fields.containsKey(name + ".keyword")) + return name + ".keyword"; + return name; // fallback — validation will catch if wrong type + } + + private OtelSignalType detectOtelSignal(Map fields) { + if (hasField(fields, "traceId") + && hasField(fields, "spanId") + && hasField(fields, "durationInNanos") + && hasField(fields, "serviceName")) { + return OtelSignalType.TRACES; + } + if (hasField(fields, "severityNumber") + && hasField(fields, "severityText") + && hasField(fields, "resource.attributes.service.name")) { + return OtelSignalType.LOGS; + } + return null; + } + + /** Predefined OTel detector configuration. */ + private static class OtelDetectorSpec { + final String nameSuffix; + final String timeField; + final String categoryField; + final String featureField; + final QueryBuilder featureFilter; // null = plain count, non-null = filter-wrapped count + + OtelDetectorSpec(String nameSuffix, String timeField, String categoryField, String featureField, QueryBuilder featureFilter) { + this.nameSuffix = nameSuffix; + this.timeField = timeField; + this.categoryField = categoryField; + this.featureField = featureField; + this.featureFilter = featureFilter; + } + } + + /** Resolve the best time field for OTel detectors from the mapping. Prefers the canonical field for each signal type. */ + private static String resolveOtelTimeField(OtelSignalType type, Map fields) { + // Priority order per signal type based on official OTel/SS4O schemas + List candidates = type == OtelSignalType.TRACES + ? List.of("startTime", "@timestamp", "time") + : List.of("@timestamp", "time", "observedTimestamp"); + for (String candidate : candidates) { + String fieldType = fields.get(candidate); + if (fieldType != null && (fieldType.equals("date") || fieldType.equals("date_nanos"))) { + return candidate; + } + } + return candidates.get(0); // fallback to first preference + } + + private List buildOtelSpecs(OtelSignalType type, Map fields) { + String timeField = resolveOtelTimeField(type, fields); + List specs = new ArrayList<>(); + if (type == OtelSignalType.TRACES) { + specs + .add( + new OtelDetectorSpec( + "trace-errors", + timeField, + resolveKeywordField(fields, "serviceName"), + timeField, + QueryBuilders.termQuery("status.code", 2) + ) + ); + specs.add(new OtelDetectorSpec("trace-throughput", timeField, resolveKeywordField(fields, "serviceName"), timeField, null)); + } else { + specs + .add( + new OtelDetectorSpec( + "log-errors", + timeField, + resolveKeywordField(fields, "resource.attributes.service.name"), + timeField, + QueryBuilders.rangeQuery("severityNumber").gte(17) + ) + ); + specs + .add( + new OtelDetectorSpec( + "log-volume", + timeField, + resolveKeywordField(fields, "resource.attributes.service.name"), + timeField, + null + ) + ); + } + return specs; + } + + private void createOtelDetectors(String indexName, OtelSignalType type, Map fields, ActionListener listener) { + List specs = buildOtelSpecs(type, fields); + List> results = new ArrayList<>(); + createOtelDetectorSequentially(specs, indexName, 0, results, listener); + } + + @SuppressWarnings("unchecked") + private void createOtelDetectorSequentially( + List specs, + String indexName, + int idx, + List> results, + ActionListener listener + ) { + if (idx >= specs.size()) { + log + .info( + "OTel detectors done for '{}': {}/{} succeeded", + indexName, + results.stream().filter(r -> "success".equals(r.get("status"))).count(), + results.size() + ); + listener.onResponse((T) gson.toJson(results)); + return; + } + OtelDetectorSpec spec = specs.get(idx); + String suffix = "-" + spec.nameSuffix + "-" + UUIDs.randomBase64UUID().substring(0, 6); + String truncatedIndex = indexName.substring(0, Math.min(indexName.length(), MAX_DETECTOR_NAME_LENGTH - suffix.length())); + String detectorName = truncatedIndex + suffix; + + AggregationBuilder innerAgg = AnomalyDetectorToolHelper.createAggregationBuilder("count", spec.featureField); + // Wrap filter inside the feature aggregation (not on the detector) so HC entities + // with 0 matching docs still get a model with value=0, enabling 0→N anomaly detection. + AggregationBuilder featureAgg = spec.featureFilter != null + ? AggregationBuilders.filter(spec.nameSuffix.replace("-", "_"), spec.featureFilter).subAggregation(innerAgg) + : innerAgg; + Feature feature = new Feature(UUIDs.randomBase64UUID(), spec.nameSuffix.replace("-", "_"), true, featureAgg); + + AnomalyDetector detector = new AnomalyDetector( + null, + null, + detectorName, + DEFAULT_DETECTOR_DESCRIPTION, + spec.timeField, + List.of(indexName), + List.of(feature), + QueryBuilders.matchAllQuery(), + new IntervalTimeConfiguration(DEFAULT_OTEL_INTERVAL_MINUTES, ChronoUnit.MINUTES), + new IntervalTimeConfiguration(DEFAULT_WINDOW_DELAY_MINUTES, ChronoUnit.MINUTES), + DEFAULT_SHINGLE_SIZE, + null, + DEFAULT_SCHEMA_VERSION, + Instant.now(), + List.of(spec.categoryField), + null, + customResultIndex, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + new IntervalTimeConfiguration(calculateFrequencyMinutes(DEFAULT_OTEL_INTERVAL_MINUTES), ChronoUnit.MINUTES), + true + ); + + // Use suggest API to find optimal interval, then create + suggestAndCreateOtelDetector(detector, specs, indexName, idx, results, listener); + } + + /** Call suggest API for interval optimization, then create and start the detector. */ + @SuppressWarnings("unchecked") + private void suggestAndCreateOtelDetector( + AnomalyDetector detector, + List specs, + String indexName, + int idx, + List> results, + ActionListener listener + ) { + SuggestConfigParamRequest suggestRequest = new SuggestConfigParamRequest( + AnalysisType.AD, + detector, + "interval", + TimeValue.timeValueSeconds(SUGGEST_API_TIMEOUT_SECONDS) + ); + + adClient.suggestAnomalyDetector(suggestRequest, ActionListener.wrap(suggestResp -> { + AnomalyDetector optimized = suggestResp.getInterval() != null ? applySuggestionsToDetector(detector, suggestResp) : detector; + createAndStartOtelDetector(optimized, specs, indexName, idx, results, listener); + }, e -> { + log.warn("Suggest API failed for OTel detector '{}', using default interval: {}", detector.getName(), e.getMessage()); + createAndStartOtelDetector(detector, specs, indexName, idx, results, listener); + })); + } + + @SuppressWarnings("unchecked") + private void createAndStartOtelDetector( + AnomalyDetector detector, + List specs, + String indexName, + int idx, + List> results, + ActionListener listener + ) { + // Validate config before creating — catches invalid time/category/feature fields + callValidationAPI(detector, "detector", new ActionListener() { + @Override + public void onResponse(ValidateConfigResponse response) { + if (response.getIssue() != null) { + String error = response.getIssue().getMessage(); + log.warn("OTel detector validation failed for '{}': {}", detector.getName(), error); + results.add(DetectorResult.failedValidation(indexName, error).toMap()); + createOtelDetectorSequentially(specs, indexName, idx + 1, results, listener); + return; + } + doCreateAndStartOtelDetector(detector, specs, indexName, idx, results, listener); + } + + @Override + public void onFailure(Exception e) { + log.warn("OTel detector validation API failed for '{}', proceeding anyway: {}", detector.getName(), e.getMessage()); + doCreateAndStartOtelDetector(detector, specs, indexName, idx, results, listener); + } + }); + } + + @SuppressWarnings("unchecked") + private void doCreateAndStartOtelDetector( + AnomalyDetector detector, + List specs, + String indexName, + int idx, + List> results, + ActionListener listener + ) { + String detectorName = detector.getName(); + IndexAnomalyDetectorRequest createReq = new IndexAnomalyDetectorRequest("", detector, RestRequest.Method.POST); + adClient.createAnomalyDetector(createReq, ActionListener.wrap(createResp -> { + String detectorId = createResp.getId(); + log.info("OTel detector created: {} ({})", detectorName, detectorId); + JobRequest startReq = new JobRequest( + detectorId, + ".opendistro-anomaly-detectors", + null, + false, + "/_plugins/_anomaly_detection/detectors/" + detectorId + "/_start" + ); + adClient.startAnomalyDetector(startReq, ActionListener.wrap(startResp -> { + results + .add( + DetectorResult + .success(indexName, detectorId, detectorName, "Detector created successfully", "Detector started successfully") + .toMap() + ); + createOtelDetectorSequentially(specs, indexName, idx + 1, results, listener); + }, e -> { + log.error("Failed to start OTel detector {}: {}", detectorName, e.getMessage()); + results.add(DetectorResult.failedStart(indexName, detectorId, e.getMessage()).toMap()); + createOtelDetectorSequentially(specs, indexName, idx + 1, results, listener); + })); + }, e -> { + log.error("Failed to create OTel detector {}: {}", detectorName, e.getMessage()); + results.add(DetectorResult.failedCreate(indexName, e.getMessage()).toMap()); + createOtelDetectorSequentially(specs, indexName, idx + 1, results, listener); + })); + } + + /** + * Get Index Insight analysis for the given index using ALL type. + * Returns null content on empty response; calls onFailure if API is unavailable. + */ + private void getIndexInsight(String indexName, String tenantId, ActionListener listener) { + log.info("Fetching Index Insight for index '{}'", indexName); + + MLIndexInsightGetRequest request = new MLIndexInsightGetRequest(indexName, MLIndexInsightType.ALL, tenantId); + + client.execute(MLIndexInsightGetAction.INSTANCE, request, ActionListener.wrap(response -> { + IndexInsight insight = response.getIndexInsight(); + String content = insight != null ? insight.getContent() : null; + + if (content != null && !content.isEmpty()) { + log.info("Index Insight for '{}': {} chars", indexName, content.length()); + listener.onResponse(content); + } else { + log.warn("Index Insight returned empty content for '{}'", indexName); + listener.onResponse(null); + } + }, e -> { + log.warn("Index Insight API call failed for '{}': {} ({})", indexName, e.getMessage(), e.getClass().getSimpleName()); + listener.onFailure(e); + })); + } + + /** + * Run a sampler aggregation with not_null filters per field to drop fields that are + * null in >99.9% of sampled docs. Same approach as Index Insight's StatisticalDataTask. + */ + private void filterNullFields(String indexName, Map mapping, ActionListener> listener) { + AggregatorFactories.Builder filters = new AggregatorFactories.Builder(); + // Use index-based names to avoid collisions from dot-to-underscore replacement + List fieldOrder = new ArrayList<>(mapping.keySet()); + for (int i = 0; i < fieldOrder.size(); i++) { + filters.addAggregator(AggregationBuilders.filter("f_" + i, QueryBuilders.existsQuery(fieldOrder.get(i)))); + } + SamplerAggregationBuilder sampler = AggregationBuilders + .sampler("sample") + .shardSize(FIELD_FILTER_SAMPLE_SIZE) + .subAggregations(filters); + + SearchRequest request = new SearchRequest(indexName) + .source(new SearchSourceBuilder().size(0).query(QueryBuilders.matchAllQuery()).aggregation(sampler)); + + client.search(request, ActionListener.wrap(response -> { + InternalSampler sampleAgg = (InternalSampler) response.getAggregations().getAsMap().get("sample"); + long totalDocs = sampleAgg.getDocCount(); + if (totalDocs == 0) { + listener.onResponse(mapping); + return; + } + Map aggMap = sampleAgg.getAggregations().getAsMap(); + Map result = new LinkedHashMap<>(); + for (int i = 0; i < fieldOrder.size(); i++) { + Aggregation agg = aggMap.get("f_" + i); + if (agg instanceof InternalFilter) { + long docCount = ((InternalFilter) agg).getDocCount(); + if (docCount >= FIELD_NULL_THRESHOLD * totalDocs) { + String fieldName = fieldOrder.get(i); + result.put(fieldName, mapping.get(fieldName)); + } + } + } + // Safety: if filter is too aggressive, fall back to original + listener.onResponse(result.size() >= 5 ? result : mapping); + }, e -> { + log.warn("Sampler aggregation failed for '{}': {}", indexName, e.getMessage()); + listener.onResponse(mapping); + })); + } + + private void getSampleDocuments(String indexName, int size, ActionListener listener) { + SearchRequest request = new SearchRequest(indexName) + .source(new SearchSourceBuilder().size(size).query(QueryBuilders.matchAllQuery()).sort("_doc").trackTotalHits(false)); + client.search(request, ActionListener.wrap(response -> { + SearchHit[] hits = response.getHits().getHits(); + if (hits.length == 0) { + listener.onResponse(null); + return; + } + List> docs = new ArrayList<>(); + for (SearchHit hit : hits) { + docs.add(hit.getSourceAsMap()); + } + listener.onResponse(gson.toJson(docs)); + }, e -> { + log.warn("Failed to fetch sample docs for '{}': {}", indexName, e.getMessage()); + listener.onResponse(null); + })); + } + + private void getDataDensity(String indexName, String dateField, ActionListener listener) { + SearchRequest request = new SearchRequest(indexName) + .source(new SearchSourceBuilder().size(0).query(QueryBuilders.rangeQuery(dateField).gte("now-24h")).trackTotalHits(true)); + client.search(request, ActionListener.wrap(response -> listener.onResponse(response.getHits().getTotalHits().value()), e -> { + log.warn("Failed to get data density for '{}': {}", indexName, e.getMessage()); + listener.onResponse(-1L); + })); + } + + /** + * Filter date fields to only those with data in the last 30 days. + * The LLM picks the semantically best date field from the survivors. + * Falls back to all date fields if none have recent data. + */ + private static final int MAX_DATE_FIELDS_TO_CHECK = 10; + + private void filterDateFieldsByDensity(MappingContext ctx, ActionListener listener) { + if (ctx.dateFields.size() <= 1) { + listener.onResponse(ctx); + return; + } + List dateFieldList = new ArrayList<>(ctx.dateFields); + if (dateFieldList.size() > MAX_DATE_FIELDS_TO_CHECK) { + dateFieldList = dateFieldList.subList(0, MAX_DATE_FIELDS_TO_CHECK); + } + Map counts = new HashMap<>(); + checkNextDateField(ctx, dateFieldList, 0, counts, listener); + } + + private void checkNextDateField( + MappingContext ctx, + List dateFieldList, + int idx, + Map counts, + ActionListener listener + ) { + if (idx >= dateFieldList.size()) { + // All fields checked — filter to those with data + Set viable = dateFieldList + .stream() + .filter(f -> counts.getOrDefault(f, 0L) > 0) + .collect(Collectors.toCollection(java.util.LinkedHashSet::new)); + if (viable.isEmpty()) { + log.info("No date fields with recent data for '{}', keeping all", ctx.indexName); + listener.onResponse(ctx); + } else { + log.info("Filtered date fields for '{}': {} → {}", ctx.indexName, ctx.dateFields, viable); + listener + .onResponse( + new MappingContext( + ctx.indexName, + ctx.filteredMapping, + viable, + ctx.indexInsight, + ctx.detectorDecisions, + ctx.sampleDocs, + ctx.dataDensity24h + ) + ); + } + return; + } + String dateField = dateFieldList.get(idx); + SearchRequest request = new SearchRequest(ctx.indexName) + .source( + new SearchSourceBuilder() + .size(0) + .query(QueryBuilders.rangeQuery(dateField).gte(DATE_FIELD_LOOKBACK_PERIOD)) + .timeout(TimeValue.timeValueSeconds(DATE_FIELD_QUERY_TIMEOUT_SECONDS)) + .trackTotalHits(true) + ); + client.search(request, ActionListener.wrap(response -> { + counts.put(dateField, response.getHits().getTotalHits().value()); + checkNextDateField(ctx, dateFieldList, idx + 1, counts, listener); + }, e -> { + counts.put(dateField, 0L); + checkNextDateField(ctx, dateFieldList, idx + 1, counts, listener); + })); + } + + private String extractResponseFromDataAsMap(Map dataAsMap) { + if (dataAsMap == null) { + return null; + } + if (dataAsMap.containsKey("response")) { + return (String) dataAsMap.get("response"); + } + // Bedrock and Claude both end with content[0].text — just navigate to the content array + List> content = null; + if (dataAsMap.containsKey("output")) { + try { + Map output = (Map) dataAsMap.get("output"); + Map message = (Map) output.get("message"); + content = (List>) message.get("content"); + } catch (Exception e) { + log.error("Failed to parse Bedrock response format", e); + } + } else if (dataAsMap.containsKey("content")) { + try { + content = (List>) dataAsMap.get("content"); + } catch (Exception e) { + log.error("Failed to parse Claude content format", e); + } + } + if (content != null && !content.isEmpty()) { + return (String) content.getFirst().get("text"); + } + log.error("Unknown or unparseable response format. Available keys: {}", dataAsMap.keySet()); + return null; + } + + @SuppressWarnings("unchecked") + private static Map loadDefaultPromptFromFile() { + try ( + InputStream inputStream = CreateAnomalyDetectorToolEnhanced.class + .getResourceAsStream("CreateAnomalyDetectorEnhancedPrompt.json") + ) { + if (inputStream != null) { + Map raw = gson.fromJson(new String(inputStream.readAllBytes(), StandardCharsets.UTF_8), Map.class); + String base = raw.getOrDefault("prompt", ""); + if (!base.isEmpty()) { + // Generate both variants from single base prompt + Map result = new HashMap<>(); + result.put("OPENAI", base); + result.put("CLAUDE", "\n\nHuman: " + base + "\n\nAssistant:"); + return result; + } + return raw; // fallback: legacy format with CLAUDE/OPENAI keys + } + } catch (IOException e) { + log.error("Failed to load prompt from the file CreateAnomalyDetectorEnhancedPrompt.json, error: ", e); + } + return new HashMap<>(); + } + + /** + * + * @param fieldsToType the flattened field-> field type mapping + * @param indexName the index name + * @param dateFields the comma-separated date fields + * @param indexInsight the index insight analysis (can be null) + * @return the prompt about creating anomaly detector + */ + private String constructPrompt( + final Map fieldsToType, + final String indexName, + final String dateFields, + final String indexInsight + ) { + Map fields = fieldsToType; + if (fields.size() > MAX_PROMPT_FIELDS) { + log.info("Truncating mapping from {} to {} fields for LLM prompt", fields.size(), MAX_PROMPT_FIELDS); + fields = fields.entrySet().stream().limit(MAX_PROMPT_FIELDS).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + StringJoiner tableInfoJoiner = new StringJoiner("\n"); + for (Map.Entry entry : fields.entrySet()) { + tableInfoJoiner.add("- " + entry.getKey() + ": " + entry.getValue()); + } + + String insightSection = ""; + if (!Strings.isNullOrEmpty(indexInsight)) { + insightSection = "\n\nINDEX ANALYSIS (from Index Insight):\n" + + indexInsight + + "\n\nUse the above analysis to inform your detector configuration choices."; + } + + Map indexInfo = ImmutableMap + .of( + "indexName", + indexName, + "indexMapping", + tableInfoJoiner.toString(), + "dateFields", + dateFields, + "indexInsight", + insightSection + ); + StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}"); + + String basePrompt = substitutor.replace(contextPrompt); + // If prompt template doesn't have ${indexInfo.indexInsight} placeholder, append insight before OUTPUT FORMAT + if (!contextPrompt.contains("${indexInfo.indexInsight}") && !insightSection.isEmpty()) { + int outputFormatIndex = basePrompt.indexOf("OUTPUT FORMAT:"); + if (outputFormatIndex > 0) { + basePrompt = basePrompt.substring(0, outputFormatIndex) + insightSection + "\n\n" + basePrompt.substring(outputFormatIndex); + } else { + basePrompt = basePrompt + insightSection; + } + } + + return basePrompt; + } + + @Override + public boolean validate(Map parameters) { + return parameters != null && !parameters.isEmpty(); + } + + @Override + public String getType() { + return TYPE; + } + + /** + * Calculate a jittered frequency as a random multiple of interval between 2×interval and 24h. + * Spreads detector queries across time to avoid concurrent load spikes. + */ + private static int calculateFrequencyMinutes(int intervalMinutes) { + int minFreq = intervalMinutes * 2; + if (minFreq >= MAX_FREQUENCY_MINUTES) + return MAX_FREQUENCY_MINUTES; + int multiples = (MAX_FREQUENCY_MINUTES - minFreq) / intervalMinutes; + return minFreq + Randomness.get().nextInt(multiples + 1) * intervalMinutes; + } + + private AnomalyDetector buildAnomalyDetectorFromSuggestions(Map suggestions) { + return buildAnomalyDetectorFromSuggestions(suggestions, null); + } + + private AnomalyDetector buildAnomalyDetectorFromSuggestions(Map suggestions, QueryBuilder filterQuery) { + String indexName = suggestions.get(OUTPUT_KEY_INDEX); + String categoryField = suggestions.get(OUTPUT_KEY_CATEGORY_FIELD); + String aggregationFields = suggestions.get(OUTPUT_KEY_AGGREGATION_FIELD); + String aggregationMethods = suggestions.get(OUTPUT_KEY_AGGREGATION_METHOD); + String dateFields = suggestions.get(OUTPUT_KEY_DATE_FIELDS); + String intervalStr = suggestions.getOrDefault("interval", String.valueOf(DEFAULT_INTERVAL_MINUTES)); + + // Parse filter from suggestions if present (from LLM output) + QueryBuilder featureFilter = filterQuery; + if (featureFilter == null) { + String filterExpr = suggestions.get("filter"); + featureFilter = parseFilterExpression(filterExpr); + } + + // Parse interval (default to 10 minutes) + int intervalMinutes = DEFAULT_INTERVAL_MINUTES; + try { + intervalMinutes = Integer.parseInt(intervalStr); + } catch (NumberFormatException e) { + log.warn("Invalid interval '{}', using default {} minutes", intervalStr, DEFAULT_INTERVAL_MINUTES); + } + + // Parse comma-separated fields and methods + String[] fields = aggregationFields.split(","); + String[] methods = aggregationMethods.split(","); + + if (fields.length != methods.length) { + throw new IllegalArgumentException("Number of aggregation fields and methods must match"); + } + + // Determine if this is an HC detector (has category field) + categoryField = categoryField != null ? categoryField.trim() : ""; + boolean isHC = !categoryField.isEmpty() && !categoryField.equalsIgnoreCase("null") && !categoryField.equalsIgnoreCase("none"); + + List features = new ArrayList<>(); + boolean filterAppliedInFeature = false; + for (int i = 0; i < fields.length; i++) { + String field = fields[i].trim(); + String method = methods[i].trim(); + + if (field.isEmpty() || method.isEmpty()) { + continue; + } + + String cleanField = field.startsWith("feature_") ? field.substring(8) : field; + // Handle template variable leak - LLM sometimes outputs literal template variables + String actualDateField = dateFields.split(",")[0].trim(); + cleanField = cleanField.replace("${dateFields}", actualDateField).replace("${indexInfo.dateFields}", actualDateField); + + AggregationBuilder innerAgg = AnomalyDetectorToolHelper.createAggregationBuilder(method, cleanField); + // Filter-in-feature: for HC+count, apply to the first count feature found. + // This lets HC entities with 0 matches still get a model (0→N detection). + AggregationBuilder featureAgg; + if (featureFilter != null && isHC && !filterAppliedInFeature && "count".equalsIgnoreCase(method)) { + featureAgg = AggregationBuilders + .filter("feature_" + cleanField + "_" + method + "_filter", featureFilter) + .subAggregation(innerAgg); + filterAppliedInFeature = true; + } else { + featureAgg = innerAgg; + } + Feature feature = new Feature(UUIDs.randomBase64UUID(), "feature_" + cleanField + "_" + method, true, featureAgg); + features.add(feature); + } + + if (features.isEmpty()) { + throw new IllegalArgumentException( + "No valid features could be built from LLM suggestions. " + + "Fields: [" + + aggregationFields + + "], Methods: [" + + aggregationMethods + + "]" + ); + } + + List categoryFields = null; + if (isHC) { + categoryFields = java.util.Arrays + .stream(categoryField.split(",")) + .map(String::trim) + .filter(s -> !s.isEmpty()) + .collect(Collectors.toList()); + if (categoryFields.size() > 2) { + categoryFields = categoryFields.subList(0, 2); + } + } + + // If filter was applied inside feature agg, detector-level is matchAll. + // Otherwise, put filter on detector level. + QueryBuilder detectorFilter = (featureFilter != null && !filterAppliedInFeature) ? featureFilter : QueryBuilders.matchAllQuery(); + + String timeField = dateFields.split(",")[0].trim(); + + String nameSuffix = "-detector-" + UUIDs.randomBase64UUID().substring(0, 8); + String truncatedIndex = indexName.substring(0, Math.min(indexName.length(), MAX_DETECTOR_NAME_LENGTH - nameSuffix.length())); + + String description = suggestions.getOrDefault("description", DEFAULT_DETECTOR_DESCRIPTION); + if (description.isEmpty()) + description = DEFAULT_DETECTOR_DESCRIPTION; + + return new AnomalyDetector( + null, + null, + truncatedIndex + nameSuffix, + description, + timeField, + List.of(indexName), + features, + detectorFilter, + new IntervalTimeConfiguration(intervalMinutes, ChronoUnit.MINUTES), + new IntervalTimeConfiguration(DEFAULT_WINDOW_DELAY_MINUTES, ChronoUnit.MINUTES), + DEFAULT_SHINGLE_SIZE, + null, + DEFAULT_SCHEMA_VERSION, + Instant.now(), + categoryFields, + null, + customResultIndex, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + new IntervalTimeConfiguration(calculateFrequencyMinutes(intervalMinutes), ChronoUnit.MINUTES), + true + ); + } + + private void callLLM(String prompt, String tenantId, ActionListener listener) { + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Collections.singletonMap("prompt", prompt)) + .build(); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), + null, + tenantId + ); + + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> { + ModelTensorOutput output = (ModelTensorOutput) mlTaskResponse.getOutput(); + List outputList = output != null ? output.getMlModelOutputs() : null; + if (outputList == null || outputList.isEmpty()) { + listener.onFailure(new IllegalStateException("Remote endpoint returned empty output.")); + return; + } + List tensorList = outputList.get(0).getMlModelTensors(); + if (tensorList == null || tensorList.isEmpty()) { + listener.onFailure(new IllegalStateException("Remote endpoint returned empty tensors.")); + return; + } + Map dataAsMap = (Map) tensorList.get(0).getDataAsMap(); + String response = extractResponseFromDataAsMap(dataAsMap); + + if (Strings.isNullOrEmpty(response)) { + listener.onFailure(new IllegalStateException("Remote endpoint fails to inference.")); + } else { + listener.onResponse(response); + } + }, listener::onFailure)); + } + + private void respondWithError(ActionListener listener, String indexName, String status, String errorMessage) { + String errorId = UUIDs.randomBase64UUID().substring(0, 8); + log.error("Detector operation failed [{}] - Index: {}, Status: {}, Error: {}", errorId, indexName, status, errorMessage); + + DetectorResult result; + switch (status) { + case "validation": + result = DetectorResult.failedValidation(indexName, errorMessage); + break; + case "create": + result = DetectorResult.failedCreate(indexName, errorMessage); + break; + case "start": + result = DetectorResult.failedStart(indexName, null, errorMessage); + break; + default: + listener.onFailure(new IllegalArgumentException("Unknown error status: " + status)); + return; + } + listener.onResponse((T) gson.toJson(List.of(result.toMap()))); + } + + // Note: retry prompts intentionally omit advanced rules (e.g., filter-first-feature for HC). + // Retry prompts should be focused on fixing the specific error — adding all rules creates + // cognitive overload and distracts the LLM from the fix. The main prompt covers these rules. + private void retryWithFormatFix( + String parseError, + String indexName, + String dateFields, + MappingContext mappingContext, + String tenantId, + int maxRetries, + int currentRetry, + String validationType, + ActionListener listener, + java.util.function.BiConsumer, ActionListener> nextPhaseCallback + ) { + String contextPrefix = mappingContext != null ? buildRetryContext(mappingContext) : ""; + String fixPrompt = contextPrefix + + "The previous response had incorrect format. " + + parseError + + "\n\nPlease provide suggestions for index '" + + indexName + + "' in the exact format: " + + LLM_OUTPUT_FORMAT + + "\n\nInterval should be in minutes (default: 10). Only return the configuration in curly braces."; + + callLLM(fixPrompt, tenantId, ActionListener.wrap(fixedResponse -> { + parseAndRetryWithLLM( + fixedResponse, + indexName, + dateFields, + mappingContext, + tenantId, + maxRetries, + currentRetry, + validationType, + listener, + nextPhaseCallback + ); + }, e -> { + log.error("LLM fix request failed: {}", e.getMessage()); + listener.onFailure(e); + })); + } + + private void callValidationAPI(AnomalyDetector detector, String validationType, ActionListener listener) { + try { + ValidateConfigRequest validateRequest = new ValidateConfigRequest(AnalysisType.AD, detector, validationType); + adClient.validateAnomalyDetector(validateRequest, listener); + } catch (Throwable e) { + log.error("Validation API call failed: {}", e.getMessage(), e); + listener.onFailure(new RuntimeException("Validation API failed: " + e.getMessage(), e)); + } + } + + private void parseAndRetryWithLLM( + String llmResponse, + String indexName, + String dateFields, + MappingContext mappingContext, + String tenantId, + int maxRetries, + int currentRetry, + String validationType, + ActionListener listener, + java.util.function.BiConsumer, ActionListener> nextPhaseCallback + ) { + log.debug("LLM_RESPONSE,index={},retry={},response={}", indexName, currentRetry, llmResponse); + + // Parse {key=value|key=value|...} format into a map + Map parsed = parseLLMResponse(llmResponse); + + if (parsed == null) { + log.error("Parsing failed for response: {}", llmResponse); + if (currentRetry < maxRetries) { + String parseError = "Cannot parse response format. Expected: " + LLM_OUTPUT_FORMAT; + retryWithFormatFix( + parseError, + indexName, + dateFields, + mappingContext, + tenantId, + maxRetries, + currentRetry + 1, + validationType, + listener, + nextPhaseCallback + ); + } else { + listener.onFailure(new IllegalStateException("Cannot parse LLM response after " + maxRetries + " retries")); + } + return; + } + + String categoryField = parsed.getOrDefault("category_field", "").strip(); + String aggregationField = parsed.getOrDefault("aggregation_field", "").strip(); + String aggregationMethod = parsed.getOrDefault("aggregation_method", "").strip(); + final String parsedFilter = parsed.getOrDefault("filter", "").strip(); + final String parsedDescription = parsed.getOrDefault("description", "").strip(); + final String parsedInterval = parsed.getOrDefault("interval", String.valueOf(DEFAULT_INTERVAL_MINUTES)).strip(); + + // Use LLM's date field choice; validate against known fields, fall back to first if invalid + String parsedDateField = parsed.getOrDefault("date_field", "").strip(); + Set knownDateFields = java.util.Arrays.stream(dateFields.split(",")).map(String::trim).collect(Collectors.toSet()); + String selectedDateField = (!parsedDateField.isEmpty() && knownDateFields.contains(parsedDateField)) + ? parsedDateField + : dateFields.split(",")[0].trim(); + + Map suggestions = new HashMap<>( + Map + .of( + OUTPUT_KEY_INDEX, + indexName, + OUTPUT_KEY_CATEGORY_FIELD, + categoryField, + OUTPUT_KEY_AGGREGATION_FIELD, + aggregationField, + OUTPUT_KEY_AGGREGATION_METHOD, + aggregationMethod, + OUTPUT_KEY_DATE_FIELDS, + selectedDateField, + "interval", + parsedInterval + ) + ); + if (!parsedFilter.isEmpty()) { + suggestions.put("filter", parsedFilter); + } + if (!parsedDescription.isEmpty()) { + suggestions.put("description", parsedDescription); + } + nextPhaseCallback.accept(suggestions, listener); + } + + // Checks if detector config is valid (fields exist, aggregations work, etc.) + private void validateDetectorPhase( + Map suggestions, + MappingContext mappingContext, + String tenantId, + int maxRetries, + int currentRetry, + ActionListener listener + ) { + if (currentRetry >= MAX_DETECTOR_VALIDATION_RETRIES) { + listener.onFailure(new RuntimeException("Detector validation failed after " + MAX_DETECTOR_VALIDATION_RETRIES + " retries")); + return; + } + try { + log.info("Validating detector configuration"); + AnomalyDetector detector = buildAnomalyDetectorFromSuggestions(suggestions); + + callValidationAPI(detector, "detector", new ActionListener() { + @Override + public void onResponse(ValidateConfigResponse response) { + if (response.getIssue() != null) { + String errorMessage = response.getIssue().getMessage(); + String issueType = response.getIssue().getType().toString(); + + // GENERAL_SETTINGS = max detectors reached for cluster + if ("GENERAL_SETTINGS".equals(issueType)) { + log.error("System limit error (non-retryable): {}", errorMessage); + respondWithError(listener, suggestions.get(OUTPUT_KEY_INDEX), "validation", errorMessage); + return; + } + if (currentRetry < MAX_DETECTOR_VALIDATION_RETRIES) { + retryDetectorValidation( + suggestions, + mappingContext, + errorMessage, + tenantId, + maxRetries, + currentRetry + 1, + listener + ); + } else { + log.error("Max detector validation retries reached: {}", errorMessage); + listener + .onFailure( + new RuntimeException( + "Detector validation failed after " + MAX_DETECTOR_VALIDATION_RETRIES + " retries: " + errorMessage + ) + ); + } + return; + } + suggestHyperParametersPhase(detector, mappingContext, tenantId, maxRetries, listener); + } + + @Override + public void onFailure(Exception e) { + log.error("Detector validation API failed: {}", e.getMessage(), e); + if (currentRetry < MAX_DETECTOR_VALIDATION_RETRIES) { + retryDetectorValidation( + suggestions, + mappingContext, + e.getMessage(), + tenantId, + maxRetries, + currentRetry + 1, + listener + ); + } else { + listener + .onFailure( + new RuntimeException( + "Detector validation failed after " + MAX_DETECTOR_VALIDATION_RETRIES + " retries: " + e.getMessage() + ) + ); + } + } + }); + + } catch (Exception e) { + log.error("Error building detector: {}", e.getMessage(), e); + if (currentRetry < MAX_FORMAT_FIX_RETRIES) { + retryDetectorValidation(suggestions, mappingContext, e.getMessage(), tenantId, maxRetries, currentRetry + 1, listener); + } else { + listener.onFailure(e); + } + } + } + + private void retryDetectorValidation( + Map originalSuggestions, + MappingContext mappingContext, + String validationError, + String tenantId, + int maxRetries, + int currentRetry, + ActionListener listener + ) { + // Record this failed attempt for retry memory + MappingContext updatedCtx = mappingContext.withDecision(buildAttemptSummary(originalSuggestions, validationError)); + + String fixPrompt = createFixPrompt(originalSuggestions, validationError, updatedCtx.detectorDecisions); + String fullPrompt = buildRetryContext(updatedCtx) + fixPrompt; + + log + .info( + "LLM_FIX_PROMPT,index={},retry={},error={}", + updatedCtx.indexName, + currentRetry, + validationError != null ? validationError.substring(0, Math.min(200, validationError.length())) : "unknown error" + ); + + callLLM(fullPrompt, tenantId, ActionListener.wrap(fixedResponse -> { + parseAndRetryWithLLM( + fixedResponse, + originalSuggestions.get(OUTPUT_KEY_INDEX), + originalSuggestions.get(OUTPUT_KEY_DATE_FIELDS), + updatedCtx, + tenantId, + maxRetries, + currentRetry, + "detector", + listener, + (suggestions, listenerCallback) -> validateDetectorPhase( + suggestions, + updatedCtx, + tenantId, + maxRetries, + currentRetry, + listenerCallback + ) + ); + }, e -> { + log.error("LLM fix request failed: {}", e.getMessage()); + listener.onFailure(e); + })); + } + + /** Build mapping context string for retry prompts. Shared by detector and model validation retries. */ + private static String buildRetryContext(MappingContext ctx) { + StringBuilder sb = new StringBuilder(); + sb.append("You are creating an anomaly detector for the following index:\n\n"); + sb.append("Index: ").append(ctx.indexName).append("\n"); + sb.append("Available fields:\n"); + for (Map.Entry field : ctx.filteredMapping.entrySet()) { + sb.append("- ").append(field.getKey()).append(": ").append(field.getValue()).append("\n"); + } + sb.append("Available date fields: ").append(String.join(", ", ctx.dateFields)).append("\n"); + if (ctx.dataDensity24h >= 0) { + sb.append("Data density: ").append(ctx.dataDensity24h).append(" documents in the last 24 hours\n"); + } + if (ctx.sampleDocs != null) { + String truncated = ctx.sampleDocs.length() > 1000 ? ctx.sampleDocs.substring(0, 1000) + "..." : ctx.sampleDocs; + sb.append("Sample documents: ").append(truncated).append("\n"); + } + sb.append("\n"); + return sb.toString(); + } + + private AnomalyDetector applySuggestionsToDetector(AnomalyDetector originalDetector, SuggestConfigParamResponse response) { + TimeConfiguration newInterval = originalDetector.getInterval(); + TimeConfiguration newWindowDelay = originalDetector.getWindowDelay(); + Integer newHistoryIntervals = originalDetector.getHistoryIntervals(); + if (response.getInterval() != null) { + newInterval = response.getInterval(); + } + if (response.getWindowDelay() != null) { + newWindowDelay = response.getWindowDelay(); + } + if (response.getHistory() != null) { + newHistoryIntervals = response.getHistory(); + } + // Recalculate frequency when interval changes to maintain kaituo's invariant: frequency ≥ 2×interval + TimeConfiguration newFrequency = originalDetector.getFrequency(); + if (response.getInterval() != null) { + int mins = (int) ((IntervalTimeConfiguration) newInterval).toDuration().toMinutes(); + newFrequency = new IntervalTimeConfiguration(calculateFrequencyMinutes(mins), ChronoUnit.MINUTES); + } + // Create new detector with applied suggestions + return new AnomalyDetector( + originalDetector.getId(), + originalDetector.getVersion(), + originalDetector.getName(), + originalDetector.getDescription(), + originalDetector.getTimeField(), + originalDetector.getIndices(), + originalDetector.getFeatureAttributes(), + originalDetector.getFilterQuery(), + newInterval, + newWindowDelay, + originalDetector.getShingleSize(), + originalDetector.getUiMetadata(), + originalDetector.getSchemaVersion(), + originalDetector.getLastUpdateTime(), + originalDetector.getCategoryFields(), + originalDetector.getUser(), + originalDetector.getCustomResultIndexOrAlias(), + originalDetector.getImputationOption(), + originalDetector.getRecencyEmphasis(), + originalDetector.getSeasonIntervals(), + newHistoryIntervals, + originalDetector.getRules(), + originalDetector.getCustomResultIndexMinSize(), + originalDetector.getCustomResultIndexMinAge(), + originalDetector.getCustomResultIndexTTL(), + originalDetector.getFlattenResultIndexMapping(), + null, + newFrequency, + originalDetector.getAutoCreated() + ); + } + + /** + * Apply a structured interval or window delay suggestion from model validation directly, + * avoiding an unnecessary LLM round-trip for issues the validation API already solved. + */ + private AnomalyDetector applyIntervalSuggestion(AnomalyDetector detector, IntervalTimeConfiguration suggestion, String issueType) { + TimeConfiguration newInterval = detector.getInterval(); + TimeConfiguration newWindowDelay = detector.getWindowDelay(); + TimeConfiguration newFrequency = detector.getFrequency(); + if ("window_delay".equals(issueType)) { + newWindowDelay = suggestion; + } else { + newInterval = suggestion; + int mins = (int) suggestion.toDuration().toMinutes(); + newFrequency = new IntervalTimeConfiguration(calculateFrequencyMinutes(mins), ChronoUnit.MINUTES); + } + return new AnomalyDetector( + detector.getId(), + detector.getVersion(), + detector.getName(), + detector.getDescription(), + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + newInterval, + newWindowDelay, + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + detector.getLastUpdateTime(), + detector.getCategoryFields(), + detector.getUser(), + detector.getCustomResultIndexOrAlias(), + detector.getImputationOption(), + detector.getRecencyEmphasis(), + detector.getSeasonIntervals(), + detector.getHistoryIntervals(), + detector.getRules(), + detector.getCustomResultIndexMinSize(), + detector.getCustomResultIndexMinAge(), + detector.getCustomResultIndexTTL(), + detector.getFlattenResultIndexMapping(), + null, + newFrequency, + detector.getAutoCreated() + ); + } + + // Use AD suggest API to find better interval/window-delay/history based on actual data density. + // Intentionally overrides LLM's interval — LLM guesses from field names, suggest API uses real data. + private void suggestHyperParametersPhase( + AnomalyDetector detector, + MappingContext mappingContext, + String tenantId, + int maxRetries, + ActionListener listener + ) { + log.info("Starting suggest api step"); + + // Create suggest request for interval, history, window_delay + SuggestConfigParamRequest suggestRequest = new SuggestConfigParamRequest( + AnalysisType.AD, + detector, + "interval,history,window_delay", + TimeValue.timeValueSeconds(SUGGEST_API_TIMEOUT_SECONDS) + ); + + adClient.suggestAnomalyDetector(suggestRequest, new ActionListener() { + @Override + public void onResponse(SuggestConfigParamResponse response) { + try { + AnomalyDetector optimizedDetector = applySuggestionsToDetector(detector, response); + validateModelPhase(optimizedDetector, mappingContext, tenantId, maxRetries, 0, List.of(), listener); + + } catch (Exception e) { + validateModelPhase(detector, mappingContext, tenantId, maxRetries, 0, List.of(), listener); + } + } + + @Override + public void onFailure(Exception e) { + // Continue to model validation with original detector even if suggest fails + validateModelPhase(detector, mappingContext, tenantId, maxRetries, 0, List.of(), listener); + } + }); + } + + // Checks if there's enough data to train the model (final validation before creating detector) + private void validateModelPhase( + AnomalyDetector detector, + MappingContext mappingContext, + String tenantId, + int maxRetries, + int currentRetry, + List previousAttempts, + ActionListener listener + ) { + log.info("Starting model validation"); + + callValidationAPI(detector, "model", new ActionListener() { + @Override + public void onResponse(ValidateConfigResponse response) { + + if (response.getIssue() != null) { + String issueAspect = response.getIssue().getAspect().toString(); + boolean isBlockingError = "DETECTOR".equals(issueAspect); + String errorMessage = response.getIssue().getMessage(); + + // If validation provides a structured interval or window delay suggestion, + // apply it directly and re-validate — no LLM call needed for these. + IntervalTimeConfiguration intervalSuggestion = response.getIssue().getIntervalSuggestion(); + if (intervalSuggestion != null && currentRetry < MAX_MODEL_VALIDATION_RETRIES) { + String issueType = response.getIssue().getType().getName(); + log + .info( + "Applying validation suggestion: {}={} for '{}'", + issueType, + intervalSuggestion, + getDetectorIndex(detector) + ); + AnomalyDetector adjusted = applyIntervalSuggestion(detector, intervalSuggestion, issueType); + validateModelPhase(adjusted, mappingContext, tenantId, maxRetries, currentRetry + 1, previousAttempts, listener); + return; + } + + if (currentRetry < MAX_MODEL_VALIDATION_RETRIES) { + retryModelValidation( + detector, + mappingContext, + errorMessage, + tenantId, + maxRetries, + currentRetry + 1, + previousAttempts, + listener + ); + } else { + // Max retries reached + if (isBlockingError) { + log.error("Max retries reached with blocking validation error: {}", errorMessage); + respondWithError(listener, getDetectorIndex(detector), "validation", errorMessage); + } else { + DetectorResult result = DetectorResult + .failedValidation(getDetectorIndex(detector), "Non-blocking warning: " + errorMessage); + listener.onResponse((T) gson.toJson(List.of(result.toMap()))); + } + } + return; + } + createDetector(detector, listener); + } + + @Override + public void onFailure(Exception e) { + log.error("Model validation API failed: {}", e.getMessage(), e); + if (currentRetry < MAX_MODEL_VALIDATION_RETRIES) { + retryModelValidation( + detector, + mappingContext, + e.getMessage(), + tenantId, + maxRetries, + currentRetry + 1, + previousAttempts, + listener + ); + } else { + DetectorResult result = DetectorResult.failedValidation(getDetectorIndex(detector), "API failure: " + e.getMessage()); + listener.onResponse((T) gson.toJson(List.of(result.toMap()))); + } + } + }); + } + + private void retryModelValidation( + AnomalyDetector detector, + MappingContext mappingContext, + String validationError, + String tenantId, + int maxRetries, + int currentRetry, + List previousAttempts, + ActionListener listener + ) { + Map currentSuggestions = Map + .of( + OUTPUT_KEY_INDEX, + String.join(",", detector.getIndices()), + OUTPUT_KEY_CATEGORY_FIELD, + detector.getCategoryFields() == null || detector.getCategoryFields().isEmpty() + ? "" + : String.join(",", detector.getCategoryFields()), + OUTPUT_KEY_AGGREGATION_FIELD, + detector.getFeatureAttributes().stream().map(this::getFeatureField).collect(java.util.stream.Collectors.joining(",")), + OUTPUT_KEY_AGGREGATION_METHOD, + detector.getFeatureAttributes().stream().map(this::getAggMethod).collect(java.util.stream.Collectors.joining(",")), + OUTPUT_KEY_DATE_FIELDS, + detector.getTimeField(), + "interval", + String.valueOf(detector.getIntervalInMinutes()) + ); + + List updatedAttempts = new ArrayList<>(previousAttempts); + updatedAttempts.add(buildAttemptSummary(currentSuggestions, validationError)); + + String fixPrompt = createFixPrompt(currentSuggestions, validationError, updatedAttempts); + String fullPrompt = buildRetryContext(mappingContext) + fixPrompt; + + callLLM(fullPrompt, tenantId, ActionListener.wrap(fixedResponse -> { + parseAndRetryWithLLM( + fixedResponse, + currentSuggestions.get(OUTPUT_KEY_INDEX), + currentSuggestions.get(OUTPUT_KEY_DATE_FIELDS), + mappingContext, + tenantId, + maxRetries, + currentRetry, + "model", + listener, + (suggestions, listenerCallback) -> { + try { + AnomalyDetector newDetector = buildAnomalyDetectorFromSuggestions(suggestions); + validateModelPhase( + newDetector, + mappingContext, + tenantId, + maxRetries, + currentRetry, + updatedAttempts, + listenerCallback + ); + } catch (Exception e) { + log.error("Error building detector from LLM fix: {}", e.getMessage()); + DetectorResult result = DetectorResult + .failedValidation(getDetectorIndex(detector), "Failed to build detector: " + e.getMessage()); + listenerCallback.onResponse((T) gson.toJson(List.of(result.toMap()))); + } + } + ); + }, e -> { + log.error("LLM fix request failed: {}", e.getMessage()); + DetectorResult result = DetectorResult + .failedValidation(getDetectorIndex(detector), "LLM fix request failed: " + e.getMessage()); + listener.onResponse((T) gson.toJson(List.of(result.toMap()))); + })); + } + + private void createDetector(AnomalyDetector detector, ActionListener listener) { + IndexAnomalyDetectorRequest request = new IndexAnomalyDetectorRequest("", detector, RestRequest.Method.POST); + adClient.createAnomalyDetector(request, new ActionListener() { + @Override + public void onResponse(IndexAnomalyDetectorResponse response) { + String detectorId = response.getId(); + startDetector(getDetectorIndex(detector), detectorId, detector.getName(), listener); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to create detector: {}", e.getMessage(), e); + respondWithError(listener, getDetectorIndex(detector), "create", e.getMessage()); + } + }); + } + + private void startDetector(String indexName, String detectorId, String detectorName, ActionListener listener) { + JobRequest request = new JobRequest( + detectorId, + ".opendistro-anomaly-detectors", + null, + false, + "/_plugins/_anomaly_detection/detectors/" + detectorId + "/_start" + ); + + adClient.startAnomalyDetector(request, ActionListener.wrap(response -> { + DetectorResult result = DetectorResult + .success(indexName, detectorId, detectorName, "Detector created successfully", "Detector started successfully"); + listener.onResponse((T) gson.toJson(List.of(result.toMap()))); + }, e -> { + log.error("Failed to start detector: {}", e.getMessage()); + respondWithError(listener, indexName, "start", e.getMessage()); + })); + } + + String getAggMethod(Feature feature) { + String type = feature.getAggregation().getType(); + return "value_count".equals(type) ? "count" : type; + } + + /** Extract the original index field name from a feature, reversing the "feature_{field}_{method}" naming. */ + private String getFeatureField(Feature feature) { + String name = feature.getName(); + String method = getAggMethod(feature); + String suffix = "_" + method; + if (name.startsWith("feature_") && name.endsWith(suffix)) { + return name.substring("feature_".length(), name.length() - suffix.length()); + } + return name; // fallback + } + + /** Safely get the first index from a detector, avoiding IndexOutOfBoundsException. */ + private static String getDetectorIndex(AnomalyDetector detector) { + List indices = detector.getIndices(); + return (indices != null && !indices.isEmpty()) ? indices.get(0) : "unknown"; + } + + private static final Set KNOWN_LLM_KEYS = Set + .of("category_field", "aggregation_field", "aggregation_method", "filter", "interval", "date_field"); + + /** + * Parse LLM response in {key=value|key=value|...} format into a map. + * Finds the first {...} block containing "category_field", splits on |, splits on =. + * Known keys are parsed by name (order-independent). "description" is the only free-text + * key that may contain | — its value is reconstructed from any non-known-key segments. + * Returns null if no valid block found. + */ + private static Map parseLLMResponse(String llmResponse) { + if (llmResponse == null) + return null; + Matcher braces = BRACES_PATTERN.matcher(llmResponse); + while (braces.find()) { + String content = braces.group(1).strip(); + if (!content.contains("category_field")) + continue; + + Map result = new HashMap<>(); + StringBuilder descBuilder = new StringBuilder(); + boolean inDescription = false; + + for (String part : content.split("\\|")) { + int eq = part.indexOf('='); + String key = eq > 0 ? part.substring(0, eq).strip().replaceAll("\"", "") : ""; + if (KNOWN_LLM_KEYS.contains(key)) { + inDescription = false; + result.put(key, part.substring(eq + 1).strip().replaceAll("\"", "")); + } else if ("description".equals(key)) { + inDescription = true; + descBuilder.append(part.substring(eq + 1).strip()); + } else if (inDescription) { + // Continuation of description value that contained | + descBuilder.append("|").append(part); + } + } + if (descBuilder.length() > 0) { + result.put("description", descBuilder.toString().replaceAll("\"", "").strip()); + } + if (result.containsKey("category_field") && result.containsKey("aggregation_field")) { + return result; + } + } + return null; + } + + /** + * Parse a filter expression (field:operator:value) into a QueryBuilder. + * Returns null if the expression is empty, invalid, or unparseable. + */ + QueryBuilder parseFilterExpression(String filterExpr) { + if (Strings.isNullOrEmpty(filterExpr)) + return null; + String[] parts = filterExpr.split(":", 3); + if (parts.length != 3) { + log.warn("Invalid filter expression '{}', ignoring", filterExpr); + return null; + } + try { + String field = parts[0].trim(); + String operator = parts[1].trim().toLowerCase(Locale.ROOT); + String value = parts[2].trim(); + switch (operator) { + case "gte": + return QueryBuilders.rangeQuery(field).gte(value); + case "gt": + return QueryBuilders.rangeQuery(field).gt(value); + case "lte": + return QueryBuilders.rangeQuery(field).lte(value); + case "lt": + return QueryBuilders.rangeQuery(field).lt(value); + case "eq": + return QueryBuilders.termQuery(field, value); + default: + log.warn("Unknown filter operator '{}', ignoring filter", operator); + return null; + } + } catch (Exception e) { + log.warn("Failed to parse filter expression '{}': {}", filterExpr, e.getMessage()); + return null; + } + } + + private String createFixPrompt(Map originalSuggestions, String validationError) { + return createFixPrompt(originalSuggestions, validationError, List.of()); + } + + private String createFixPrompt(Map originalSuggestions, String validationError, List previousAttempts) { + validationError = validationError != null ? validationError : "unknown error"; + String currentInterval = originalSuggestions.getOrDefault("interval", "10"); + String categoryField = originalSuggestions.get(OUTPUT_KEY_CATEGORY_FIELD); + String filter = originalSuggestions.getOrDefault("filter", ""); + + // Check if this is a sparse data issue with unreasonably high suggested interval (>= 4 hours) + boolean isUnreasonableInterval = false; + if (validationError.contains("interval")) { + java.util.regex.Matcher intervalMatcher = INTERVAL_MINUTES_PATTERN.matcher(validationError); + if (intervalMatcher.find()) { + isUnreasonableInterval = Integer.parseInt(intervalMatcher.group(1)) >= 240; + } + } + + // Check if this is a field type incompatibility error + boolean isFieldTypeError = validationError.contains("not supported for aggregation") + || validationError.contains("Text fields are not optimised"); + + String sparseDataGuidance = ""; + if (isFieldTypeError) { + sparseDataGuidance = "\n**FIELD TYPE ERROR - CRITICAL FIX REQUIRED**:\n" + + "- You selected a keyword/text field with avg/sum/max/min aggregation - THIS WILL NOT WORK\n" + + "- RULE: keyword/text fields can ONLY use 'count' aggregation\n" + + "- RULE: avg/sum/max/min ONLY work on numeric fields (long, integer, double, float)\n" + + "- FIX OPTIONS:\n" + + " 1. Change aggregation method to 'count' for the keyword field, OR\n" + + " 2. Pick a DIFFERENT field that is numeric (long/integer/double/float)\n" + + "- Look at the field types in the mapping and pick numeric fields for avg/sum/max/min\n"; + } else if (isUnreasonableInterval) { + sparseDataGuidance = "\n**HIGH INTERVAL DETECTED**:\n" + + "- Validation suggests interval >4 hours\n" + + "- For operational metrics (latency, errors, CPU, memory): intervals >4 hours are too slow for actionable alerts\n" + + " → PREFERRED: Remove category field entirely (set to empty) to achieve 10-60 min intervals\n" + + " → ALTERNATIVE: Choose different category field with lower cardinality\n" + + "- For business metrics (revenue, sales, users): longer intervals may be acceptable\n" + + " → Consider if the suggested interval fits the use case\n" + + "- Evaluate based on the aggregation fields being monitored\n"; + } else if (validationError.contains("sparse data") || validationError.contains("interval")) { + sparseDataGuidance = "\n**SPARSE DATA GUIDANCE**:\n" + + "- For intervals 60-120 min: acceptable, proceed with suggestion\n" + + "- For intervals >120 min: consider removing category field instead\n"; + } + + return "VALIDATION ERROR: " + + validationError + + "\n\n" + + "Current Configuration:\n" + + "- Category Field: " + + (Strings.isNullOrEmpty(categoryField) ? "NONE" : categoryField) + + "\n" + + "- Aggregation Fields: " + + originalSuggestions.get(OUTPUT_KEY_AGGREGATION_FIELD) + + "\n" + + "- Aggregation Methods: " + + originalSuggestions.get(OUTPUT_KEY_AGGREGATION_METHOD) + + "\n" + + "- Interval: " + + currentInterval + + " minutes\n" + + "- Date Field: " + + originalSuggestions.getOrDefault(OUTPUT_KEY_DATE_FIELDS, "") + + "\n" + + "- Filter: " + + (filter.isEmpty() ? "NONE" : filter) + + "\n" + + sparseDataGuidance + + "\n\nFIX STRATEGY:\n" + + "1. Evaluate aggregation fields: operational metrics need shorter intervals, business metrics can use longer\n" + + "2. For operational metrics with intervals >240 min: REMOVE category field (set to empty string)\n" + + "3. For business metrics: accept suggested interval if appropriate for use case\n" + + "4. For 'invalid query' errors: fix only the problematic field/method\n\n" + + "CRITICAL RULES:\n" + + "- ONLY valid aggregation methods: avg, sum, min, max, count\n" + + "- Keyword fields can ONLY use 'count'\n" + + "- NEVER sum/avg status_code or http_status - use bytes, duration instead\n" + + "- Prefer numeric fields: bytes_sent, total_time, response.bytes, duration\n" + + "- Keep the same aggregation method unless it caused the error\n\n" + + formatPreviousAttempts(previousAttempts) + + "Return ONLY the corrected configuration in this EXACT format:\n" + + LLM_OUTPUT_FORMAT + + "\n\n" + + "Use empty string for category_field if removing it. DO NOT include explanations."; + } + + private static String formatPreviousAttempts(List attempts) { + if (attempts == null || attempts.isEmpty()) + return ""; + StringBuilder sb = new StringBuilder("PREVIOUS ATTEMPTS (do NOT repeat these):\n"); + for (int i = 0; i < attempts.size(); i++) { + sb.append("- Attempt ").append(i + 1).append(": ").append(attempts.get(i)).append("\n"); + } + sb.append("Your response MUST be different from all previous attempts.\n\n"); + return sb.toString(); + } + + private static String buildAttemptSummary(Map suggestions, String error) { + return "category=" + + suggestions.getOrDefault(OUTPUT_KEY_CATEGORY_FIELD, "") + + ", field=" + + suggestions.getOrDefault(OUTPUT_KEY_AGGREGATION_FIELD, "") + + ":" + + suggestions.getOrDefault(OUTPUT_KEY_AGGREGATION_METHOD, "") + + ", interval=" + + suggestions.getOrDefault("interval", "?") + + " → " + + (error != null ? error.substring(0, Math.min(100, error.length())) : "unknown"); + } + + /** + * Context object to hold mapping data and index insight + */ + private static class MappingContext { + final String indexName; + final Map filteredMapping; + final Set dateFields; + final String indexInsight; + final List detectorDecisions; + final String sampleDocs; + final long dataDensity24h; + + MappingContext(String indexName, Map filteredMapping, Set dateFields) { + this(indexName, filteredMapping, dateFields, null, new ArrayList<>(), null, -1L); + } + + MappingContext(String indexName, Map filteredMapping, Set dateFields, String indexInsight) { + this(indexName, filteredMapping, dateFields, indexInsight, new ArrayList<>(), null, -1L); + } + + MappingContext( + String indexName, + Map filteredMapping, + Set dateFields, + String indexInsight, + List detectorDecisions, + String sampleDocs, + long dataDensity24h + ) { + this.indexName = indexName; + this.filteredMapping = filteredMapping; + this.dateFields = dateFields; + this.indexInsight = indexInsight; + this.detectorDecisions = detectorDecisions; + this.sampleDocs = sampleDocs; + this.dataDensity24h = dataDensity24h; + } + + MappingContext withIndexInsight(String insight) { + return new MappingContext(indexName, filteredMapping, dateFields, insight, detectorDecisions, sampleDocs, dataDensity24h); + } + + MappingContext withDecision(String decision) { + List updated = new ArrayList<>(detectorDecisions); + updated.add(decision); + return new MappingContext(indexName, filteredMapping, dateFields, indexInsight, updated, sampleDocs, dataDensity24h); + } + + MappingContext withSampleDocs(String docs) { + return new MappingContext(indexName, filteredMapping, dateFields, indexInsight, detectorDecisions, docs, dataDensity24h); + } + + MappingContext withDataDensity(long density) { + return new MappingContext(indexName, filteredMapping, dateFields, indexInsight, detectorDecisions, sampleDocs, density); + } + } + + /** + * Result object to track detector creation status per index + */ + private enum DetectorStatus { + SUCCESS("success"), + FAILED_VALIDATION("failed_validation"), + FAILED_CREATE("failed_create"), + FAILED_START("failed_start"); + + private final String value; + + DetectorStatus(String value) { + this.value = value; + } + + @Override + public String toString() { + return value; + } + } + + private static class DetectorResult { + String indexName; + DetectorStatus status; + String detectorId; + String detectorName; + String error; + String createResponse; + String startResponse; + + String toJson() { + return gson.toJson(toMap()); + } + + Map toMap() { + Map map = new HashMap<>(); + if (indexName != null) + map.put("indexName", indexName); + if (status != null) + map.put("status", status.toString()); + if (detectorId != null) + map.put("detectorId", detectorId); + if (detectorName != null) + map.put("detectorName", detectorName); + if (error != null) + map.put("error", error); + if (createResponse != null) + map.put("createResponse", createResponse); + if (startResponse != null) + map.put("startResponse", startResponse); + return map; + } + + static DetectorResult failedValidation(String indexName, String error) { + DetectorResult result = new DetectorResult(); + result.indexName = indexName; + result.status = DetectorStatus.FAILED_VALIDATION; + result.error = error; + return result; + } + + static DetectorResult failedCreate(String indexName, String error) { + DetectorResult result = new DetectorResult(); + result.indexName = indexName; + result.status = DetectorStatus.FAILED_CREATE; + result.error = error; + return result; + } + + static DetectorResult failedStart(String indexName, String detectorId, String error) { + DetectorResult result = new DetectorResult(); + result.indexName = indexName; + result.status = DetectorStatus.FAILED_START; + result.detectorId = detectorId; + result.error = error; + return result; + } + + static DetectorResult success( + String indexName, + String detectorId, + String detectorName, + String createResponse, + String startResponse + ) { + DetectorResult result = new DetectorResult(); + result.indexName = indexName; + result.status = DetectorStatus.SUCCESS; + result.detectorId = detectorId; + result.detectorName = detectorName; + result.createResponse = createResponse; + result.startResponse = startResponse; + return result; + } + } + + /** + * Step 2: Get mappings and filter fields + */ + private void getMappingsAndFilterFields(String indexName, ActionListener listener) { + GetMappingsRequest getMappingsRequest = new GetMappingsRequest().indices(indexName); + + client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(response -> { + Map mappings = response.getMappings(); + + if (mappings.isEmpty()) { + listener.onFailure(new IllegalArgumentException("No mapping found for the index: " + indexName)); + return; + } + + String firstIndexName = mappings.keySet().iterator().next(); + MappingMetadata mappingMetadata = mappings.get(firstIndexName); + Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); + + if (mappingSource == null) { + listener.onFailure(new IllegalArgumentException("Index '" + indexName + "' has no mapping metadata")); + return; + } + + Map fieldsToType = new HashMap<>(); + ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", true); + + final Set dateFields = ToolHelper.findDateTypeFields(fieldsToType); + if (dateFields.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Index '" + indexName + "' has no date fields")); + return; + } + + Map filteredMapping = fieldsToType + .entrySet() + .stream() + .filter(entry -> VALID_FIELD_TYPES.contains(entry.getValue())) + .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); + + MappingContext context = new MappingContext(firstIndexName, filteredMapping, dateFields); + listener.onResponse(context); + + }, e -> { + log.error("Failed to get mapping: {}", e.getMessage()); + if (e instanceof IndexNotFoundException) { + listener.onFailure(new IllegalArgumentException("Index '" + indexName + "' does not exist")); + } else { + listener.onFailure(e); + } + })); + } + + // ── Multi-detector sequential creation ────────────────────────────────── + + @SuppressWarnings("unchecked") + private void createMultipleDetectors( + MappingContext ctx, + String tenantId, + int maxRetries, + List alreadyCreated, + List> results, + int totalAttempts, + ActionListener listener + ) { + if (alreadyCreated.size() >= MAX_DETECTORS_PER_INDEX || totalAttempts >= MAX_DETECTORS_PER_INDEX + 2) { + listener.onResponse((T) gson.toJson(results)); + return; + } + + String alreadyCreatedContext = buildAlreadyCreatedContext(alreadyCreated); + + // Build prompt with context prepended + StringJoiner dateFieldsJoiner = new StringJoiner(","); + ctx.dateFields.forEach(dateFieldsJoiner::add); + String basePrompt = constructPrompt(ctx.filteredMapping, ctx.indexName, dateFieldsJoiner.toString(), ctx.indexInsight); + + // Inject sample docs and data density + StringBuilder extraContext = new StringBuilder(); + if (ctx.dataDensity24h >= 0) { + extraContext.append("DATA DENSITY: ").append(ctx.dataDensity24h).append(" documents in the last 24 hours\n"); + if (ctx.dataDensity24h == 0) { + extraContext.append("WARNING: This index has no data in the last 24 hours.\n"); + } + } + if (ctx.sampleDocs != null) { + String truncated = ctx.sampleDocs.length() > 2000 ? ctx.sampleDocs.substring(0, 2000) + "..." : ctx.sampleDocs; + extraContext.append("\nSAMPLE DOCUMENTS:\n").append(truncated).append("\n"); + } + + String fullPrompt = alreadyCreatedContext + extraContext + basePrompt; + + callLLM(fullPrompt, tenantId, ActionListener.wrap(llmResponse -> { + // Check for NONE signal before parsing — only valid in multi-detector loop + if (llmResponse != null && llmResponse.toUpperCase(Locale.ROOT).contains(NONE_SIGNAL)) { + log.info("LLM returned NONE for '{}' after {} detectors", ctx.indexName, alreadyCreated.size()); + listener.onResponse((T) gson.toJson(results)); + return; + } + parseAndRetryWithLLM( + llmResponse, + ctx.indexName, + dateFieldsJoiner.toString(), + ctx, + tenantId, + maxRetries, + 0, + "model", + listener, + (suggestions, listenerCallback) -> { + // Run through existing validation pipeline + validateDetectorPhase(suggestions, ctx, tenantId, maxRetries, 0, new ActionListener() { + @Override + public void onResponse(T resultJson) { + List> detectorResults = gson.fromJson((String) resultJson, List.class); + Map result = detectorResults.get(0); + results.add(result); + + // Stop on systemic failures + String error = (String) result.get("error"); + if (error != null && error.contains("GENERAL_SETTINGS")) { + listenerCallback.onResponse((T) gson.toJson(results)); + return; + } + + // Build summary and continue + if (DetectorStatus.SUCCESS.toString().equals(result.get("status"))) { + List updated = new ArrayList<>(alreadyCreated); + updated.add(buildDetectorSummary(suggestions)); + createMultipleDetectors(ctx, tenantId, maxRetries, updated, results, totalAttempts + 1, listenerCallback); + } else { + // Non-systemic failure — still try next detector + createMultipleDetectors( + ctx, + tenantId, + maxRetries, + alreadyCreated, + results, + totalAttempts + 1, + listenerCallback + ); + } + } + + @Override + public void onFailure(Exception e) { + results.add(DetectorResult.failedValidation(ctx.indexName, e.getMessage()).toMap()); + listenerCallback.onResponse((T) gson.toJson(results)); + } + }); + } + ); + }, e -> { + log.error("LLM call failed for multi-detector on '{}': {}", ctx.indexName, e.getMessage()); + if (results.isEmpty()) { + // First call failed — propagate failure + listener.onFailure(e); + } else { + // Subsequent call failed — return what we have + listener.onResponse((T) gson.toJson(results)); + } + })); + } + + private String buildAlreadyCreatedContext(List summaries) { + if (summaries.isEmpty()) + return ""; + StringBuilder sb = new StringBuilder(); + sb.append("ALREADY CREATED DETECTORS FOR THIS INDEX (do NOT create similar detectors):\n\n"); + for (int i = 0; i < summaries.size(); i++) { + sb.append("Detector ").append(i + 1).append(": ").append(summaries.get(i)).append("\n"); + } + sb.append("\nCreate a DIFFERENT detector monitoring a DIFFERENT signal.\n"); + sb.append("Do NOT use the same aggregation field or monitor the same type of anomaly.\n"); + sb.append("If no more useful, non-overlapping signals exist, return exactly: ").append(NONE_SIGNAL).append("\n\n"); + sb.append("---\n\n"); + return sb.toString(); + } + + private String buildDetectorSummary(Map suggestions) { + String field = suggestions.get(OUTPUT_KEY_AGGREGATION_FIELD); + String method = suggestions.get(OUTPUT_KEY_AGGREGATION_METHOD); + String category = suggestions.get(OUTPUT_KEY_CATEGORY_FIELD); + String filter = suggestions.getOrDefault("filter", ""); + StringBuilder sb = new StringBuilder(); + sb.append(method).append("(").append(field).append(")"); + if (!filter.isEmpty()) + sb.append(" WHERE ").append(filter); + if (!Strings.isNullOrEmpty(category)) + sb.append(" per ").append(category); + return sb.toString(); + } + + /** + * The tool factory + */ + public static class Factory implements WithModelTool.Factory { + private Client client; + private NamedWriteableRegistry namedWriteableRegistry; + + private static volatile CreateAnomalyDetectorToolEnhanced.Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static CreateAnomalyDetectorToolEnhanced.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (CreateAnomalyDetectorToolEnhanced.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new CreateAnomalyDetectorToolEnhanced.Factory(); + return INSTANCE; + } + } + + public void init(Client client, NamedWriteableRegistry namedWriteableRegistry) { + this.client = client; + this.namedWriteableRegistry = namedWriteableRegistry; + } + + /** + * + * @param map the input parameters + * @return the instance of this tool + */ + @Override + public CreateAnomalyDetectorToolEnhanced create(Map map) { + String modelId = (String) map.getOrDefault(COMMON_MODEL_ID_FIELD, ""); + if (modelId.isEmpty()) { + throw new IllegalArgumentException("model_id cannot be empty."); + } + String modelType = (String) map.getOrDefault("model_type", ModelType.CLAUDE.toString()); + // if model type is empty, use the default value + if (modelType.isEmpty()) { + modelType = ModelType.CLAUDE.toString(); + } else if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) + && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { + throw new IllegalArgumentException("Unsupported model_type: " + modelType); + } + String prompt = (String) map.getOrDefault("prompt", ""); + String resultIndex = (String) map.getOrDefault("result_index", ""); + return new CreateAnomalyDetectorToolEnhanced(client, modelId, modelType, prompt, resultIndex, namedWriteableRegistry); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + + @Override + public List getAllModelKeys() { + return List.of(COMMON_MODEL_ID_FIELD); + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/AnomalyDetectorToolHelper.java b/src/main/java/org/opensearch/agent/tools/utils/AnomalyDetectorToolHelper.java new file mode 100644 index 00000000..8c85e833 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/AnomalyDetectorToolHelper.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; + +/** + * Helper utilities for anomaly detector tools + */ +public class AnomalyDetectorToolHelper { + + /** + * Create an aggregation builder based on the aggregation method + * @param method aggregation method (avg, sum, min, max, count) + * @param field field name to aggregate on + * @return aggregation builder + */ + public static AggregationBuilder createAggregationBuilder(String method, String field) { + return switch (method.toLowerCase(Locale.ROOT)) { + case "avg" -> AggregationBuilders.avg(field).field(field); + case "sum" -> AggregationBuilders.sum(field).field(field); + case "min" -> AggregationBuilders.min(field).field(field); + case "max" -> AggregationBuilders.max(field).field(field); + case "count" -> AggregationBuilders.count(field).field(field); + default -> throw new IllegalArgumentException("Unsupported aggregation method: " + method); + }; + } + + /** + * Extract list of indices from tool parameters + * @param parameters tool parameters containing "input" with JSON array of indices + * @return list of index names + */ + public static List extractIndicesList(Map parameters) { + String inputStr = parameters.get("input"); + if (inputStr == null || inputStr.trim().isEmpty()) { + throw new IllegalArgumentException("Input parameter is required"); + } + + try { + Map input = StringUtils.gson.fromJson(inputStr, Map.class); + List indices = (List) input.get("indices"); + + if (indices == null || indices.isEmpty()) { + throw new IllegalArgumentException("No indices provided"); + } + + for (String index : indices) { + if (index.startsWith(".")) { + throw new IllegalArgumentException("System indices not supported: " + index); + } + } + + return indices; + } catch (Exception e) { + throw new IllegalArgumentException("Failed to parse indices: " + e.getMessage()); + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java index e4ff38f8..296cc2b3 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java @@ -10,6 +10,8 @@ import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; @@ -79,6 +81,21 @@ public static void extractFieldNamesTypes( } } + /** + * Find all date type fields from a field-to-type mapping + * @param fieldsToType map of field names to field types + * @return set of field names that are date or date_nanos type + */ + public static Set findDateTypeFields(Map fieldsToType) { + Set dateTypes = Set.of("date", "date_nanos"); + return fieldsToType + .entrySet() + .stream() + .filter(e -> dateTypes.contains(e.getValue())) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + } + /** * Wrapper to get PPL transport action listener * @param listener input action listener diff --git a/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorEnhancedPrompt.json b/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorEnhancedPrompt.json new file mode 100644 index 00000000..e9314d38 --- /dev/null +++ b/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorEnhancedPrompt.json @@ -0,0 +1,3 @@ +{ + "prompt": "Analyze this index and suggest ONE anomaly detector that generates clear, actionable alerts.\n\nIndex: ${indexInfo.indexName}\nMapping: ${indexInfo.indexMapping}\nAvailable date fields: ${indexInfo.dateFields}\n${indexInfo.indexInsight}\n\nCORE PRINCIPLE:\nCreate detectors that find anomalies a domain expert would want to be alerted about — whether operational issues, business KPI changes, unusual patterns, or threshold violations.\nFocus on 1-2 RELATED fields that will be used for defining our features which are the important KPIs that we are looking for anomalies in.\n\nSTEP 1 - IDENTIFY MONITORING PRIORITY:\nImpact priority (operational, business, or analytics):\n\n1. Service Reliability: error_count, failed_requests, 5xx_status, exceptions, timeout_count, failures\n2. Performance Issues: response_time, latency, processing_time, duration, delay_minutes\n3. Resource Problems: cpu_usage, memory_usage, disk_usage, connection_count\n4. Traffic/Capacity: request_count, bytes_transferred, active_connections, queue_size, throughput\n5. Security Events: blocked_requests, authentication_failures, suspicious_activity\n6. Business Metrics: revenue, conversion_rate, transaction_amount, order_count, cart_value\n7. User Analytics: page_views, session_duration, bounce_rate, click_count\n\nSTEP 2 - FEATURE SELECTION STRATEGY:\nDEFAULT TO 1 FEATURE unless multiple features provide complementary evidence of the SAME operational issue.\n\n✓ EXCELLENT 2-FEATURE COMBINATIONS (investigated together):\n\n* [error_count, timeout_count] → \"Service degradation: errors up 300%, timeouts up 150%\"\n* [response_time, error_rate] → \"Performance issue: response time 2x higher, error rate spiked\"\n* [cpu_usage, memory_usage] → \"Resource exhaustion: CPU at 90%, memory at 85%\"\n* [failed_requests, retry_count] → \"Service instability: failures up 400%, retries up 250%\"\n* [bytes_sent, bytes_received] → \"Network anomaly: traffic pattern changed significantly\"\n* [blocked_requests, failed_auth] → \"Security event: attack pattern detected\"\n* [request_count, error_count] → \"Service stress: high load with increasing failures\"\n\n✓ GOOD SINGLE FEATURES (clear, actionable alerts):\n\n* [error_count] → \"Error spike: 400% increase in errors\"\n* [response_time] → \"Performance issue: response time 3x normal\"\n* [bytes_transferred] → \"Traffic anomaly: data transfer volume unusual\"\n* [cpu_usage] → \"Resource issue: CPU utilization spiked\"\n* [document count] → \"Volume anomaly: ingestion rate dropped\" (use count on any field)\n* [order_total] → \"Revenue anomaly: order values changed significantly\"\n\nSTEP 3 - AGGREGATION METHOD RULES:\nCRITICAL: OpenSearch Anomaly Detection supports ONLY these 5 aggregation methods:\n\n* avg() - Average value of numeric fields\n* sum() - Sum total of numeric fields\n* min() - Minimum value of numeric fields\n* max() - Maximum value of numeric fields\n* count() - Count of documents (works on any field type)\n\nField Type Constraints:\n\n* Numeric fields (long, integer, double, float): Can use avg, sum, min, max, count\n* Keyword fields: Can ONLY use count ('count' ONLY when meaningful - avoid mixed good/bad values)\n* NEVER use sum/avg/min/max on keyword fields - Will cause errors\n\nOperational Logic Rules:\n\n* Times/Durations/Delays: ALWAYS 'avg' (NEVER 'sum' - summing time is meaningless)\n* Errors/Failures/Counts: 'sum' for totals, 'avg' for rates/percentages\n* Bytes/Size fields: 'sum' for total volume (bytes, object_size, response.bytes)\n* Memory/Resource fields: 'avg' for percentages, 'max' for absolute values\n* Business metrics: 'avg' for per-transaction values, 'sum' for revenue totals\n* Keyword fields for traffic: 'count' (counts specific when there is an error or something specific like bad status codes, not just all status codes)\n\nWhen to be CAREFUL with 'count' on keyword fields:\n\n* If a field mixes success and error values (e.g., status_code.keyword with 200, 404, 500), plain count becomes total traffic — use a filter to isolate errors instead (filter=status:gte:400)\n* Check sample documents to understand what values a field contains before deciding\n* count() without a filter counts ALL documents — this is valid for traffic/volume monitoring but may not detect specific issues\n\nField Pattern Recognition:\n\n* *_count, *_errors, *_failures, *_requests → 'sum' (if numeric) OR 'count' (if keyword)\n* *_bytes, *_size, object_size → 'sum' (if numeric)\n* status_code, method, protocol → 'count' (if keyword)\n\nSTEP 3.5 - FILTER CONDITIONS (optional):\nFilters isolate a SUBSET of documents for counting. Use them to detect anomalies in error rates, failure rates, or specific event types.\n\nWHEN TO USE FILTERS:\n* Counting HTTP errors: filter=status:gte:400 with method=count\n* Counting failed requests: filter=failed:eq:true with method=count\n* Counting high-severity logs: filter=severityNumber:gte:17 with method=count\n* Counting specific error codes: filter=status.code:eq:2 with method=count\n\nWHEN NOT TO USE FILTERS:\n* Averaging response time (want ALL requests, not just errors)\n* Summing bytes transferred (want total volume)\n* Counting total throughput (want all documents)\n\nFILTER SYNTAX: field:operator:value\n* Range: field:gte:value, field:lte:value, field:gt:value, field:lt:value\n* Term: field:eq:value\n* Leave empty if no filter needed\n\nFILTER + AGGREGATION COMBINATIONS:\n* count + filter = \"How many error documents per interval?\" (ERROR RATE)\n* avg + no filter = \"What's the average latency across all requests?\" (LATENCY)\n* sum + no filter = \"What's the total bytes transferred?\" (THROUGHPUT)\n* count + no filter = \"How many total documents per interval?\" (TRAFFIC VOLUME)\n\nNOTE: Filters pair naturally with count() features. The system handles filter placement automatically — for detectors with a category field and a count feature, the filter is applied inside the count feature to enable detection of new anomalies (e.g., a service going from 0 errors to many). For other configurations, the filter narrows the entire dataset.\n\nCOMMON MISTAKES TO AVOID:\n* Do NOT use avg/sum on status codes, error codes, or severity numbers — these are categorical codes, use count with a filter instead\n* If a field mixes success and error values (status_code with 200+404+500), use a filter to isolate errors (e.g., filter=status:gte:400)\n* Do NOT create a detector that just counts all documents with no filter and no category field unless you specifically want to monitor total ingestion volume. For most use cases, adding a category field (per-service breakdown) or a filter (error subset) makes the signal more actionable.\n\nSTEP 4 - CATEGORY FIELD SELECTION:\nIMPORTANT: Category fields are OPTIONAL. If no field provides actionable segmentation, leave empty.\n\nCRITICAL CONSTRAINT: Category field MUST be a keyword or ip field type from the mapping above.\nCheck the field type - ONLY use fields marked as \"keyword\" or \"ip\".\nIf no keyword/ip fields exist, leave category_field empty.\n\nChoose 1-2 keyword fields that provide actionable segmentation (comma-separated if 2):\n\n✓ EXCELLENT choices (actionable alerts):\n\n* service_name, endpoint, api_path → \"Error spike on /checkout endpoint\"\n* host, instance_id, server → \"CPU spike on web-server-01\"\n* region, datacenter, availability_zone → \"Network issues in us-west-2\"\n* status_code, error_type → \"500 errors spiking\"\n* method, protocol → \"POST requests failing\"\n\n✓ GOOD choices (moderate cardinality):\n\n* device_type, browser, user_agent for web analytics\n* database_name, table_name for DB monitoring\n* queue_name, topic for messaging systems\n* payment_method, transaction_type for financial monitoring\n\n✗ AVOID (too specific or not actionable for general monitoring):\n\n* Unique identifiers: transaction_id, session_id, request_id\n* High-cardinality user data: user_id, customer_id\n* Timestamp fields, hash fields, UUIDs\n* Fields ending in _key, _hash, _uuid\n\nCARDINALITY GUIDELINES:\n\n* Ideal: 5-50 unique values\n* Acceptable: 50-500 values (if actionable segmentation)\n* AVOID: >500 unique values — each creates a separate model using cluster memory\n* No category field: Perfectly fine — creates a single model monitoring aggregate behavior across the entire index\n* With category field: Creates one model PER unique value (high-cardinality mode) — provides per-entity anomaly detection but uses more cluster memory\n\nSTEP 5 - DATE FIELD SELECTION:\nChoose the date field that represents WHEN the event actually occurred, not when it was ingested.\n\n* Business data (orders, transactions): use the business timestamp (order_date, transaction_date) not @timestamp\n* IoT/sensor data: use the measurement time (sensor_reading_time) not the upload/ingest time\n* Log data: @timestamp is usually correct (it represents when the event was logged)\n* CDC/replication data: use the source timestamp (db_updated_at) not the replication timestamp\n* If unsure, prefer the field that best represents the real-world timing of the event being monitored\n\nSTEP 6 - DETECTION INTERVAL GUIDELINES:\nMatch interval to data frequency and operational needs:\n\n* Real-time systems: 10-15 minutes (APIs, web services, errors, response times)\n* Infrastructure monitoring: 15-30 minutes (servers, databases, resource usage)\n* Business processes: 30-60 minutes (transactions, conversions)\n* Security logs: 10-30 minutes (access logs, firewalls, authentication)\n* Batch/ETL processes: 60+ minutes (data pipeline monitoring)\n* Sparse data: 60+ minutes (avoid false positives from empty buckets)\n\nIf unsure about the interval, use 10 minutes as default. The system will automatically adjust the interval based on actual data density after your suggestion.\n\nSTEP 7 - SPARSE DATA AWARENESS:\nIf the index has low data volume or infrequent writes:\n\n* Use longer intervals (60+ minutes) to ensure enough data points per bucket\n* PREFERRED: Remove category field entirely rather than using very long intervals (>4 hours)\n* ALTERNATIVE: Choose a lower-cardinality category field\n* PRINCIPLE: A detector with a shorter interval and no category field is more useful than one with a 12-hour interval and a category field\n* The system will automatically adjust the interval based on actual data density after your suggestion\n\nIMPORTANT CONSTRAINTS:\n* Create the SINGLE most valuable detector. Only suggest additional detectors on subsequent calls if they monitor a genuinely DIFFERENT and important signal.\n* If a previous detector already covers the primary monitoring need, return {NONE}.\n* Do NOT create detectors just to fill a quota — quality over quantity.\n* If you already created a detector WITH a category field, prefer the next one WITHOUT a category field, or return {NONE}.\n\nOUTPUT FORMAT:\nReturn ONLY this structured format, no explanation:\n{category_field=field_or_empty|aggregation_field=field1,field2|aggregation_method=method1,method2|date_field=date_field_name|filter=field:operator:value_or_empty|interval=minutes|description=one_sentence}\n\nExamples:\n{category_field=endpoint|aggregation_field=status|aggregation_method=count|date_field=@timestamp|filter=status:gte:400|interval=10|description=HTTP error rate per endpoint}\n{category_field=service_name|aggregation_field=response_time|aggregation_method=avg|date_field=@timestamp|filter=|interval=15|description=Average response time per service}\n{category_field=|aggregation_field=order_total|aggregation_method=sum|date_field=order_date|filter=|interval=30|description=Total order revenue}\n{category_field=|aggregation_field=bytes_sent|aggregation_method=sum|date_field=@timestamp|filter=|interval=10|description=Network traffic volume}\n\nIf no suitable fields exist, use empty strings for aggregation_field and aggregation_method.\nIf no more useful non-overlapping detectors can be created, return exactly:\n{NONE}" +} diff --git a/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolEnhancedTests.java b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolEnhancedTests.java new file mode 100644 index 00000000..383c2935 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolEnhancedTests.java @@ -0,0 +1,1326 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.jsoup.helper.Validate.fail; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.lucene.search.TotalHits; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.agent.tools.utils.AnomalyDetectorToolHelper; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.indexInsight.IndexInsight; +import org.opensearch.ml.common.indexInsight.IndexInsightTaskStatus; +import org.opensearch.ml.common.indexInsight.MLIndexInsightType; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.indexInsight.MLIndexInsightGetAction; +import org.opensearch.ml.common.transport.indexInsight.MLIndexInsightGetResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.IndicesAdminClient; + +import com.google.common.collect.ImmutableMap; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateAnomalyDetectorToolEnhancedTests { + @Mock + private Client client; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + @Mock + private NamedWriteableRegistry namedWriteableRegistry; + + private Map mockedMappings; + private Map indexMappings; + + @Mock + private MLTaskResponse mlTaskResponse; + @Mock + private ModelTensorOutput modelTensorOutput; + @Mock + private ModelTensors modelTensors; + + private ModelTensor modelTensor; + private Map modelReturns; + + private String mockedIndexName = "http_logs"; + private String mockedResponse = "{category_field=|aggregation_field=response,responseLatency|aggregation_method=count,avg|interval=10}"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + createMappings(); + + // Setup mapping mocks + when(mappingMetadata.getSourceAsMap()).thenReturn(indexMappings); + when(getMappingsResponse.getMappings()).thenReturn(mockedMappings); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getMappingsResponse); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + + initMLTensors(); + CreateAnomalyDetectorToolEnhanced.Factory.getInstance().init(client, namedWriteableRegistry); + } + + @Test + public void testModelIdIsNullOrEmpty() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> CreateAnomalyDetectorToolEnhanced.Factory.getInstance().create(ImmutableMap.of("model_id", "")) + ); + assertEquals("model_id cannot be empty.", exception.getMessage()); + } + + @Test + public void testInvalidModelType() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "unknown")) + ); + assertEquals("Unsupported model_type: unknown", exception.getMessage()); + } + + @Test + public void testValidModelTypes() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "openai")); + assertEquals(CreateAnomalyDetectorToolEnhanced.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("OPENAI", tool.getModelType().toString()); + + tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "claude")); + assertEquals(CreateAnomalyDetectorToolEnhanced.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("CLAUDE", tool.getModelType().toString()); + } + + @Test + public void testDefaultModelType() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + assertEquals("CLAUDE", tool.getModelType().toString()); + } + + @Test + public void testEmptyModelType() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "")); + assertEquals("CLAUDE", tool.getModelType().toString()); + } + + @Test + public void testCustomPrompt() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "prompt", "custom prompt")); + assertEquals("custom prompt", tool.getContextPrompt()); + } + + @Test + public void testIndexNameValidation_SystemIndex() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(".system_index")))), + ActionListener.wrap(response -> {}, e -> { + throw new IllegalArgumentException(e.getMessage()); + }) + ) + ); + assertTrue(exception.getMessage().contains("System indices not supported")); + } + + @Test + public void testInputFormat_IndicesList() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + // This should extract first index from list + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + assertTrue(response.contains("indexName")); + }, e -> log.error("Error: ", e)) + ); + } + + @Test + public void testLLMResponseParsing_ValidFormat() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + // Test that all fields from LLM response are correctly parsed + String validResponse = "{category_field=host|aggregation_field=response,responseLatency|aggregation_method=count,avg|interval=15}"; + modelReturns = Collections.singletonMap("response", validResponse); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + assertTrue(response.contains("\"status\":\"success\"")); + assertTrue(response.contains("detectorId")); + assertTrue(response.contains("detectorName")); + Map result = gson.fromJson(response, Map.class); + assertEquals("success", result.get("status")); + }, e -> fail("Should successfully parse valid LLM response: " + e.getMessage())) + ); + } + + private void createMappings() { + indexMappings = new HashMap<>(); + indexMappings + .put( + "properties", + ImmutableMap + .of( + "response", + ImmutableMap.of("type", "integer"), + "responseLatency", + ImmutableMap.of("type", "float"), + "host", + ImmutableMap.of("type", "keyword"), + "date", + ImmutableMap.of("type", "date") + ) + ); + mockedMappings = new HashMap<>(); + mockedMappings.put(mockedIndexName, mappingMetadata); + + modelReturns = Collections.singletonMap("response", mockedResponse); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + } + + @Test + public void testLLMResponseParsing_InvalidFormat() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + modelReturns = Collections.singletonMap("response", "invalid format without curly braces"); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + assertTrue(response.contains("\"status\":\"failed_validation\"")); + assertTrue(response.contains("Cannot parse LLM response after")); + }, e -> fail("Should return JSON response, not throw exception: " + e.getMessage())) + ); + } + + @Test + public void testLLMResponseParsing_EmptyResponse() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + modelReturns = Collections.singletonMap("response", ""); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + assertTrue(response.contains("\"status\":\"failed_validation\"")); + assertTrue(response.contains("Remote endpoint fails to inference, no response found")); + }, e -> fail("Should return JSON response, not throw exception: " + e.getMessage())) + ); + } + + @Test + public void testLLMResponseParsing_NullResponse() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + modelReturns = Collections.singletonMap("response", null); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + assertTrue(response.contains("\"status\":\"failed_validation\"")); + assertTrue(response.contains("Remote endpoint fails to inference, no response found")); + }, e -> fail("Should return JSON response, not throw exception: " + e.getMessage())) + ); + } + + @Test + public void testIntervalParsing_ValidInterval() { + // Test that interval is correctly parsed from LLM response + String responseWithInterval = "{category_field=|aggregation_field=response|aggregation_method=count|interval=15}"; + modelReturns = Collections.singletonMap("response", responseWithInterval); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + assertTrue("Response should contain success status", response.contains("\"status\":\"success\"")); + assertTrue("Response should contain detector configuration", response.contains("detectorName")); + }, e -> fail("Should not throw exception: " + e.getMessage())) + ); + } + + @Test + public void testIntervalParsing_DefaultInterval() { + // Test that default interval (10) is used when not specified + String responseWithoutInterval = "{category_field=|aggregation_field=response|aggregation_method=count|interval=}"; + modelReturns = Collections.singletonMap("response", responseWithoutInterval); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + assertTrue("Response should contain success status", response.contains("\"status\":\"success\"")); + assertTrue("Response should contain detector configuration", response.contains("detectorName")); + }, e -> fail("Should not throw exception: " + e.getMessage())) + ); + } + + @Test + public void testCategoryField_Empty() { + // Test single-entity detector (no category field) + String responseNoCategoryField = "{category_field=|aggregation_field=response|aggregation_method=count|interval=10}"; + modelReturns = Collections.singletonMap("response", responseNoCategoryField); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + assertTrue(response.contains("\"status\":\"success\"")); + assertTrue(response.contains("detectorName")); + }, e -> fail("Should not throw exception: " + e.getMessage())) + ); + } + + @Test + public void testCategoryField_WithValue() { + // Test multi-entity detector (with category field) + String responseWithCategoryField = "{category_field=host|aggregation_field=response|aggregation_method=count|interval=10}"; + modelReturns = Collections.singletonMap("response", responseWithCategoryField); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + assertTrue("Response should contain success status", response.contains("\"status\":\"success\"")); + assertTrue("Response should contain detector configuration", response.contains("detectorName")); + }, e -> fail("Should not throw exception: " + e.getMessage())) + ); + } + + @Test + public void testIndexNotFound() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + // Mock IndexNotFoundException + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new Exception("IndexNotFoundException[no such index [nonexistent]]")); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList("nonexistent")))), + ActionListener.wrap(response -> { + assertTrue(response.contains("\"status\":\"failed_validation\"")); + assertTrue(response.contains("does not exist") || response.contains("no such index")); + }, e -> fail("Should return JSON response, not throw exception: " + e.getMessage())) + ); + } + + @Test + public void testMultipleIndices() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Arrays.asList("index1", "index2")))), + ActionListener.wrap(response -> { + assertTrue(response.contains("index1")); + assertTrue(response.contains("index2")); + }, e -> log.error("Error: ", e)) + ); + } + + @Test + public void testIndexWithMultipleDateFields() { + Map multiDateMappings = new HashMap<>(); + multiDateMappings + .put( + "properties", + ImmutableMap + .of( + "timestamp1", + ImmutableMap.of("type", "date"), + "timestamp2", + ImmutableMap.of("type", "date"), + "created_at", + ImmutableMap.of("type", "date"), + "updated_at", + ImmutableMap.of("type", "date"), + "event_time", + ImmutableMap.of("type", "date_nanos"), + "response", + ImmutableMap.of("type", "integer") + ) + ); + + when(mappingMetadata.getSourceAsMap()).thenReturn(multiDateMappings); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList("multi_date_index")))), + ActionListener.wrap(response -> { + assertTrue(response.contains("indexName")); + }, e -> log.error("Error: ", e)) + ); + } + + @Test + public void testIndexWithNoDateFields() { + Map noDateMappings = new HashMap<>(); + noDateMappings + .put("properties", ImmutableMap.of("response", ImmutableMap.of("type", "integer"), "host", ImmutableMap.of("type", "keyword"))); + + when(mappingMetadata.getSourceAsMap()).thenReturn(noDateMappings); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList("no_date_index")))), + ActionListener.wrap(response -> { + assertTrue(response.contains("\"status\":\"failed_validation\"")); + assertTrue(response.contains("has no date fields")); + }, e -> fail("Should return JSON response, not throw exception: " + e.getMessage())) + ); + } + + // ===== NEW TESTS FOR COMMIT 1 CHANGES ===== + + @Test + public void testGetAggMethod_ReturnsActualType() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + // Build features with different aggregation types and verify getAggMethod extracts them correctly + AggregationBuilder avgAgg = AnomalyDetectorToolHelper.createAggregationBuilder("avg", "bytes"); + Feature avgFeature = new Feature("id1", "feature_bytes", true, avgAgg); + assertEquals("avg", tool.getAggMethod(avgFeature)); + + AggregationBuilder sumAgg = AnomalyDetectorToolHelper.createAggregationBuilder("sum", "bytes"); + Feature sumFeature = new Feature("id2", "feature_bytes", true, sumAgg); + assertEquals("sum", tool.getAggMethod(sumFeature)); + + // count maps to value_count internally — verify it maps back to "count" + AggregationBuilder countAgg = AnomalyDetectorToolHelper.createAggregationBuilder("count", "requests"); + Feature countFeature = new Feature("id3", "feature_requests", true, countAgg); + assertEquals("count", tool.getAggMethod(countFeature)); + + AggregationBuilder maxAgg = AnomalyDetectorToolHelper.createAggregationBuilder("max", "latency"); + Feature maxFeature = new Feature("id4", "feature_latency", true, maxAgg); + assertEquals("max", tool.getAggMethod(maxFeature)); + + AggregationBuilder minAgg = AnomalyDetectorToolHelper.createAggregationBuilder("min", "latency"); + Feature minFeature = new Feature("id5", "feature_latency", true, minAgg); + assertEquals("min", tool.getAggMethod(minFeature)); + } + + @Test + public void testIndexInsightSuccess_InsightReachesLLMPrompt() throws Exception { + // Mock Index Insight to return content + mockFullDetectorCreationChain(); + String insightContent = "dataSource: web_logs, recommendedFeatures: bytes_sent"; + IndexInsight insight = IndexInsight + .builder() + .index(mockedIndexName) + .content(insightContent) + .status(IndexInsightTaskStatus.COMPLETED) + .taskType(MLIndexInsightType.ALL) + .lastUpdatedTime(Instant.now()) + .build(); + MLIndexInsightGetResponse insightResponse = new MLIndexInsightGetResponse(insight); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(insightResponse); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + // Capture the LLM prompt to verify insight was injected + AtomicReference capturedPrompt = new AtomicReference<>(); + + // Re-mock LLM call to capture the request AND return a valid response + String validResponse = "{category_field=|aggregation_field=response|aggregation_method=count|interval=10}"; + modelReturns = Collections.singletonMap("response", validResponse); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + // Override LLM mock to also capture the prompt + doAnswer(invocation -> { + MLPredictionTaskRequest req = (MLPredictionTaskRequest) invocation.getArguments()[1]; + RemoteInferenceInputDataSet ds = (RemoteInferenceInputDataSet) req.getMlInput().getInputDataset(); + capturedPrompt.set(ds.getParameters().toString()); + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + responseRef.set(response); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + + // The key assertion: the insight content must appear in the prompt sent to the LLM + Assert.assertNotNull("LLM should have been called", capturedPrompt.get()); + Assert + .assertTrue( + "Prompt must contain the Index Insight content, but was: " + capturedPrompt.get(), + capturedPrompt.get().contains("INDEX ANALYSIS") || capturedPrompt.get().contains(insightContent) + ); + } + + @Test + public void testIndexInsightFailure_StillCreatesDetector() throws Exception { + // Mock Index Insight to fail + mockFullDetectorCreationChain(); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("Index Insight not available")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + // Capture the LLM prompt to verify NO insight was injected + AtomicReference capturedPrompt = new AtomicReference<>(); + + String validResponse = "{category_field=|aggregation_field=response|aggregation_method=count|interval=10}"; + modelReturns = Collections.singletonMap("response", validResponse); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + doAnswer(invocation -> { + MLPredictionTaskRequest req = (MLPredictionTaskRequest) invocation.getArguments()[1]; + RemoteInferenceInputDataSet ds = (RemoteInferenceInputDataSet) req.getMlInput().getInputDataset(); + capturedPrompt.set(ds.getParameters().toString()); + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + responseRef.set(response); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + + // Verify the LLM was still called (tool didn't abort) + Assert.assertNotNull("LLM should have been called despite insight failure", capturedPrompt.get()); + // Verify no insight was injected into the prompt + Assert.assertFalse("Prompt must NOT contain INDEX ANALYSIS when insight failed", capturedPrompt.get().contains("INDEX ANALYSIS")); + } + + @Test + public void testTemplateVariableLeak_ReplacedWithActualField() throws Exception { + // LLM returns literal ${dateFields} instead of the actual date field name + mockFullDetectorCreationChain(); + String leakyResponse = "{category_field=|aggregation_field=${dateFields},responseLatency|aggregation_method=count,avg|interval=10}"; + modelReturns = Collections.singletonMap("response", leakyResponse); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + // Mock Index Insight to fail fast (skip insight, go straight to LLM) + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + responseRef.set(response); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + String response = responseRef.get(); + Assert.assertNotNull("Should get a response", response); + // The response must NOT contain the literal template variable — it should have been replaced + Assert + .assertFalse( + "Template variable ${dateFields} should have been replaced with actual date field, but response was: " + response, + response.contains("${dateFields}") + ); + Assert + .assertFalse( + "Template variable ${indexInfo.dateFields} should not appear in output", + response.contains("${indexInfo.dateFields}") + ); + } + + @Test + public void testEmptyAggregationFields_FailsWithMessage() throws Exception { + // LLM returns empty aggregation fields — all fields are blank after split + mockFullDetectorCreationChain(); + String emptyFieldsResponse = "{category_field=|aggregation_field=,|aggregation_method=,|interval=10}"; + modelReturns = Collections.singletonMap("response", emptyFieldsResponse); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + // Mock Index Insight to fail fast + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(response -> { + responseRef.set(response); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + String response = responseRef.get(); + Assert.assertNotNull("Should get a response", response); + // Must NOT report success — empty features should fail + Assert.assertFalse("Empty aggregation fields must not produce a successful detector", response.contains("\"status\":\"success\"")); + } + + // ===== COMMIT 2: OTel Fast-Path Tests ===== + + @Test + public void testOtelTraceMapping_CreatesTwoDetectors_SkipsLLM() throws Exception { + // Set up OTel trace mapping with the 4 required signature fields + Map otelTraceMapping = new HashMap<>(); + otelTraceMapping + .put( + "properties", + Map + .of( + "traceId", + ImmutableMap.of("type", "keyword"), + "spanId", + ImmutableMap.of("type", "keyword"), + "durationInNanos", + ImmutableMap.of("type", "long"), + "serviceName", + ImmutableMap.of("type", "keyword"), + "startTime", + ImmutableMap.of("type", "date_nanos"), + "status", + ImmutableMap.of("type", "object", "properties", Map.of("code", ImmutableMap.of("type", "integer"))) + ) + ); + when(mappingMetadata.getSourceAsMap()).thenReturn(otelTraceMapping); + mockedMappings.put("otel-traces", mappingMetadata); + + // Mock suggest (returns null interval = use default) + create + start + mockOtelDetectorCreationChain(); + + // Mock Index Insight to fail fast + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList("otel-traces")))), + ActionListener.wrap(r -> { + responseRef.set(r); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + + String response = responseRef.get(); + Assert.assertNotNull("Should get a response", response); + Assert.assertFalse("Should not error", response.startsWith("ERROR")); + + // Parse outer result map + Map results = gson.fromJson(response, Map.class); + Object otelResult = results.get("otel-traces"); + Assert.assertNotNull("Should have result for otel-traces", otelResult); + + // OTel path returns a list of detector results + Assert.assertTrue("OTel result should be a List, got: " + otelResult.getClass(), otelResult instanceof List); + List detectors = (List) otelResult; + Assert.assertEquals("Should create exactly 2 detectors for traces", 2, detectors.size()); + + // Verify LLM was never called (OTel path bypasses LLM) + org.mockito.Mockito.verify(client, org.mockito.Mockito.never()).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + } + + @Test + public void testOtelLogMapping_CreatesTwoDetectors_SkipsLLM() throws Exception { + Map otelLogMapping = new HashMap<>(); + otelLogMapping + .put( + "properties", + Map + .of( + "severityNumber", + ImmutableMap.of("type", "integer"), + "severityText", + ImmutableMap.of("type", "keyword"), + "time", + ImmutableMap.of("type", "date"), + "resource", + ImmutableMap + .of( + "type", + "object", + "properties", + Map + .of( + "attributes", + ImmutableMap + .of("type", "object", "properties", Map.of("service.name", ImmutableMap.of("type", "keyword"))) + ) + ) + ) + ); + when(mappingMetadata.getSourceAsMap()).thenReturn(otelLogMapping); + mockedMappings.put("otel-logs", mappingMetadata); + + mockOtelDetectorCreationChain(); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList("otel-logs")))), + ActionListener.wrap(r -> { + responseRef.set(r); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + + String response = responseRef.get(); + Map results = gson.fromJson(response, Map.class); + Object otelResult = results.get("otel-logs"); + Assert.assertTrue("OTel log result should be a List", otelResult instanceof List); + Assert.assertEquals("Should create exactly 2 detectors for logs", 2, ((List) otelResult).size()); + + org.mockito.Mockito.verify(client, org.mockito.Mockito.never()).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + } + + @Test + public void testNonOtelMapping_FallsThrough_ToLLM() throws Exception { + // Default mapping (response:integer, host:keyword, date:date) is NOT OTel + mockFullDetectorCreationChain(); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(r -> { + responseRef.set(r); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + + // LLM SHOULD have been called (non-OTel falls through) + org.mockito.Mockito.verify(client, org.mockito.Mockito.atLeastOnce()).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + } + + @Test + public void testPartialOtelMapping_DoesNotTriggerFastPath() throws Exception { + // Has durationInNanos + serviceName but missing spanId — should NOT detect as OTel + Map partialMapping = new HashMap<>(); + partialMapping + .put( + "properties", + Map + .of( + "durationInNanos", + ImmutableMap.of("type", "long"), + "serviceName", + ImmutableMap.of("type", "keyword"), + "timestamp", + ImmutableMap.of("type", "date"), + "responseCode", + ImmutableMap.of("type", "integer") + ) + ); + when(mappingMetadata.getSourceAsMap()).thenReturn(partialMapping); + mockedMappings.put("partial-otel", mappingMetadata); + + mockFullDetectorCreationChain(); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList("partial-otel")))), + ActionListener.wrap(r -> { + responseRef.set(r); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + + // LLM should have been called — partial OTel should NOT trigger fast-path + org.mockito.Mockito.verify(client, org.mockito.Mockito.atLeastOnce()).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + } + + // ===== COMMIT 5: Sequential Multi-Detector Tests ===== + + @Test + public void testFilterExpression_AppliedToDetector() throws Exception { + // LLM returns a response with filter=status:gte:400 + String responseWithFilter = + "{category_field=host|aggregation_field=status|aggregation_method=count|filter=status:gte:400|interval=10}"; + modelReturns = Collections.singletonMap("response", responseWithFilter); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + mockFullDetectorCreationChain(); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(r -> { + responseRef.set(r); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + + String response = responseRef.get(); + Assert.assertNotNull("Should get a response", response); + Assert.assertFalse("Should not error", response.startsWith("ERROR")); + // The detector should have been created (filter parsed successfully) + Assert.assertTrue("Should contain success status", response.contains("success")); + } + + @Test + public void testNoneSignal_StopsLoop() throws Exception { + // First call returns a valid detector, second call returns {NONE} + mockFullDetectorCreationChain(); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + // Track LLM call count and return {NONE} on second call + final int[] callCount = { 0 }; + doAnswer(invocation -> { + callCount[0]++; + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + if (callCount[0] == 1) { + // First call: valid response + modelReturns = Collections + .singletonMap( + "response", + "{category_field=host|aggregation_field=response|aggregation_method=count|filter=|interval=10}" + ); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + when(modelTensors.getMlModelTensors()).thenReturn(Collections.singletonList(modelTensor)); + when(modelTensorOutput.getMlModelOutputs()).thenReturn(Collections.singletonList(modelTensors)); + when(mlTaskResponse.getOutput()).thenReturn(modelTensorOutput); + } else { + // Second call: {NONE} + modelReturns = Collections.singletonMap("response", "{NONE}"); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + when(modelTensors.getMlModelTensors()).thenReturn(Collections.singletonList(modelTensor)); + when(modelTensorOutput.getMlModelOutputs()).thenReturn(Collections.singletonList(modelTensors)); + when(mlTaskResponse.getOutput()).thenReturn(modelTensorOutput); + } + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(r -> { + responseRef.set(r); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + latch.await(10, java.util.concurrent.TimeUnit.SECONDS); + + String response = responseRef.get(); + Assert.assertNotNull("Should get a response", response); + // Should have exactly 1 detector (first call succeeded, second returned NONE) + Assert.assertTrue("Should contain success", response.contains("success")); + // LLM was called exactly 2 times (once for detector, once got NONE) + Assert.assertEquals("LLM should be called exactly 2 times", 2, callCount[0]); + } + + @Test + public void testInvalidFilter_GracefulDegradation() throws Exception { + // LLM returns an invalid filter expression — should create detector without filter + String responseWithBadFilter = + "{category_field=|aggregation_field=response|aggregation_method=count|filter=invalid_no_colons|interval=10}"; + modelReturns = Collections.singletonMap("response", responseWithBadFilter); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + mockFullDetectorCreationChain(); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(r -> { + responseRef.set(r); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + + String response = responseRef.get(); + Assert.assertNotNull("Should get a response", response); + // Should still create the detector — invalid filter is silently ignored + Assert.assertTrue("Should contain success despite bad filter", response.contains("success")); + } + + @Test + public void testOldFormatWithoutFilter_StillWorks() throws Exception { + // LLM returns old format without filter= field — should fall back to old regex + String oldFormatResponse = "{category_field=host|aggregation_field=response|aggregation_method=count|interval=10}"; + modelReturns = Collections.singletonMap("response", oldFormatResponse); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + mockFullDetectorCreationChain(); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException("skip")); + return null; + }).when(client).execute(eq(MLIndexInsightGetAction.INSTANCE), any(), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("indices", Collections.singletonList(mockedIndexName)))), + ActionListener.wrap(r -> { + responseRef.set(r); + latch.countDown(); + }, e -> { + responseRef.set("ERROR: " + e.getMessage()); + latch.countDown(); + }) + ); + latch.await(5, java.util.concurrent.TimeUnit.SECONDS); + + String response = responseRef.get(); + Assert.assertNotNull("Should get a response", response); + Assert.assertTrue("Old format should still work via fallback regex", response.contains("success")); + } + + @Test + public void testParseFilterExpression() { + CreateAnomalyDetectorToolEnhanced tool = CreateAnomalyDetectorToolEnhanced.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId")); + + // Valid range operators + Assert.assertNotNull("gte should parse", tool.parseFilterExpression("status:gte:400")); + Assert.assertNotNull("gt should parse", tool.parseFilterExpression("latency:gt:5000")); + Assert.assertNotNull("lte should parse", tool.parseFilterExpression("severity:lte:3")); + Assert.assertNotNull("lt should parse", tool.parseFilterExpression("count:lt:10")); + + // Valid term operator + Assert.assertNotNull("eq should parse", tool.parseFilterExpression("status.code:eq:2")); + + // Null/empty → null + Assert.assertNull("null input", tool.parseFilterExpression(null)); + Assert.assertNull("empty input", tool.parseFilterExpression("")); + + // Invalid format → null (graceful) + Assert.assertNull("no colons", tool.parseFilterExpression("invalid_no_colons")); + Assert.assertNull("one colon", tool.parseFilterExpression("field:value")); + Assert.assertNull("unknown operator", tool.parseFilterExpression("field:between:1")); + } + + /** Mocks suggest + create + start for OTel path (no validate needed). */ + private void mockOtelDetectorCreationChain() { + // Validate detector — return no issues + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener + .onResponse( + new org.opensearch.timeseries.transport.ValidateConfigResponse( + (org.opensearch.timeseries.model.ConfigValidationIssue) null + ) + ); + return null; + }).when(client).execute(eq(org.opensearch.ad.transport.ValidateAnomalyDetectorAction.INSTANCE), any(), any()); + + // Suggest — return null interval (use default) + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(new org.opensearch.timeseries.transport.SuggestConfigParamResponse(null, null, null, null)); + return null; + }).when(client).execute(eq(org.opensearch.ad.transport.SuggestAnomalyDetectorParamAction.INSTANCE), any(), any()); + + // Create detector + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener + .onResponse( + new org.opensearch.ad.transport.IndexAnomalyDetectorResponse( + "otel-detector-id", + 1L, + 1L, + 1L, + null, + org.opensearch.core.rest.RestStatus.CREATED + ) + ); + return null; + }).when(client).execute(eq(org.opensearch.ad.transport.IndexAnomalyDetectorAction.INSTANCE), any(), any()); + + // Start detector + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(new org.opensearch.timeseries.transport.JobResponse("otel-detector-id")); + return null; + }).when(client).execute(eq(org.opensearch.ad.transport.AnomalyDetectorJobAction.INSTANCE), any(), any()); + } + + private void mockSearchForDateFieldSelection() { + SearchResponse searchResponse = org.mockito.Mockito.mock(SearchResponse.class); + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(100, TotalHits.Relation.EQUAL_TO), 1.0f); + when(searchResponse.getHits()).thenReturn(searchHits); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + } + + /** + * Mocks the full async chain: search (date field selection), validate, suggest, create, start. + */ + private void mockFullDetectorCreationChain() { + mockSearchForDateFieldSelection(); + + // Validate detector — return no issues + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener + .onResponse( + new org.opensearch.timeseries.transport.ValidateConfigResponse( + (org.opensearch.timeseries.model.ConfigValidationIssue) null + ) + ); + return null; + }).when(client).execute(eq(org.opensearch.ad.transport.ValidateAnomalyDetectorAction.INSTANCE), any(), any()); + + // Suggest hyperparameters — return defaults + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(new org.opensearch.timeseries.transport.SuggestConfigParamResponse(null, null, null, null)); + return null; + }).when(client).execute(eq(org.opensearch.ad.transport.SuggestAnomalyDetectorParamAction.INSTANCE), any(), any()); + + // Create detector + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener + .onResponse( + new org.opensearch.ad.transport.IndexAnomalyDetectorResponse( + "test-detector-id", + 1L, + 1L, + 1L, + null, + org.opensearch.core.rest.RestStatus.CREATED + ) + ); + return null; + }).when(client).execute(eq(org.opensearch.ad.transport.IndexAnomalyDetectorAction.INSTANCE), any(), any()); + + // Start detector + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(new org.opensearch.timeseries.transport.JobResponse("test-detector-id")); + return null; + }).when(client).execute(eq(org.opensearch.ad.transport.AnomalyDetectorJobAction.INSTANCE), any(), any()); + } + + private void initMLTensors() { + when(modelTensors.getMlModelTensors()).thenReturn(Collections.singletonList(modelTensor)); + when(modelTensorOutput.getMlModelOutputs()).thenReturn(Collections.singletonList(modelTensors)); + when(mlTaskResponse.getOutput()).thenReturn(modelTensorOutput); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java index 9b5a48e2..15cc9bcb 100644 --- a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java @@ -96,8 +96,8 @@ public void setup() { null, null, null, - null, - null + new IntervalTimeConfiguration(5, ChronoUnit.MINUTES), // frequency + false // autoCreated ); } diff --git a/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java b/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java index 4008904d..81bc09a0 100644 --- a/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java +++ b/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java @@ -310,4 +310,134 @@ private Map prepareDeepMap2() { return tmpMap; } + @Test + public void testExtractFieldNamesTypes_OtelLogMapping() { + String mapping = """ + { + "@timestamp": { "type": "date" }, + "observedTimestamp": { "type": "date" }, + "time": { "type": "date" }, + "body": { + "type": "text", + "fields": { "keyword": { "type": "keyword" } } + }, + "severityText": { "type": "keyword" }, + "severityNumber": { "type": "integer" }, + "traceId": { "type": "keyword" }, + "spanId": { "type": "keyword" }, + "flags": { "type": "integer" }, + "attributes": { + "type": "object", + "properties": { + "otelServiceName": { "type": "keyword" }, + "otelTraceID": { "type": "keyword" }, + "otelSpanID": { "type": "keyword" }, + "otelTraceSampled": { "type": "boolean" }, + "thread.name": { "type": "keyword" }, + "thread.id": { "type": "long" }, + "exception": { + "type": "object", + "properties": { + "type": { "type": "keyword" }, + "message": { "type": "text" }, + "stacktrace": { "type": "text" } + } + }, + "http.url": { "type": "keyword" }, + "http.method": { "type": "keyword" }, + "http.route": { "type": "keyword" }, + "http.target": { "type": "keyword" }, + "http.user_agent": { "type": "text" }, + "db.system": { "type": "keyword" }, + "db.operation": { "type": "keyword" }, + "code.namespace": { "type": "keyword" }, + "code.function": { "type": "keyword" }, + "client.address": { "type": "ip" }, + "owner.id": { "type": "integer" }, + "pet.id": { "type": "integer" }, + "hibernate.entity": { "type": "keyword" } + } + }, + "resource": { + "type": "object", + "properties": { + "attributes": { + "type": "object", + "properties": { + "service.name": { "type": "keyword" }, + "service.instance.id": { "type": "keyword" }, + "service.version": { "type": "keyword" }, + "telemetry.sdk.name": { "type": "keyword" }, + "telemetry.sdk.language": { "type": "keyword" }, + "telemetry.sdk.version": { "type": "keyword" }, + "k8s.namespace.name": { "type": "keyword" }, + "k8s.pod.name": { "type": "keyword" }, + "k8s.container.name": { "type": "keyword" }, + "host.name": { "type": "keyword" }, + "cloud.provider": { "type": "keyword" }, + "cloud.region": { "type": "keyword" } + } + } + } + }, + "instrumentationScope": { + "type": "object", + "properties": { + "name": { "type": "keyword" }, + "version": { "type": "keyword" } + } + } + } + """; + Map indexMappings = gson.fromJson(mapping, Map.class); + Map result = new HashMap<>(); + ToolHelper.extractFieldNamesTypes(indexMappings, result, "", true); + + // Date fields at top level + assertEquals("date", result.get("@timestamp")); + assertEquals("date", result.get("observedTimestamp")); + assertEquals("date", result.get("time")); + + // Text field with keyword sub-field + assertEquals("text", result.get("body")); + assertEquals("keyword", result.get("body.keyword")); + + // Top-level leaf fields + assertEquals("keyword", result.get("severityText")); + assertEquals("integer", result.get("severityNumber")); + + // Nested under attributes (object type skipped, children flattened) + assertEquals("keyword", result.get("attributes.otelServiceName")); + assertEquals("long", result.get("attributes.thread.id")); + assertEquals("ip", result.get("attributes.client.address")); + assertEquals("integer", result.get("attributes.owner.id")); + + // Deeply nested: attributes.exception.* + assertEquals("keyword", result.get("attributes.exception.type")); + assertEquals("text", result.get("attributes.exception.message")); + + // Deeply nested: resource.attributes.* — the key question + assertEquals("keyword", result.get("resource.attributes.service.name")); + assertEquals("keyword", result.get("resource.attributes.cloud.region")); + assertEquals("keyword", result.get("resource.attributes.k8s.namespace.name")); + assertEquals("keyword", result.get("resource.attributes.host.name")); + + // instrumentationScope.* + assertEquals("keyword", result.get("instrumentationScope.name")); + assertEquals("keyword", result.get("instrumentationScope.version")); + + // Verify date fields are found + java.util.Set dateFields = ToolHelper.findDateTypeFields(result); + assertEquals(3, dateFields.size()); + assert dateFields.contains("@timestamp"); + assert dateFields.contains("observedTimestamp"); + assert dateFields.contains("time"); + + // Verify object types themselves are NOT in the result + assert !result.containsKey("attributes"); + assert !result.containsKey("resource"); + assert !result.containsKey("resource.attributes"); + assert !result.containsKey("instrumentationScope"); + } + } diff --git a/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolEnhancedIT.java b/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolEnhancedIT.java new file mode 100644 index 00000000..9a1ddcef --- /dev/null +++ b/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolEnhancedIT.java @@ -0,0 +1,258 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import org.hamcrest.MatcherAssert; +import org.opensearch.agent.tools.CreateAnomalyDetectorToolEnhanced; +import org.opensearch.client.ResponseException; + +import lombok.SneakyThrows; + +public class CreateAnomalyDetectorToolEnhancedIT extends ToolIntegrationTest { + private final String NORMAL_INDEX = "http_logs"; + private final String NORMAL_INDEX_WITH_NO_DATE_FIELDS = "normal_index_with_no_date_fields"; + private final String ABNORMAL_INDEX = "abnormal_index"; + + @Override + List promptHandlers() { + PromptHandler createAnomalyDetectorToolHandler = new PromptHandler() { + @Override + String response(String prompt) { + if (prompt.contains(NORMAL_INDEX)) { + int flag = randomIntBetween(0, 5); + switch (flag) { + case 0: + return "{category_field=|aggregation_field=response,responseLatency|aggregation_method=count,avg|interval=10}"; + case 1: + return "{category_field=ip|aggregation_field=response,responseLatency|aggregation_method=count,avg|interval=10}"; + case 2: + return "{category_field=|aggregation_field=responseLatency|aggregation_method=avg|interval=10}"; + case 3: + return "{category_field=country.keyword|aggregation_field=response,responseLatency|aggregation_method=count,avg|interval=15}"; + case 4: + return "{category_field=\"ip\"|aggregation_field=\"responseLatency\"|aggregation_method=\"avg\"|interval=10}"; + case 5: + return "{category_field= |aggregation_field= responseLatency |aggregation_method= avg |interval=10}"; + default: + return "{category_field=|aggregation_field=response|aggregation_method=count|interval=10}"; + } + } else { + return "wrong response"; + } + } + + @Override + boolean apply(String prompt) { + return true; + } + }; + return List.of(createAnomalyDetectorToolHandler); + } + + @Override + String toolType() { + return CreateAnomalyDetectorToolEnhanced.TYPE; + } + + public void testCreateAnomalyDetectorToolEnhanced_SingleIndex() { + prepareIndex(); + String agentId = registerAgent(); + String result = executeAgent(agentId, "{\"parameters\": {\"input\": \"{\\\"indices\\\":[\\\"" + NORMAL_INDEX + "\\\"]}\"}}"); + // Verify successful detector creation + assertTrue(result.contains("indexName")); + assertTrue(result.contains(NORMAL_INDEX)); + // Should have either success or a specific failure status + assertTrue(result.contains("status")); + } + + public void testCreateAnomalyDetectorToolEnhanced_MultipleIndices() { + prepareIndex(); + String agentId = registerAgent(); + String result = executeAgent( + agentId, + "{\"parameters\": {\"input\": \"{\\\"indices\\\":[\\\"" + + NORMAL_INDEX + + "\\\",\\\"" + + NORMAL_INDEX_WITH_NO_DATE_FIELDS + + "\\\"]}\"}}" + ); + + // Both indices should be in response + assertTrue(result.contains(NORMAL_INDEX)); + assertTrue(result.contains(NORMAL_INDEX_WITH_NO_DATE_FIELDS)); + + // Both should have failed_validation status (one for LLM parsing, one for no date fields) + assertTrue("Should contain failed_validation status", result.contains("\"status\":\"failed_validation\"")); + + // Should contain specific error messages + assertTrue("Should contain LLM parsing error", result.contains("Cannot parse LLM response")); + assertTrue("Should contain no date fields error", result.contains("has no date fields")); + } + + public void testCreateAnomalyDetectorToolEnhanced_WithNoDateFields() { + prepareIndex(); + String agentId = registerAgent(); + String result = executeAgent( + agentId, + "{\"parameters\": {\"input\": \"{\\\"indices\\\":[\\\"" + NORMAL_INDEX_WITH_NO_DATE_FIELDS + "\\\"]}\"}}" + ); + assertTrue(result.contains("failed_validation")); + assertTrue(result.contains("no date fields")); + } + + public void testCreateAnomalyDetectorToolEnhanced_WithSystemIndex() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"input\": \"{\\\"indices\\\":[\\\".test\\\"]}\"}}") + ); + MatcherAssert.assertThat(exception.getMessage(), allOf(containsString("System indices not supported"))); + } + + public void testCreateAnomalyDetectorToolEnhanced_WithMissingIndex() { + prepareIndex(); + String agentId = registerAgent(); + String result = executeAgent(agentId, "{\"parameters\": {\"input\": \"{\\\"indices\\\":[\\\"non-existent\\\"]}\"}}"); + assertTrue(result.contains("failed_validation")); + assertTrue(result.contains("does not exist") || result.contains("no such index")); + } + + public void testCreateAnomalyDetectorToolEnhanced_WithEmptyInput() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {}}")); + MatcherAssert.assertThat(exception.getMessage(), allOf(containsString("Input parameter is required"))); + } + + @SneakyThrows + private void prepareIndex() { + createIndexWithConfiguration( + NORMAL_INDEX, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " \"responseLatency\": {\n" + + " \"type\": \"float\"\n" + + " },\n" + + " \"ip\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"country\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"date\": {\n" + + " \"type\": \"date\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex( + NORMAL_INDEX, + "0", + List.of("response", "responseLatency", "ip", "country", "date"), + List.of(200, 0.15, "192.168.1.1", "US", "2024-07-03T10:22:56,520") + ); + addDocToIndex( + NORMAL_INDEX, + "1", + List.of("response", "responseLatency", "ip", "country", "date"), + List.of(200, 3.15, "192.168.1.2", "UK", "2024-07-03T10:22:57,520") + ); + + createIndexWithConfiguration( + NORMAL_INDEX_WITH_NO_DATE_FIELDS, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"product\": {\n" + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(NORMAL_INDEX_WITH_NO_DATE_FIELDS, "0", List.of("product"), List.of("product1")); + addDocToIndex(NORMAL_INDEX_WITH_NO_DATE_FIELDS, "1", List.of("product"), List.of("product2")); + + createIndexWithConfiguration( + ABNORMAL_INDEX, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"date\": {\n" + + " \"type\": \"date\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(ABNORMAL_INDEX, "0", List.of("date"), List.of("2024-07-03T10:22:56,520")); + addDocToIndex(ABNORMAL_INDEX, "1", List.of("date"), List.of("2024-07-03T10:22:57,520")); + } + + @SneakyThrows + private String registerAgent() { + String registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource( + "org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_enhanced_request_body.json" + ) + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); + registerAgentRequestBody = registerAgentRequestBody + .replace( + "", + "\n\nHuman: Analyze this index and suggest ONE anomaly detector for operational monitoring that generates clear, actionable alerts.\n\n" + + "Index: ${indexInfo.indexName}\n" + + "Mapping: ${indexInfo.indexMapping}\n" + + "Available date fields: ${dateFields}\n\n" + + "OUTPUT FORMAT:\n" + + "Return ONLY this structured format, no explanation:\n" + + "{category_field=field_name_or_empty|aggregation_field=field1,field2|aggregation_method=method1,method2|interval=minutes}\n\n" + + "Assistant:" + ); + + return createAgent(registerAgentRequestBody); + } + + @SneakyThrows + private String registerAgentWithWrongModelId() { + String registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource( + "org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_enhanced_request_body.json" + ) + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", "wrong_model_id"); + return createAgent(registerAgentRequestBody); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_enhanced_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_enhanced_request_body.json new file mode 100644 index 00000000..a510faf8 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_enhanced_request_body.json @@ -0,0 +1,13 @@ +{ + "name": "Test_create_anomaly_detector_enhanced_flow_agent", + "type": "flow", + "tools": [ + { + "type": "CreateAnomalyDetectorToolEnhanced", + "parameters": { + "model_id": "", + "model_type": "CLAUDE" + } + } + ] +}